Skip to content

Commit 62e30e5

Browse files
committed
Add computation for non-batch dims; remove non-batch dims check
1 parent 84b3634 commit 62e30e5

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,14 @@ def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
145145
def _from_dynamic_input_node(cls, quant_node: torch.fx.Node) -> QuantParams:
146146
q_input = quant_node.args[0] # fp32 input
147147
assert isinstance(q_input, torch.fx.Node)
148-
# TODO - materialize this from the quant_node scale count and val shape
149148
num_nonbatch_dims = 1
150149

150+
# Compute non-batch dimensions (shape length - 1), defaulting to 1
151+
q_input_val = q_input.meta.get("val", None)
152+
q_input_shape = getattr(q_input_val, "shape", None)
153+
if q_input_shape is not None:
154+
num_nonbatch_dims = max(len(q_input_shape) - 1, 1)
155+
151156
return cls(
152157
per_channel=False, # True is not valid
153158
q_input=q_input,

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,6 @@ Error defineTensor(
512512
buffer_ptr == nullptr,
513513
Internal,
514514
"Dynamically quantized tensor should not have constant data but found non-nullptr");
515-
// TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
516-
ET_CHECK_OR_RETURN_ERROR(
517-
qparams->num_nonbatch_dims() == 1,
518-
Internal,
519-
"Dynamically Quantized Tensors currently only support per token quantization");
520515
status = xnn_define_dynamically_quantized_tensor_value(
521516
/*subgraph=*/subgraph_ptr,
522517
/*datatype=*/getDataType(tensor_value->datatype()),

0 commit comments

Comments
 (0)