diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 72d7fc164c4..5cf36bdc605 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1458,17 +1458,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # kernel_width for conv1d, and X = kernel_height * kernel_width for # conv2d. We extract X as the kernel_size for im2row. kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:]) - # If the transposed_convolution op was quantized, we need the input tensor's - # zero_point for im2row. Otherwise in_zero_point defaults to a zero - # tensor. For quantized ops, the input_zero_point is at args[8] as an int, - # and we need to convert it to a tensor since transposed_im2row expects - # a Tensor for in_zero_point (unlike im2row which expects an int). assert isinstance(in_tensor, torch.fx.Node) - in_zero_point = ( - torch.tensor(node.args[8], dtype=torch.int32) - if quantized_op - else torch.tensor(0, dtype=torch.int32) - ) # Cast to appropriate types stride = cast(Sequence[int], stride) @@ -1493,6 +1483,20 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: graph = node.graph + # If the transposed_convolution op was quantized, we need the input tensor's + # zero_point for im2row. Otherwise in_zero_point defaults to a zero tensor. + # We create the tensor as a graph node using aten.full to avoid + # DataDependentOutputException during FakeTensor shape propagation. + in_zero_point_val = node.args[8] if quantized_op else 0 + in_zero_point = graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], in_zero_point_val), + kwargs={"dtype": torch.int32}, + ) + # Insert the node before the current node + node.prepend(in_zero_point) + in_zero_point.meta = node.meta + # Create a transposed_im2row node with the input. This will create a 2d # matrix of shape [out_height*out_weight, X*in_channels]. X is as # defined in the comment above.