1919QuantArgs = tuple [float , int , int , int , torch .dtype ]
2020
2121
22+ def extract_input_shapes_from_graph (
23+ module : GraphModule ,
24+ ) -> dict [int , tuple [int , ...]]:
25+ """
26+ Extract input shapes from the FX graph placeholder nodes.
27+
28+ Returns a dict mapping input index to expected shape tuple.
29+ """
30+ input_shapes : dict [int , tuple [int , ...]] = {}
31+ idx = 0
32+ for node in module .graph .nodes :
33+ if node .op == "placeholder" :
34+ # Get the tensor_meta from the node if available
35+ if "val" in node .meta :
36+ val = node .meta ["val" ]
37+ if isinstance (val , torch .Tensor ):
38+ input_shapes [idx ] = tuple (val .shape )
39+ elif hasattr (val , "shape" ):
40+ input_shapes [idx ] = tuple (val .shape )
41+ idx += 1
42+ return input_shapes
43+
44+
2245@torch .no_grad ()
2346def trace (
2447 model : torch .nn .Module ,
@@ -138,6 +161,9 @@ def __init__(
138161 super ().__init__ ()
139162 self .module : GraphModule = module
140163 self .quant_args : dict [int , QuantArgs ] = {}
164+ self .expected_shapes : dict [int , tuple [int , ...]] = (
165+ extract_input_shapes_from_graph (module )
166+ )
141167
142168 if input_args is not None :
143169 logger .warning (
@@ -151,6 +177,20 @@ def __init__(
151177
152178 def forward (self , * args : torch .Tensor ) -> Any :
153179 """Run inference, dequantizing configured inputs."""
180+ # Validate input shapes for quantized inputs
181+ for index in self .quant_args :
182+ if index >= len (args ):
183+ continue
184+ actual_shape = tuple (args [index ].shape )
185+ if index not in self .expected_shapes :
186+ continue
187+ expected_shape = self .expected_shapes [index ]
188+ if actual_shape != expected_shape :
189+ raise ValueError (
190+ f"Shape mismatch for quantized input at index { index } : "
191+ f"expected { expected_shape } , got { actual_shape } "
192+ )
193+
154194 dequantized_args = []
155195 for index , node in enumerate (args ):
156196 if index in self .quant_args :
0 commit comments