Skip to content

Commit f278cc4

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Add shape check for quantizedinputwrapper (#16515)
Summary: Check the provided input args tensor shape to make sure they align with extracted input tensor shape Reviewed By: DrJessop, eigen-k Differential Revision: D90362667
1 parent 35eb01a commit f278cc4

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

backends/cadence/aot/compiler_funcs.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,29 @@
1919
QuantArgs = tuple[float, int, int, int, torch.dtype]
2020

2121

22+
def extract_input_shapes_from_graph(
23+
module: GraphModule,
24+
) -> dict[int, tuple[int, ...]]:
25+
"""
26+
Extract input shapes from the FX graph placeholder nodes.
27+
28+
Returns a dict mapping input index to expected shape tuple.
29+
"""
30+
input_shapes: dict[int, tuple[int, ...]] = {}
31+
idx = 0
32+
for node in module.graph.nodes:
33+
if node.op == "placeholder":
34+
# Get the tensor_meta from the node if available
35+
if "val" in node.meta:
36+
val = node.meta["val"]
37+
if isinstance(val, torch.Tensor):
38+
input_shapes[idx] = tuple(val.shape)
39+
elif hasattr(val, "shape"):
40+
input_shapes[idx] = tuple(val.shape)
41+
idx += 1
42+
return input_shapes
43+
44+
2245
@torch.no_grad()
2346
def trace(
2447
model: torch.nn.Module,
@@ -138,6 +161,9 @@ def __init__(
138161
super().__init__()
139162
self.module: GraphModule = module
140163
self.quant_args: dict[int, QuantArgs] = {}
164+
self.expected_shapes: dict[int, tuple[int, ...]] = (
165+
extract_input_shapes_from_graph(module)
166+
)
141167

142168
if input_args is not None:
143169
logger.warning(
@@ -151,6 +177,20 @@ def __init__(
151177

152178
def forward(self, *args: torch.Tensor) -> Any:
153179
"""Run inference, dequantizing configured inputs."""
180+
# Validate input shapes for quantized inputs
181+
for index in self.quant_args:
182+
if index >= len(args):
183+
continue
184+
actual_shape = tuple(args[index].shape)
185+
if index not in self.expected_shapes:
186+
continue
187+
expected_shape = self.expected_shapes[index]
188+
if actual_shape != expected_shape:
189+
raise ValueError(
190+
f"Shape mismatch for quantized input at index {index}: "
191+
f"expected {expected_shape}, got {actual_shape}"
192+
)
193+
154194
dequantized_args = []
155195
for index, node in enumerate(args):
156196
if index in self.quant_args:

0 commit comments

Comments
 (0)