diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index ca7ce72caed..2aaaa13df6e 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -47,7 +47,8 @@ def __contains__(self, op): operator.getitem, ] -BINARY_OPS = [ +SUPPORTS_DYNAMIC_SHAPE = [ + # Binary broadcasting exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.minimum.default, @@ -55,9 +56,7 @@ def __contains__(self, op): exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.pow.Tensor_Tensor, -] - -UNARY_OPS = [ + # Unary elementwise exir_ops.edge.aten.abs.default, exir_ops.edge.aten.clamp.default, exir_ops.edge.aten.cos.default, @@ -71,60 +70,46 @@ def __contains__(self, op): exir_ops.edge.aten.sin.default, exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.tanh.default, -] - -MATMUL_OPS = [ + # Matrix Multiplication exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.linear.default, -] - -POOLING_OPS = [ + # Reduction + exir_ops.edge.aten._log_softmax.default, + exir_ops.edge.aten._softmax.default, + # 2D Pooling exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.max_pool2d_with_indices.default, -] - -CONVOLUTION_OPS = [ + # Convolution exir_ops.edge.aten.convolution.default, exir_ops.edge.et_vk.conv_with_clamp.default, ] -REDUCTION_OPS = [ +NO_DYNAMIC_SHAPE = [ + # Reduction exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, - exir_ops.edge.aten._log_softmax.default, - exir_ops.edge.aten._softmax.default, -] - -NORMALIZATION_OPS = [ + # Normalization exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, -] - -SHAPE_MANIPULATION_OPS = [ + # Shape Manipulation exir_ops.edge.aten.squeeze_copy.dims, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.view_copy.default, exir_ops.edge.aten.permute_copy.default, exir_ops.edge.aten.t_copy.default, -] - -INDEXING_OPS = [ + # Indexing and lookup exir_ops.edge.aten.embedding.default, exir_ops.edge.aten.index_select.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.slice_copy.Tensor, -] - -ORCHESTRATION_OPS = [ + # Tensor combination exir_ops.edge.aten.cat.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split.Tensor, exir_ops.edge.aten.repeat.default, -] - -CREATION_OPS = [ + # Tensor creation exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.constant_pad_nd.default, @@ -139,39 +124,20 @@ def __contains__(self, op): ] -def register_prim_ops(ops: OpList): - for op in PRIM_OPS: - ops[op].supports_texture = True - ops[op].supports_buffer = True - ops[op].supports_dynamic_shape = True +def enumerate_supported_ops(): + ops = OpList() + # Register in order of least to most capabilities -def register_no_dynamic_shape_ops(ops: OpList): - for op in [ - *REDUCTION_OPS, - *NORMALIZATION_OPS, - *SHAPE_MANIPULATION_OPS, - *INDEXING_OPS, - *ORCHESTRATION_OPS, - *CREATION_OPS, - ]: + for op in NO_DYNAMIC_SHAPE: ops[op].supports_dynamic_shape = False - -def register_dynamic_shape_ops(ops: OpList): - for op in [ - *BINARY_OPS, - *UNARY_OPS, - *MATMUL_OPS, - *POOLING_OPS, - *CONVOLUTION_OPS, - ]: + for op in SUPPORTS_DYNAMIC_SHAPE: ops[op].supports_dynamic_shape = True + for op in PRIM_OPS: + ops[op].supports_texture = True + ops[op].supports_buffer = True + ops[op].supports_dynamic_shape = True -def enumerate_supported_ops(): - ops = OpList() - register_prim_ops(ops) - register_no_dynamic_shape_ops(ops) - register_dynamic_shape_ops(ops) return ops