Skip to content

Commit 59ae6d1

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Refactor quantizer: Only replace with per-tensor variants (pytorch#14974)
Summary: In our previous flow, we would replace ops with default variants, have a special fusion pass which constructs singleton tensors for a variety of fused quantized ops, and then we would call a replace ops to turn them into per-tensor-variants. I confirmed this was for legacy reasons, so a cleanup was much due. This diff directly replaces ops with the per-tensor variants and removes the pass which replaces singleton tensors with scalars. Reviewed By: hsharma35, zonglinpeng Differential Revision: D83873738
1 parent 9560800 commit 59ae6d1

File tree

5 files changed

+138
-467
lines changed

5 files changed

+138
-467
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,7 @@ python_unittest(
425425
"//executorch/exir:pass_base",
426426
"//executorch/exir/dialects:lib",
427427
"//executorch/exir/passes:lib",
428+
":ref_implementations",
428429
],
429430
)
430431

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 30 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -66,33 +66,18 @@ def get_args_and_kwargs_add(
6666
dequants_inputs: List[fx.Node],
6767
quant_node: fx.Node,
6868
) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]:
69-
X_scale_ = graph_module.graph.call_function(
70-
torch.ops.aten.full.default,
71-
([1], dequants_inputs[0].args[1]),
72-
{"dtype": torch.float},
73-
)
74-
X_zero_point_ = graph_module.graph.call_function(
75-
torch.ops.aten.full.default,
76-
([1], dequants_inputs[0].args[2]),
77-
{"dtype": torch.int32},
78-
)
79-
Y_scale_ = graph_module.graph.call_function(
80-
torch.ops.aten.full.default,
81-
([1], dequants_inputs[1].args[1]),
82-
{"dtype": torch.float},
83-
)
84-
Y_zero_point_ = graph_module.graph.call_function(
85-
torch.ops.aten.full.default,
86-
([1], dequants_inputs[1].args[2]),
87-
{"dtype": torch.int32},
88-
)
69+
X_scale = dequants_inputs[0].args[1]
70+
71+
X_zero_point = dequants_inputs[0].args[2]
72+
Y_scale = dequants_inputs[1].args[1]
73+
Y_zero_point = dequants_inputs[1].args[2]
8974
args = (
9075
inputs_inputs[0],
91-
X_scale_,
92-
X_zero_point_,
76+
X_scale,
77+
X_zero_point,
9378
inputs_inputs[1],
94-
Y_scale_,
95-
Y_zero_point_,
79+
Y_scale,
80+
Y_zero_point,
9681
quant_node.args[1],
9782
quant_node.args[2],
9883
)
@@ -130,31 +115,12 @@ def get_args_and_kwargs_linear(
130115
else:
131116
bias = bias_inputs[0]
132117

133-
# Create single element tensors for weight_zero_point, out_multiplier, out_shift.
134-
# Note that the function expects int32_t, when it would default to int64_t, so
135-
# we explicitly require that type.
136-
weight_zero_point_ = graph_module.graph.call_function(
137-
torch.ops.aten.full.default,
138-
([1], dequants_weights[0].args[2]),
139-
{"dtype": torch.int32},
140-
)
141-
out_multiplier_ = graph_module.graph.call_function(
142-
torch.ops.aten.full.default,
143-
([1], out_multiplier[0].item()),
144-
{"dtype": torch.int32},
145-
)
146-
out_shift_ = graph_module.graph.call_function(
147-
torch.ops.aten.full.default,
148-
([1], out_shift[0].item()),
149-
{"dtype": torch.int32},
150-
)
151-
152118
args = tuple(inputs_inputs + weights_inputs + [bias])
153119
kwargs = {
154120
"src_zero_point": dequants_inputs[0].args[2],
155-
"weight_zero_point": weight_zero_point_,
156-
"out_multiplier": out_multiplier_,
157-
"out_shift": out_shift_,
121+
"weight_zero_point": dequants_weights[0].args[2],
122+
"out_multiplier": out_multiplier[0].item(),
123+
"out_shift": out_shift[0].item(),
158124
"out_zero_point": quant_node.args[2],
159125
"offset": None,
160126
}
@@ -179,22 +145,8 @@ def get_args_and_kwargs_layer_norm(
179145
), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars"
180146

181147
# Make the scale and zero_point tensors
182-
scale_tensor = graph_module.graph.call_function(
183-
torch.ops.aten.full.default,
184-
(
185-
[1],
186-
dequants_inputs[0].args[1],
187-
),
188-
{"dtype": torch.float32},
189-
)
190-
zero_point_tensor = graph_module.graph.call_function(
191-
torch.ops.aten.full.default,
192-
(
193-
[1],
194-
dequants_inputs[0].args[2],
195-
),
196-
{"dtype": torch.int32},
197-
)
148+
scale = dequants_inputs[0].args[1]
149+
zero_point = dequants_inputs[0].args[2]
198150

199151
weight = other_inputs[1] if len(other_inputs) > 1 else None
200152

@@ -221,7 +173,7 @@ def get_args_and_kwargs_layer_norm(
221173
)
222174

223175
# Make the args and kwargs for the replacement op
224-
args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor])
176+
args = tuple(inputs_inputs + [scale, zero_point])
225177
kwargs = {
226178
"normalized_shape": other_inputs[0],
227179
"weight": weight,
@@ -309,31 +261,6 @@ def get_args_and_kwargs_conv(
309261

310262
(out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t)
311263

312-
out_multiplier_ = graph_module.graph.call_function(
313-
torch.ops.aten.full.default,
314-
([1], out_multiplier[0].item()),
315-
{"dtype": torch.int32},
316-
)
317-
out_shift_ = graph_module.graph.call_function(
318-
torch.ops.aten.full.default,
319-
([1], out_shift[0].item()),
320-
{"dtype": torch.int32},
321-
)
322-
323-
# Create a single element tensor for the weight zero point
324-
weight_zero_point_tensor = graph_module.graph.call_function(
325-
torch.ops.aten.full.default,
326-
([1], weight_zero_point),
327-
{"dtype": torch.int32},
328-
)
329-
330-
# Create a single element tensor for the bias scale
331-
bias_scale_tensor = graph_module.graph.call_function(
332-
torch.ops.aten.full.default,
333-
([1], bias_scale),
334-
{"dtype": torch.float32},
335-
)
336-
337264
# Make the args and kwargs for the replacement op
338265
args = tuple(inputs_inputs + weights_inputs + [bias])
339266
kwargs = {
@@ -342,12 +269,12 @@ def get_args_and_kwargs_conv(
342269
"dilation": dilation,
343270
"groups": groups,
344271
"input_zero_point": dequants_inputs[0].args[2],
345-
"weight_zero_point": weight_zero_point_tensor,
346-
"bias_scale": bias_scale_tensor,
272+
"weight_zero_point": weight_zero_point,
273+
"bias_scale": bias_scale,
347274
"out_scale": quant_node.args[1],
348275
"out_zero_point": quant_node.args[2],
349-
"out_multiplier": out_multiplier_,
350-
"out_shift": out_shift_,
276+
"out_multiplier": out_multiplier[0].item(),
277+
"out_shift": out_shift[0].item(),
351278
}
352279
return args, kwargs
353280

@@ -368,27 +295,11 @@ def get_args_and_kwargs_relu(
368295
# Make the args and kwargs for the replacement op
369296
args = tuple(inputs_inputs)
370297

371-
X_zero_point = graph_module.graph.call_function(
372-
torch.ops.aten.full.default,
373-
([1], dequants_inputs[0].args[2]),
374-
{"dtype": torch.int32},
375-
)
376-
out_multiplier_ = graph_module.graph.call_function(
377-
torch.ops.aten.full.default,
378-
([1], out_multiplier[0].item()),
379-
{"dtype": torch.int32},
380-
)
381-
out_shift_ = graph_module.graph.call_function(
382-
torch.ops.aten.full.default,
383-
([1], out_shift[0].item()),
384-
{"dtype": torch.int32},
385-
)
386-
387298
kwargs = {
388-
"X_zero_point": X_zero_point,
299+
"X_zero_point": dequants_inputs[0].args[2],
389300
"out_zero_point": quant_node.args[2],
390-
"out_multiplier": out_multiplier_,
391-
"out_shift": out_shift_,
301+
"out_multiplier": out_multiplier[0].item(),
302+
"out_shift": out_shift[0].item(),
392303
}
393304
return args, kwargs
394305

@@ -436,48 +347,20 @@ def get_args_and_kwargs_softmax(
436347
{"dtype": torch.int32},
437348
)
438349
# Make the scale and zero_point tensors
439-
in_scale_tensor = graph_module.graph.call_function(
440-
torch.ops.aten.full.default,
441-
(
442-
[1],
443-
dequants_inputs[0].args[1],
444-
),
445-
{"dtype": torch.float32},
446-
)
447-
in_zero_point_tensor = graph_module.graph.call_function(
448-
torch.ops.aten.full.default,
449-
(
450-
[1],
451-
dequants_inputs[0].args[2],
452-
),
453-
{"dtype": torch.int32},
454-
)
455-
out_scale_tensor = graph_module.graph.call_function(
456-
torch.ops.aten.full.default,
457-
(
458-
[1],
459-
quant_node.args[1],
460-
),
461-
{"dtype": torch.float32},
462-
)
463-
out_zero_point_tensor = graph_module.graph.call_function(
464-
torch.ops.aten.full.default,
465-
(
466-
[1],
467-
quant_node.args[2],
468-
),
469-
{"dtype": torch.int32},
470-
)
350+
in_scale = dequants_inputs[0].args[1]
351+
in_zero_point = dequants_inputs[0].args[2]
352+
out_scale = quant_node.args[1]
353+
out_zero_point = quant_node.args[2]
471354

472355
# Make the args and kwargs for the replacement op
473356
args = (
474357
inputs_inputs[0],
475358
mask_tensor,
476359
op_node.args[1],
477-
in_scale_tensor,
478-
in_zero_point_tensor,
479-
out_scale_tensor,
480-
out_zero_point_tensor,
360+
in_scale,
361+
in_zero_point,
362+
out_scale,
363+
out_zero_point,
481364
)
482365
kwargs = {}
483366

backends/cadence/aot/quantizer/patterns.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def get_anchors(
112112
)
113113

114114
def replacement_op(self) -> OpOverload:
115-
return torch.ops.cadence.quantized_linear.default
115+
return torch.ops.cadence.quantized_linear.per_tensor
116116

117117

118118
class AddPattern(QuantizationPattern):
@@ -150,7 +150,7 @@ def get_anchors(
150150
)
151151

152152
def replacement_op(self) -> OpOverload:
153-
return torch.ops.cadence.quantized_add.default
153+
return torch.ops.cadence.quantized_add.per_tensor
154154

155155

156156
class BmmPattern(QuantizationPattern):
@@ -174,6 +174,8 @@ def get_anchors(
174174
)
175175

176176
def replacement_op(self) -> OpOverload:
177+
# TODO: T240804887 This is actually a per-tensor variant,
178+
# we just need to change the name of the op
177179
return torch.ops.cadence.quantized_matmul.default
178180

179181

@@ -265,7 +267,7 @@ def get_anchors(
265267
)
266268

267269
def replacement_op(self) -> OpOverload:
268-
return torch.ops.cadence.quantized_conv2d_nchw.default
270+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
269271

270272

271273
class Conv2dPattern(QuantizationPattern):
@@ -307,7 +309,7 @@ def get_anchors(
307309
)
308310

309311
def replacement_op(self) -> OpOverload:
310-
return torch.ops.cadence.quantized_conv2d_nchw.default
312+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
311313

312314

313315
class LayerNormPattern(QuantizationPattern):
@@ -345,7 +347,7 @@ def get_anchors(
345347
)
346348

347349
def replacement_op(self) -> OpOverload:
348-
return torch.ops.cadence.quantized_layer_norm.default
350+
return torch.ops.cadence.quantized_layer_norm.per_tensor
349351

350352

351353
class LinearPattern(QuantizationPattern):
@@ -387,7 +389,7 @@ def get_anchors(
387389
)
388390

389391
def replacement_op(self) -> OpOverload:
390-
return torch.ops.cadence.quantized_linear.default
392+
return torch.ops.cadence.quantized_linear.per_tensor
391393

392394

393395
class MatmulPattern(QuantizationPattern):
@@ -411,6 +413,7 @@ def get_anchors(
411413
)
412414

413415
def replacement_op(self) -> OpOverload:
416+
# TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op
414417
return torch.ops.cadence.quantized_matmul.default
415418

416419

@@ -437,7 +440,7 @@ def get_anchors(
437440
)
438441

439442
def replacement_op(self) -> OpOverload:
440-
return torch.ops.cadence.quantized_relu.default
443+
return torch.ops.cadence.quantized_relu.per_tensor
441444

442445

443446
# Regular relu op
@@ -496,7 +499,7 @@ def get_anchors(
496499
)
497500

498501
def replacement_op(self) -> OpOverload:
499-
return torch.ops.cadence.quantized_conv2d_nchw.default
502+
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
500503

501504

502505
# Conv1d + regular relu op fusion
@@ -544,7 +547,7 @@ def get_anchors(
544547
)
545548

546549
def replacement_op(self) -> OpOverload:
547-
return torch.ops.cadence.quantized_softmax.default
550+
return torch.ops.cadence.quantized_softmax.per_tensor
548551

549552

550553
class MixedW8A32LinearPattern(QuantizationPattern):

0 commit comments

Comments
 (0)