Skip to content

Commit e9f22c1

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Replace quantized conv and relu non-tensor variants with per tensor variants
Summary: Fix to just call the per tensor variants for quantized conv and quantized relu, since those are the only ones we are supporting. Differential Revision: D83873738
1 parent 5326a1b commit e9f22c1

File tree

2 files changed

+11
-52
lines changed

2 files changed

+11
-52
lines changed

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -306,31 +306,6 @@ def get_args_and_kwargs_conv(
306306

307307
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
308308

309-
out_multiplier_ = graph_module.graph.call_function(
310-
torch.ops.aten.full.default,
311-
([1], out_multiplier[0].item()),
312-
{"dtype": torch.int32},
313-
)
314-
out_shift_ = graph_module.graph.call_function(
315-
torch.ops.aten.full.default,
316-
([1], out_shift[0].item()),
317-
{"dtype": torch.int32},
318-
)
319-
320-
# Create a single element tensor for the weight zero point
321-
weight_zero_point_tensor = graph_module.graph.call_function(
322-
torch.ops.aten.full.default,
323-
([1], weight_zero_point),
324-
{"dtype": torch.int32},
325-
)
326-
327-
# Create a single element tensor for the bias scale
328-
bias_scale_tensor = graph_module.graph.call_function(
329-
torch.ops.aten.full.default,
330-
([1], bias_scale),
331-
{"dtype": torch.float32},
332-
)
333-
334309
# Make the args and kwargs for the replacement op
335310
args = tuple(inputs_inputs + weights_inputs + [bias])
336311
kwargs = {
@@ -339,12 +314,12 @@ def get_args_and_kwargs_conv(
339314
"dilation": dilation,
340315
"groups": groups,
341316
"input_zero_point": dequants_inputs[0].args[2],
342-
"weight_zero_point": weight_zero_point_tensor,
343-
"bias_scale": bias_scale_tensor,
317+
"weight_zero_point": weight_zero_point,
318+
"bias_scale": bias_scale,
344319
"out_scale": quant_node.args[1],
345320
"out_zero_point": quant_node.args[2],
346-
"out_multiplier": out_multiplier_,
347-
"out_shift": out_shift_,
321+
"out_multiplier": out_multiplier[0].item(),
322+
"out_shift": out_shift[0].item(),
348323
}
349324
return args, kwargs
350325

@@ -365,27 +340,11 @@ def get_args_and_kwargs_relu(
365340
# Make the args and kwargs for the replacement op
366341
args = tuple(inputs_inputs)
367342

368-
X_zero_point = graph_module.graph.call_function(
369-
torch.ops.aten.full.default,
370-
([1], dequants_inputs[0].args[2]),
371-
{"dtype": torch.int32},
372-
)
373-
out_multiplier_ = graph_module.graph.call_function(
374-
torch.ops.aten.full.default,
375-
([1], out_multiplier[0].item()),
376-
{"dtype": torch.int32},
377-
)
378-
out_shift_ = graph_module.graph.call_function(
379-
torch.ops.aten.full.default,
380-
([1], out_shift[0].item()),
381-
{"dtype": torch.int32},
382-
)
383-
384343
kwargs = {
385-
"X_zero_point": X_zero_point,
344+
"X_zero_point": dequants_inputs[0].args[2],
386345
"out_zero_point": quant_node.args[2],
387-
"out_multiplier": out_multiplier_,
388-
"out_shift": out_shift_,
346+
"out_multiplier": out_multiplier[0].item(),
347+
"out_shift": out_shift[0].item(),
389348
}
390349
return args, kwargs
391350

backends/cadence/aot/quantizer/patterns.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def get_anchors(
265265
)
266266

267267
def replacement_op(self) -> OpOverload:
268-
return torch.ops.cadence.quantized_conv2d_nchw.default
268+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
269269

270270

271271
class Conv2dPattern(QuantizationPattern):
@@ -307,7 +307,7 @@ def get_anchors(
307307
)
308308

309309
def replacement_op(self) -> OpOverload:
310-
return torch.ops.cadence.quantized_conv2d_nchw.default
310+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
311311

312312

313313
class LayerNormPattern(QuantizationPattern):
@@ -437,7 +437,7 @@ def get_anchors(
437437
)
438438

439439
def replacement_op(self) -> OpOverload:
440-
return torch.ops.cadence.quantized_relu.default
440+
return torch.ops.cadence.quantized_relu.per_tensor
441441

442442

443443
# Regular relu op
@@ -496,7 +496,7 @@ def get_anchors(
496496
)
497497

498498
def replacement_op(self) -> OpOverload:
499-
return torch.ops.cadence.quantized_conv2d_nchw.default
499+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
500500

501501

502502
# Conv1d + regular relu op fusion

0 commit comments

Comments
 (0)