@@ -228,8 +228,6 @@ def update_features_impl(op: OpKey):
228228 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
229229 exir_ops .edge .quantized_decomposed .dequantize_per_tensor .tensor ,
230230 exir_ops .edge .quantized_decomposed .dequantize_per_channel .default ,
231- # dim order copy operator will be removed; memory layout is handled internally
232- exir_ops .edge .dim_order_ops ._to_dim_order_copy .default ,
233231 ]
234232)
235233def register_ephemeral_op (features : OpFeatures ):
@@ -322,6 +320,37 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
322320 return features
323321
324322
323+ @update_features (exir_ops .edge .dim_order_ops ._to_dim_order_copy .default )
324+ def register_to_copy_dim_order_op (features : OpFeatures ):
325+ features .texture_impl = TextureImplFeatures (
326+ uses_axis_map = True ,
327+ valid_packed_dims = all_packed_dims ,
328+ )
329+ features .buffer_impl = True
330+ features .resize_fn = True
331+
332+ # Currently there is no "real" implementation for to_dim_order_copy, but it can be
333+ # removed as long as the operator is not changing the dtype, i.e. the operator call
334+ # is modifying the dim order only. Therefore, check that the input and output dtypes
335+ # are the same, if so the operator is safe to remove.
336+ def check_dim_order_copy_node (node : torch .fx .Node ) -> bool :
337+ in_arg = node .args [0 ]
338+ if not isinstance (in_arg , torch .fx .Node ):
339+ return False
340+
341+ in_tensor = in_arg .meta .get ("val" , None )
342+ out_tensor = node .meta .get ("val" , None )
343+
344+ if in_tensor .dtype != out_tensor .dtype :
345+ return False
346+
347+ return True
348+
349+ features .check_node_fn = check_dim_order_copy_node
350+
351+ return features
352+
353+
325354@update_features (
326355 [
327356 exir_ops .edge .aten .bmm .default ,
0 commit comments