Skip to content

Commit 7c53454

Browse files
committed
Add check to determine if node feeds into conv and set non-batch dims accordingly
1 parent 6da8b7d commit 7c53454

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,27 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
141141
tensor, self.scale, self.zp, self.qmin, self.qmax, self.dtype
142142
)
143143

144+
# Temporary helper until non-batch dimensions can be inferred
145+
# Detects if a node feeds into a conv op by checking all downstream users
146+
@staticmethod
147+
def _feeds_into_conv(node: torch.fx.Node) -> bool:
148+
users_list = [node]
149+
150+
while users_list:
151+
current_user = users_list.pop()
152+
if "convolution" in str(current_user.target):
153+
return True
154+
users_list.extend(current_user.users)
155+
156+
return False
157+
144158
@classmethod
145159
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
146160
q_input = quant_node.args[0] # fp32 input
147161
assert isinstance(q_input, torch.fx.Node)
148-
num_nonbatch_dims = 1
149-
150-
# Compute non-batch dimensions (shape length - 1), defaulting to 1
151-
q_input_val = q_input.meta.get("val", None)
152-
q_input_shape = getattr(q_input_val, "shape", None)
153-
if q_input_shape is not None:
154-
num_nonbatch_dims = max(len(q_input_shape) - 1, 1)
162+
# TODO - materialize this from the quant_node scale count and val shape
163+
# Set non-batch dims to 3 if node feeds into conv (only 2D is supported), otherwise set to 1 for linear
164+
num_nonbatch_dims = 3 if cls._feeds_into_conv(quant_node) else 1
155165

156166
return cls(
157167
per_channel=False, # True is not valid

0 commit comments

Comments
 (0)