Skip to content
36 changes: 29 additions & 7 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey):
@update_features(
[
operator.getitem,
# Quantization related ops will be fused via graph passes
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
Expand All @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures):
return features


@update_features(
[
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
]
)
def register_quantization_op(features: OpFeatures):
# Quantization requires buffer storage and width packing for scales/zero_points
# but we need to provide texture impl features for the partitioner to work properly
features.texture_impl = TextureImplFeatures(
uses_axis_map=True,
valid_packed_dims={
PackedDim.WIDTH,
},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.BUFFER
return features


@update_features(
[
exir_ops.edge.aten.add.Tensor,
Expand Down
Loading