Skip to content

Commit a1fde8b

Browse files
authored
[ET-VK][ez] Update requirements for partitioning to_dim_order_copy
Differential Revision: D68528213 Pull Request resolved: #7859
1 parent 021df27 commit a1fde8b

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

backends/vulkan/op_registry.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ def update_features_impl(op: OpKey):
228228
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229229
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231-
# dim order copy operator will be removed; memory layout is handled internally
232-
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
233231
]
234232
)
235233
def register_ephemeral_op(features: OpFeatures):
@@ -322,6 +320,37 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
322320
return features
323321

324322

323+
@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
324+
def register_to_copy_dim_order_op(features: OpFeatures):
325+
features.texture_impl = TextureImplFeatures(
326+
uses_axis_map=True,
327+
valid_packed_dims=all_packed_dims,
328+
)
329+
features.buffer_impl = True
330+
features.resize_fn = True
331+
332+
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
333+
# removed as long as the operator is not changing the dtype, i.e. the operator call
334+
# is modifying the dim order only. Therefore, check that the input and output dtypes
335+
# are the same, if so the operator is safe to remove.
336+
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
337+
in_arg = node.args[0]
338+
if not isinstance(in_arg, torch.fx.Node):
339+
return False
340+
341+
in_tensor = in_arg.meta.get("val", None)
342+
out_tensor = node.meta.get("val", None)
343+
344+
if in_tensor.dtype != out_tensor.dtype:
345+
return False
346+
347+
return True
348+
349+
features.check_node_fn = check_dim_order_copy_node
350+
351+
return features
352+
353+
325354
@update_features(
326355
[
327356
exir_ops.edge.aten.bmm.default,

0 commit comments

Comments
 (0)