Skip to content

Commit d88601f

Browse files
committed
up
1 parent a7cee32 commit d88601f

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
3939
weight_1 = weight_view[:, :, 1] << 2
4040
weight_2 = weight_view[:, :, 2] << 4
4141
weight_3 = weight_view[:, :, 3] << 6
42-
packed_weight = weight_0 + weight_1 + weight_2 + weight_3
42+
packed_weight = weight_0 | weight_1 | weight_2 | weight_3
4343
return packed_weight
4444
elif bitwidth == 4:
4545
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2"
4646
weight_range_shifted = weight.add(8).view(torch.uint8)
4747
weight_view = weight_range_shifted.view(
4848
weight.shape[0], weight.shape[1] // 2, 2
4949
)
50-
weight_even = weight_view[:, :, 0] * 16 # left shift 4
50+
weight_even = weight_view[:, :, 0] << 4
5151
weight_odd = weight_view[:, :, 1]
52-
packed_weight = weight_even + weight_odd
52+
packed_weight = weight_even | weight_odd
5353
return packed_weight
5454
elif bitwidth == 8:
5555
return weight

exir/tests/test_quant_fusion_pass.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,17 @@ def forward(self, indices):
378378
# )
379379

380380
def test_embedding_torchao(self) -> None:
381-
for bit_width, test_dtype_variant, test_per_group in zip(
381+
for bit_width, use_dtype_variant, test_per_group in zip(
382382
[2, 4, 8], [True, False], [True, False]
383383
):
384-
self._test_embedding_torchao(bit_width, test_dtype_variant, test_per_group)
384+
self._test_embedding_torchao(bit_width, use_dtype_variant, test_per_group)
385385

386386
def _test_embedding_torchao(
387-
self, bit_width: int, test_dtype_variant: bool, test_per_group: bool
387+
self, bit_width: int, use_dtype_variant: bool, test_per_group: bool
388388
) -> None:
389389
assert bit_width in [2, 4, 8]
390390
embedding_suffix = f"{bit_width}bit" if bit_width < 8 else "byte"
391-
if test_dtype_variant:
391+
if use_dtype_variant:
392392
embedding_suffix = f"{embedding_suffix}_dtype"
393393

394394
indices = torch.tensor([1, 2, 3], dtype=torch.int64)
@@ -399,7 +399,7 @@ def _test_embedding_torchao(
399399

400400
# torchao adds a dtype cast to match embeddings original weight type
401401
# this does not happen for float32 because it is the default dtype
402-
model = model.to(torch.float16) if test_dtype_variant else model
402+
model = model.to(torch.float16) if use_dtype_variant else model
403403

404404
# quantize the model
405405
granularity = PerGroup(32) if test_per_group else PerAxis(0)

0 commit comments

Comments
 (0)