diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index dc7f763fade..3df3b2abea1 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -618,25 +618,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager: ) # At this point, the model is loaded in the default fp32. - - # Checkpoint dtype should be lower or equal precision to the dtype override. + # override dtype checkpoint_dtype = edge_manager.model.checkpoint_dtype - if not ( - checkpoint_dtype == dtype_override.to_torch_dtype() - or ( - checkpoint_dtype == torch.float16 - and dtype_override.to_torch_dtype() == torch.float32 - ) - or ( - checkpoint_dtype == torch.bfloat16 - and dtype_override.to_torch_dtype() == torch.float32 - ) - ): - logging.warning( - f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." - ) - - edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype()) + edge_manager.to_dtype(dtype_override) # We want to quantize (in the source transforms) the weights of the model # in the checkpoint dtype. diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 751e2d16175..b4882c6be07 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -147,7 +147,26 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager": assert not dtype_override or isinstance( dtype_override, DType ), "Override dtype needs to be of type " - if dtype_override is not None and dtype_override != self.dtype: + + # Checkpoint dtype should be lower or equal precision to the dtype override. + if hasattr(self.model, "checkpoint_dtype"): + checkpoint_dtype = self.model.checkpoint_dtype + if not ( + checkpoint_dtype == dtype_override.to_torch_dtype() + or ( + checkpoint_dtype == torch.float16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + or ( + checkpoint_dtype == torch.bfloat16 + and dtype_override.to_torch_dtype() == torch.float32 + ) + ): + logging.warning( + f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}." + ) + + if dtype_override != self.dtype: torch_dtype = dtype_override.to_torch_dtype() logging.info(f"model.to {torch_dtype}") self.model = self.model.to(dtype=torch_dtype)