Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,17 +1458,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# kernel_width for conv1d, and X = kernel_height * kernel_width for
# conv2d. We extract X as the kernel_size for im2row.
kernel_size = list(weight_shape[1:-1] if channel_last else weight_shape[2:])
# If the transposed_convolution op was quantized, we need the input tensor's
# zero_point for im2row. Otherwise in_zero_point defaults to a zero
# tensor. For quantized ops, the input_zero_point is at args[8] as an int,
# and we need to convert it to a tensor since transposed_im2row expects
# a Tensor for in_zero_point (unlike im2row which expects an int).
assert isinstance(in_tensor, torch.fx.Node)
in_zero_point = (
torch.tensor(node.args[8], dtype=torch.int32)
if quantized_op
else torch.tensor(0, dtype=torch.int32)
)

# Cast to appropriate types
stride = cast(Sequence[int], stride)
Expand All @@ -1493,6 +1483,20 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:

graph = node.graph

# If the transposed_convolution op was quantized, we need the input tensor's
# zero_point for im2row. Otherwise in_zero_point defaults to a zero tensor.
# We create the tensor as a graph node using aten.full to avoid
# DataDependentOutputException during FakeTensor shape propagation.
in_zero_point_val = node.args[8] if quantized_op else 0
in_zero_point = graph.call_function(
exir_ops.edge.aten.full.default,
args=([1], in_zero_point_val),
kwargs={"dtype": torch.int32},
)
# Insert the node before the current node
node.prepend(in_zero_point)
in_zero_point.meta = node.meta

# Create a transposed_im2row node with the input. This will create a 2d
# matrix of shape [out_height*out_weight, X*in_channels]. X is as
# defined in the comment above.
Expand Down
Loading