Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,11 +869,6 @@ def postprocess_state_dict(state_dict: dict, maxbound: float, quantization: str
post_state_dict[prefix + new_suffix] = value
break

# Squeeze tensors with a leading dimension of 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specify what errors you get for phi4-mm?

We needed this for sticking with our ckpt format and I remember if we don’t squeeze it will break some deployment flow with TRT-LLM.
It works for all our supported models so far.

for key, value in post_state_dict.items():
if isinstance(value, torch.Tensor) and value.dim() == 3 and value.shape[0] == 1:
post_state_dict[key] = value.squeeze(0)

# remove real quant parameters from the state dict
keys_to_delete = []
for key, value in post_state_dict.items():
Expand Down
Loading