Skip to content

Commit 49ed26d

Browse files
committed
Fix bug
1 parent 4a96bbc commit 49ed26d

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
612612
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
613613
_get_source_transforms(
614614
args.model,
615+
dtype_override,
615616
DType.from_torch_dtype(edge_manager.model.checkpoint_dtype),
616617
args,
617618
)
@@ -1040,7 +1041,10 @@ def _load_llama_model(
10401041

10411042

10421043
def _get_source_transforms( # noqa
1043-
modelname: str, dtype_override: Optional[DType], args
1044+
modelname: str,
1045+
dtype_override: DType,
1046+
checkpoint_dtype: Optional[DType],
1047+
args,
10441048
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
10451049
transforms = []
10461050

@@ -1074,7 +1078,7 @@ def _get_source_transforms( # noqa
10741078
"""
10751079
modelname = f"{modelname}_q"
10761080
transforms.append(
1077-
get_quant_weight_transform(args, dtype_override, verbose_export())
1081+
get_quant_weight_transform(args, checkpoint_dtype, verbose_export())
10781082
)
10791083

10801084
if args.embedding_quantize:
@@ -1088,7 +1092,7 @@ def _get_source_transforms( # noqa
10881092
this wil be a no-op.
10891093
"""
10901094
modelname = f"{modelname}_e"
1091-
transforms.append(get_quant_embedding_transform(args, dtype_override))
1095+
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
10921096

10931097
if args.quantization_mode or args.embedding_quantize:
10941098
transforms.append(

examples/models/llama/source_transformation/quantize.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,12 +859,13 @@ def _set_quantized_computation_dtype_rec(
859859
if isinstance(child, Int8DynActInt4WeightLinear):
860860
# Change the precision attribute to 'fp32'
861861
child.precision = dtype
862-
print(f"Changed precision of {name} to {dtype}")
862+
logging.info(f"Changed precision of {name} to {dtype}")
863863
elif isinstance(child, QuantizedGroupEmbedding):
864864
child.dtype = dtype
865-
print(f"Changed precision of {name} to {dtype}")
865+
logging.info(f"Changed precision of {name} to {dtype}")
866866
elif isinstance(child, WeightOnlyInt8Linear):
867867
child.dtype = dtype
868+
logging.info(f"Changed precision of {name} to {dtype}")
868869
else:
869870
# Recursively apply to child modules
870871
_set_quantized_computation_dtype_rec(child, dtype)

0 commit comments

Comments
 (0)