Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 2 additions & 18 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,25 +618,9 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
)

# At this point, the model is loaded in the default fp32.

# Checkpoint dtype should be lower or equal precision to the dtype override.
# override dtype
checkpoint_dtype = edge_manager.model.checkpoint_dtype
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why we need to keep this original checkpoint_dtype, since it would be overriden in the line below. Just to keep the behavior of the existing code. @jackzhxng and @kimishpatel , what's the reason behind this, for weight quantization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep it's for quantization

if not (
checkpoint_dtype == dtype_override.to_torch_dtype()
or (
checkpoint_dtype == torch.float16
and dtype_override.to_torch_dtype() == torch.float32
)
or (
checkpoint_dtype == torch.bfloat16
and dtype_override.to_torch_dtype() == torch.float32
)
):
logging.warning(
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
edge_manager.to_dtype(dtype_override)

# We want to quantize (in the source transforms) the weights of the model
# in the checkpoint dtype.
Expand Down
21 changes: 20 additions & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,26 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LLMEdgeManager":
assert not dtype_override or isinstance(
dtype_override, DType
), "Override dtype needs to be of type <DType>"
if dtype_override is not None and dtype_override != self.dtype:

# Checkpoint dtype should be lower or equal precision to the dtype override.
if hasattr(self.model, "checkpoint_dtype"):
checkpoint_dtype = self.model.checkpoint_dtype
if not (
checkpoint_dtype == dtype_override.to_torch_dtype()
or (
checkpoint_dtype == torch.float16
and dtype_override.to_torch_dtype() == torch.float32
)
or (
checkpoint_dtype == torch.bfloat16
and dtype_override.to_torch_dtype() == torch.float32
)
):
logging.warning(
f"Checkpoint dtype {checkpoint_dtype} precision is higher than dtype override {dtype_override.to_torch_dtype()}."
)

if dtype_override != self.dtype:
torch_dtype = dtype_override.to_torch_dtype()
logging.info(f"model.to {torch_dtype}")
self.model = self.model.to(dtype=torch_dtype)
Expand Down
Loading