diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 9756602ad2d..ce36e23b62c 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -19,6 +19,29 @@ QuantArgs = tuple[float, int, int, int, torch.dtype] +def extract_input_shapes_from_graph( + module: GraphModule, +) -> dict[int, tuple[int, ...]]: + """ + Extract input shapes from the FX graph placeholder nodes. + + Returns a dict mapping input index to expected shape tuple. + """ + input_shapes: dict[int, tuple[int, ...]] = {} + idx = 0 + for node in module.graph.nodes: + if node.op == "placeholder": + # Get the tensor_meta from the node if available + if "val" in node.meta: + val = node.meta["val"] + if isinstance(val, torch.Tensor): + input_shapes[idx] = tuple(val.shape) + elif hasattr(val, "shape"): + input_shapes[idx] = tuple(val.shape) + idx += 1 + return input_shapes + + @torch.no_grad() def trace( model: torch.nn.Module, @@ -138,6 +161,9 @@ def __init__( super().__init__() self.module: GraphModule = module self.quant_args: dict[int, QuantArgs] = {} + self.expected_shapes: dict[int, tuple[int, ...]] = ( + extract_input_shapes_from_graph(module) + ) if input_args is not None: logger.warning( @@ -151,6 +177,18 @@ def __init__( def forward(self, *args: torch.Tensor) -> Any: """Run inference, dequantizing configured inputs.""" + # Validate input shapes for quantized inputs + for index in self.quant_args: + if index < len(args): + actual_shape = tuple(args[index].shape) + if index in self.expected_shapes: + expected_shape = self.expected_shapes[index] + if actual_shape != expected_shape: + raise ValueError( + f"Shape mismatch for quantized input at index {index}: " + f"expected {expected_shape}, got {actual_shape}" + ) + dequantized_args = [] for index, node in enumerate(args): if index in self.quant_args: