Skip to content

Commit 6f9cd8c

Browse files
mcr229facebook-github-bot
authored andcommitted
Add Int8DynActInt8WeightLinear module (#5605)
Summary: Pull Request resolved: pytorch/executorch#5605 Adding Int8DynActInt8WeightLinear for Per Channel DQ Linear Reviewed By: mergennachin Differential Revision: D63339550 fbshipit-source-id: 032a699a7f7fcc03177215fc40381038b4354c7a
1 parent 85e7458 commit 6f9cd8c

File tree

1 file changed

+93
-0
lines changed
  • examples/models/llama2/source_transformation

1 file changed

+93
-0
lines changed

examples/models/llama2/source_transformation/quantize.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,99 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
379379
# return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
380380

381381

382+
def linear_forward_8da8w(
383+
x,
384+
weight_int8,
385+
scales,
386+
zeros,
387+
out_features,
388+
precision,
389+
):
390+
from torchao.quantization.utils import per_token_dynamic_quant
391+
392+
x = per_token_dynamic_quant(x)
393+
n_bit = 8
394+
quant_min = -(2 ** (n_bit - 1))
395+
quant_max = 2 ** (n_bit - 1) - 1
396+
w_dq = torch.ops.quantized_decomposed.dequantize_per_channel(
397+
weight_int8,
398+
scales,
399+
zeros,
400+
0,
401+
quant_min,
402+
quant_max,
403+
torch.int8,
404+
out_dtype=precision,
405+
)
406+
c = torch.nn.functional.linear(x, w_dq)
407+
408+
return c
409+
410+
411+
class Int8DynActInt8WeightLinear(torch.nn.Module):
412+
__constants__ = ["in_features", "out_features"]
413+
414+
in_features: int
415+
out_features: int
416+
weight: torch.Tensor
417+
418+
"""
419+
This module implements a dynamic quantized linear layer with int8 weight.
420+
Weights are per channel quantized. Parameters of importance
421+
precision: precision of input and output. e.g. torch.float32 means input
422+
activation is float32 and output is float32.
423+
"""
424+
425+
def __init__(
426+
self,
427+
in_features: int,
428+
out_features: int,
429+
bias=True,
430+
device=None,
431+
dtype=None,
432+
precision: torch.dtype = torch.float32,
433+
) -> None:
434+
super().__init__()
435+
self.in_features = in_features
436+
self.out_features = out_features
437+
assert not bias, "require bias=False"
438+
self.precision = precision
439+
440+
if dtype is not None:
441+
raise ValueError("Please specify 'precision' instead of 'dtype'")
442+
443+
# currently storing unpacked int8 weights
444+
self.register_buffer(
445+
"weight",
446+
torch.empty((out_features, in_features), dtype=torch.int8),
447+
)
448+
self.register_buffer(
449+
"scales",
450+
torch.empty(
451+
(out_features),
452+
dtype=torch.float32,
453+
),
454+
)
455+
self.register_buffer(
456+
"zeros",
457+
torch.empty(
458+
(out_features),
459+
dtype=torch.float32,
460+
),
461+
)
462+
463+
def forward(self, input: torch.Tensor) -> torch.Tensor:
464+
input = input.to(self.precision)
465+
return linear_forward_8da8w(
466+
input,
467+
self.weight,
468+
self.scales,
469+
self.zeros,
470+
self.out_features,
471+
self.precision,
472+
)
473+
474+
382475
#########################################################################
383476
##### embedding table quantization ######
384477

0 commit comments

Comments
 (0)