@@ -141,17 +141,27 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
141141 tensor , self .scale , self .zp , self .qmin , self .qmax , self .dtype
142142 )
143143
144+ # Temporary helper until non-batch dimensions can be inferred
145+ # Detects if a node feeds into a conv op by checking all downstream users
146+ @staticmethod
147+ def _feeds_into_conv (node : torch .fx .Node ) -> bool :
148+ users_list = [node ]
149+
150+ while users_list :
151+ current_user = users_list .pop ()
152+ if "convolution" in str (current_user .target ):
153+ return True
154+ users_list .extend (current_user .users )
155+
156+ return False
157+
144158 @classmethod
145159 def _from_dynamic_input_node (cls , quant_node : torch .fx .Node ) -> QuantParams :
146160 q_input = quant_node .args [0 ] # fp32 input
147161 assert isinstance (q_input , torch .fx .Node )
148- num_nonbatch_dims = 1
149-
150- # Compute non-batch dimensions (shape length - 1), defaulting to 1
151- q_input_val = q_input .meta .get ("val" , None )
152- q_input_shape = getattr (q_input_val , "shape" , None )
153- if q_input_shape is not None :
154- num_nonbatch_dims = max (len (q_input_shape ) - 1 , 1 )
162+ # TODO - materialize this from the quant_node scale count and val shape
163+ # Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear
164+ num_nonbatch_dims = 3 if cls ._feeds_into_conv (quant_node ) else 1
155165
156166 return cls (
157167 per_channel = False , # True is not valid
0 commit comments