Skip to content

Commit df89a3e

Browse files
committed
up
1 parent 4565a0b commit df89a3e

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -515,13 +515,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
515515

516516

517517
def get_quant_embedding_transform(
518-
args, use_shared_embedding: bool = False, dtype_override: Optional[DType] = None
518+
embedding_quantize: str,
519+
use_shared_embedding: bool = False,
520+
dtype_override: Optional[DType] = None,
519521
):
520-
use_torchao = args.embedding_quantize.startswith("torchao:")
522+
use_torchao = embedding_quantize.startswith("torchao:")
521523
if use_torchao:
522-
quant_args = args.embedding_quantize.split(":")[1].split(",")
524+
quant_args = embedding_quantize.split(":")[1].split(",")
523525
else:
524-
quant_args = args.embedding_quantize.split(",")
526+
quant_args = embedding_quantize.split(",")
525527
assert len(quant_args) in [
526528
2,
527529
3,

0 commit comments

Comments
 (0)