Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -970,11 +970,16 @@ jobs:
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build

# Test models serially
models="mv2 mv3 edsr resnet18 resnet50 dl3"
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4 vit"
for model in $models; do
python -m examples.vulkan.export --model_name=$model --test
done

# For selected vision models, test with dynamic shapes
models="mv2 mv3 resnet18 resnet50 ic3 ic4 densenet161"
for model in $models; do
python -m examples.vulkan.export --model_name=$model --test -d
done

test-vulkan-operators-linux:
name: test-vulkan-operators-linux
Expand Down
5 changes: 2 additions & 3 deletions backends/vulkan/_passes/fold_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass):
valid quant op patterns have already been fused before this pass.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
super(FoldQDQPass, self).__init__()
self.edge_program = edge_program
def __init__(self):
super().__init__()

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
Expand Down
10 changes: 7 additions & 3 deletions backends/vulkan/_passes/fuse_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import executorch.backends.vulkan.patterns as vk_patterns

import torch
Expand All @@ -13,13 +15,15 @@


class FusePatternsPass(ExportPass):
def __init__(self, exported_program: ExportedProgram) -> None:
def __init__(self) -> None:
super().__init__()
self.program = exported_program
self._exported_program: Optional[ExportedProgram] = None

def call(self, graph_module: torch.fx.GraphModule):
assert self._exported_program is not None

total_replaced = vk_patterns.replace_all_fusable_subgraphs(
self.program, graph_module
self._exported_program, graph_module
)

if total_replaced > 0:
Expand Down
10 changes: 6 additions & 4 deletions backends/vulkan/_passes/fuse_quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node(


class FuseQuantizedOpsTransform(ExportPass):
def __init__(self, exported_program: ExportedProgram) -> None:
def __init__(self) -> None:
super().__init__()
self.program = exported_program
self._exported_program: Optional[ExportedProgram] = None

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
assert self._exported_program is not None

for node in graph_module.graph.nodes:
# Check for linear_qcnw pattern (weight-only quantization)
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node)
if qcnw_details is not None:
qcnw_method, qcnw_nbits = qcnw_details
fuse_into_linear_qcnw_node(
self.program, graph_module, node, qcnw_method, qcnw_nbits
self._exported_program, graph_module, node, qcnw_method, qcnw_nbits
)
continue

Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def get_arg_tensor_source_repset(
"""
arg_node = op_node.args[arg_i]

# For non-tensor arguments, return ANY_STORAGE
if not utils.is_tensor_arg_node(arg_node):
return utils.ANY_STORAGE

# Special case for cat - use the first tensor in the list as representative
if isinstance(arg_node, list):
arg_node = arg_node[0]
Expand Down
93 changes: 70 additions & 23 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -48,6 +46,9 @@ class OpFeatures:
# Optional check function used during partitioning to determine if a node's
# inputs are supported by the operator implementation.
"are_node_inputs_supported_fn",
# Optional function to determine valid representation sets for input and outputs
# once a node's actual inputs are known.
"pick_io_storage_fn",
]

def __init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
supports_resize: bool = False,
supports_prepacking: bool = False,
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
pick_io_storage_fn: Optional[Callable] = None,
):
self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
inputs_storage if inputs_storage is not None else []
Expand All @@ -77,15 +79,21 @@ def __init__(
self.supports_prepacking = supports_prepacking

self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
self.pick_io_storage_fn = pick_io_storage_fn

def make_op_repsets(
self,
op_node: torch.fx.Node,
texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS,
) -> utils.OpRepSets:
return utils.OpRepSets(
self.inputs_storage, self.outputs_storage, op_node, texture_limits
)
inputs_storage = self.inputs_storage
outputs_storage = self.outputs_storage
if self.pick_io_storage_fn is not None:
i_storage, o_storage = self.pick_io_storage_fn(op_node)
inputs_storage = utils.TensorRepSetList(i_storage)
outputs_storage = utils.TensorRepSetList(o_storage)

return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits)


#######################
Expand Down Expand Up @@ -410,28 +418,16 @@ def register_softmax_op():
)
def register_reduce_op():
def check_reduce_node(node: torch.fx.Node) -> bool:
# Only one argument implies that the reduction is over the entire tensor, which
# is not supported yet.
if len(node.args) == 1:
return False

dim_list = node.args[1]
# Only 1D and 2D reductions are supported at the moment.
if isinstance(dim_list, list) and len(dim_list) > 2:
return False

if isinstance(dim_list, list) and len(dim_list) == 2:
# Try to get the memory layout for this node
try:
memory_layout = utils.get_node_memory_layout(node)

# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
if (
memory_layout is not None
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
):
# For now only default layout is supported for 2D reduction.
# Because we can't determine if the input is NCHW or NHWC here,
# assume the reduction dimension is packed so we cannot support it.
return False
except (AssertionError, KeyError, AttributeError):
# If we can't get memory layout information, we'll assume the dims aren't packed
pass

def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
for arg in node.args:
if isinstance(arg, bool):
Expand All @@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:

return True

def pick_io_storage_for_reduce(node: torch.fx.Node):
inputs_storage = utils.ANY_TEXTURE
outputs_storage = utils.ANY_TEXTURE

input_tensor = node.args[0]
ndim = input_tensor.meta["val"].ndim
dim_list = node.args[1]
if isinstance(dim_list, list) and len(dim_list) == 2:
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)

possible_packed_dims = {0, 1, 2}
possible_packed_dims.discard(reduce_dim1_whcn)
possible_packed_dims.discard(reduce_dim2_whcn)

packed_dim = possible_packed_dims.pop()
assert packed_dim in [0, 1, 2]

if packed_dim == 0:
inputs_storage = utils.WIDTH_PACKED_TEXTURE
outputs_storage = utils.WIDTH_PACKED_TEXTURE
elif packed_dim == 1:
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
else:
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
outputs_storage = utils.CHANNELS_PACKED_TEXTURE

return inputs_storage, outputs_storage

return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_resize=True,
are_node_inputs_supported_fn=check_reduce_node,
pick_io_storage_fn=pick_io_storage_for_reduce,
)


Expand All @@ -474,6 +501,23 @@ def register_2d_pool_op():
]
)
def register_convolution_op():
def check_conv_node(node: torch.fx.Node) -> bool:
x = node.args[0]
x_shape = x.meta["val"].size()
# 4-D input implies 2D convolution
if len(x_shape) == 4:
batches = x.meta["val"].size()[0]
if batches != 1:
return False
# 3-D input implies 1D convolution
if len(x_shape) == 3:
transpose = node.args[6]
# Transposed 1D convolution is not supported yet
if transpose:
return False

return True

return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE, # input
Expand All @@ -490,6 +534,7 @@ def register_convolution_op():
],
supports_resize=True,
supports_prepacking=True,
are_node_inputs_supported_fn=check_conv_node,
)


Expand Down Expand Up @@ -716,6 +761,7 @@ def register_ported_ops_with_prepacking():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
supports_prepacking=True,
supports_resize=True,
)


Expand Down Expand Up @@ -746,6 +792,7 @@ def register_ported_ops_with_prepacking_all_dims():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_prepacking=True,
supports_resize=True,
)


Expand Down
10 changes: 6 additions & 4 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
from executorch.exir.dialects._ops import ops as exir_ops

from torch.export.exported_program import ExportedProgram
Expand Down Expand Up @@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
self.log_skip(node, "permute node of non compatible linear node")
return False

is_in_local_scalar_dense_chain, dst_node_is_compatible = (
self.is_in_local_scalar_dense_chain(node)
)
(
is_in_local_scalar_dense_chain,
dst_node_is_compatible,
) = self.is_in_local_scalar_dense_chain(node)
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
return True
elif is_in_local_scalar_dense_chain:
Expand Down Expand Up @@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.")

tag_constant_data(exported_program)
tag_mutated_buffer(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
Expand Down
12 changes: 7 additions & 5 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
return

# Identify input node
self.fp_input_node, self.quantize_input_node, dq_node = (
utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
)
(
self.fp_input_node,
self.quantize_input_node,
dq_node,
) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
assert self.fp_input_node is not None
self.all_nodes.append(self.fp_input_node)

Expand Down Expand Up @@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op(
weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.CONSTANT_TENSOR,
kind=InputKind.PARAMETER,
name=sums_name,
data=sum_per_quant_group,
)
Expand Down Expand Up @@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op(
weight_sums_node = create_constant_placeholder(
exp_program=ep,
graph=graph_module.graph,
kind=InputKind.CONSTANT_TENSOR,
kind=InputKind.PARAMETER,
name=sums_name,
data=sum_per_output_channel,
)
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/glsl/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ full:
DTYPE:
- VALUE: half
- VALUE: float
- VALUE: int32
shader_variants:
- NAME: full
Loading
Loading