diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9a63d178e2d..1f77b30cda3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey): @update_features( [ operator.getitem, - # Quantization related ops will be fused via graph passes - exir_ops.edge.quantized_decomposed.quantize_per_channel.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.quantized_decomposed.quantize_per_token.default, + exir_ops.edge.quantized_decomposed.dequantize_per_token.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, + ] +) +def register_quantization_op(features: OpFeatures): + # Quantization requires buffer storage and width packing for scales/zero_points + # but we need to provide texture impl features for the partitioner to work properly + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor,