|
56 | 56 |
|
57 | 57 | from .source_transformation.attention import replace_attention_to_attention_sha |
58 | 58 | from .source_transformation.quantize import ( |
59 | | - set_quantized_computation_dtype, |
60 | 59 | get_quant_embedding_transform, |
61 | 60 | get_quant_weight_transform, |
| 61 | + set_quantized_computation_dtype, |
62 | 62 | ) |
63 | 63 | from .source_transformation.quantized_kv_cache import ( |
64 | 64 | replace_kv_cache_with_custom_kv_cache, |
@@ -596,31 +596,24 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: |
596 | 596 |
|
597 | 597 | # At this point, the model is loaded in the default fp32. |
598 | 598 |
|
599 | | - # Convert the non-weights of the model (the buffers) to the dtype_override. |
600 | | - # Need to do this before source transform quantization since the quantized |
601 | | - # parameters become buffers. |
602 | | - for buf in edge_manager.model.buffers(): |
603 | | - buf.data = buf.data.to(dtype=dtype_override.to_torch_dtype()) |
| 599 | + # TODO: some validation for the combination of checkpoint dtype and dtype_override. |
| 600 | + |
| 601 | + edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) |
604 | 602 |
|
605 | 603 | # We want to quantize (in the source transforms) the weights of the model |
606 | 604 | # in the checkpoint dtype. |
607 | 605 | logging.info(f"Checkpoint dtype: {edge_manager.model.checkpoint_dtype}") |
608 | 606 | edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform( |
609 | 607 | _get_source_transforms( |
610 | | - args.model, |
611 | | - dtype_override, |
612 | | - DType.from_torch_dtype(edge_manager.model.checkpoint_dtype), |
613 | | - args, |
| 608 | + modelname=args.model, |
| 609 | + dtype_override=dtype_override, |
| 610 | + checkpoint_dtype=DType.from_torch_dtype( |
| 611 | + edge_manager.model.checkpoint_dtype |
| 612 | + ), |
| 613 | + args=args, |
614 | 614 | ) |
615 | 615 | ) |
616 | 616 |
|
617 | | - # Convert the parameters to the dtype_override. |
618 | | - # If source transform quantization has already happened at this point (-qmode), |
619 | | - # the quantized weights will become buffers and not be returned by .parameters(), |
620 | | - # so we don't convert them to the dtype_override. |
621 | | - for param in edge_manager.model.parameters(): |
622 | | - param.data = param.data.to(dtype=dtype_override.to_torch_dtype()) |
623 | | - |
624 | 617 | return edge_manager |
625 | 618 |
|
626 | 619 |
|
|
0 commit comments