Skip to content

Commit 695c7d5

Browse files
authored
Replaced int4 string with torch.int4
Differential Revision: D77058163 Pull Request resolved: #11845
1 parent 608a745 commit 695c7d5

File tree

3 files changed

+27
-21
lines changed

3 files changed

+27
-21
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def annotate_matmul_input1(node: Node):
233233
)
234234
quantization_config_8a4w_per_channel = get_ptq_per_channel_quant_config(
235235
act_dtype=torch.uint8,
236-
weight_dtype="int4",
236+
weight_dtype=torch.int4,
237237
act_observer=MinMaxObserver,
238238
act_symmetric=True,
239239
)

backends/qualcomm/quantizer/qconfig.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ def get_ptq_per_channel_quant_config(
241241
torch.int8,
242242
torch.int16,
243243
}
244-
# TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
245-
supported_weight_dtypes = {"int4", torch.int8, torch.int16}
244+
supported_weight_dtypes = {torch.int4, torch.int8, torch.int16}
246245
assert (
247246
act_dtype in supported_act_types
248247
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
@@ -276,9 +275,11 @@ def get_ptq_per_channel_quant_config(
276275
)
277276

278277
weight_quantization_spec = QuantizationSpec(
279-
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
280-
quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
281-
quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
278+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
279+
quant_min=(
280+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
281+
),
282+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
282283
qscheme=torch.per_channel_symmetric,
283284
ch_axis=0,
284285
observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
@@ -310,9 +311,11 @@ def get_ptq_per_block_quant_config(
310311
act_symmetric=act_symmetric,
311312
)
312313
weight_quantization_spec = QuantizationSpec(
313-
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
314-
quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
315-
quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
314+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
315+
quant_min=(
316+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
317+
),
318+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
316319
qscheme=torch.per_channel_symmetric,
317320
ch_axis=0,
318321
observer_or_fake_quant_ctr=PerBlockParamObserver.with_args(**extra_args),
@@ -463,8 +466,7 @@ def get_qat_per_channel_quant_config(
463466
torch.int8,
464467
torch.int16,
465468
}
466-
# TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype
467-
supported_weight_dtypes = {"int4", torch.int8, torch.int16}
469+
supported_weight_dtypes = {torch.int4, torch.int8, torch.int16}
468470
assert (
469471
act_dtype in supported_act_types
470472
), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}"
@@ -491,17 +493,21 @@ def get_qat_per_channel_quant_config(
491493
)
492494

493495
weight_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
494-
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
495-
quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
496-
quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
496+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
497+
quant_min=(
498+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
499+
),
500+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
497501
qscheme=torch.per_channel_symmetric,
498502
ch_axis=0,
499503
observer=MovingAveragePerChannelMinMaxObserver,
500504
)
501505
weight_quantization_spec = QuantizationSpec(
502-
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
503-
quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1,
504-
quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max,
506+
dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype,
507+
quant_min=(
508+
-7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).min + 1
509+
),
510+
quant_max=7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max,
505511
qscheme=torch.per_channel_symmetric,
506512
ch_axis=0,
507513
observer_or_fake_quant_ctr=weight_fake_quant_ctr,

backends/qualcomm/quantizer/quantizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class QuantDtype(IntEnum):
8585
partial(
8686
get_ptq_per_channel_quant_config,
8787
act_dtype=torch.uint16,
88-
weight_dtype="int4",
88+
weight_dtype=torch.int4,
8989
),
9090
None,
9191
),
@@ -94,12 +94,12 @@ class QuantDtype(IntEnum):
9494
partial(
9595
get_ptq_per_channel_quant_config,
9696
act_dtype=torch.uint16,
97-
weight_dtype="int4",
97+
weight_dtype=torch.int4,
9898
),
9999
partial(
100100
get_ptq_per_block_quant_config,
101101
act_dtype=torch.uint16,
102-
weight_dtype="int4",
102+
weight_dtype=torch.int4,
103103
),
104104
),
105105
(QuantDtype.use_8a8w, False): (
@@ -113,7 +113,7 @@ class QuantDtype(IntEnum):
113113
partial(
114114
get_qat_per_channel_quant_config,
115115
act_dtype=torch.uint16,
116-
weight_dtype="int4",
116+
weight_dtype=torch.int4,
117117
),
118118
None,
119119
),

0 commit comments

Comments
 (0)