Skip to content

Commit 0129dc1

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Update requirements for partitioning to_dim_order_copy (pytorch#7859)
Summary: ## Context The previous registration of the to dim order copy op is incorrect. Currently, there is no implementation for the op in the Vulkan backend, but since Vulkan manages memory layout internally the op node can be removed as long as the only thing being changed is dim order. In some instances the op can be used to modify the dtype, in which case it will not be removed and the Vulkan delegate cannot execute the op correctly. Therefore, update the registration of the op to reflect this restriction. This diff should unblock enabling dim order ops for Vulkan. ghstack-source-id: 262603955 exported-using-ghexport Differential Revision: D68528213
1 parent 7c67968 commit 0129dc1

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

backends/vulkan/op_registry.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ def update_features_impl(op: OpKey):
228228
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
229229
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
230230
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
231-
# dim order copy operator will be removed; memory layout is handled internally
232-
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
233231
]
234232
)
235233
def register_ephemeral_op(features: OpFeatures):
@@ -321,6 +319,36 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
321319

322320
return features
323321

322+
@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
323+
def register_to_copy_dim_order_op(features: OpFeatures):
324+
features.texture_impl = TextureImplFeatures(
325+
uses_axis_map=True,
326+
valid_packed_dims=all_packed_dims,
327+
)
328+
features.buffer_impl = True
329+
features.resize_fn = True
330+
331+
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
332+
# removed as long as the operator is not changing the dtype, i.e. the operator call
333+
# is modifying the dim order only. Therefore, check that the input and output dtypes
334+
# are the same, if so the operator is safe to remove.
335+
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
336+
in_arg = node.args[0]
337+
if not isinstance(in_arg, torch.fx.Node):
338+
return False
339+
340+
in_tensor = in_arg.meta.get("val", None)
341+
out_tensor = node.meta.get("val", None)
342+
343+
if in_tensor.dtype != out_tensor.dtype:
344+
return False
345+
346+
return True
347+
348+
features.check_node_fn = check_dim_order_copy_node
349+
350+
return features
351+
324352

325353
@update_features(
326354
[

0 commit comments

Comments
 (0)