Skip to content

Conversation

@jackzhxng
Copy link
Contributor

@jackzhxng jackzhxng commented Feb 14, 2025

Summary

Perform quantization on the weights expressed in their original dtype (from the checkpoint) by performing source transformations before dtype cast. Previously the model was being converted to the dtype_override arg's dtype and then quantized. This eliminates supposedly eliminates quantization noise.

Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype

Comparison of arbitrary q_proj tensor from sample Llama checkpoint:

Before:

Mismatched elements: 3260378 / 4194304 (77.7%)
Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed)
Signal-to-noise: 32.8974 dB

After: no difference

Test plan

Manual testing

python -m examples.models.llama.export_llama \
-v -c xl_consolidated/consolidated_renamed.pth \
-p xl_consolidated/et_params.json -kv -d fp32 \
-qmode 8da4w --group_size 32 -X \
--use_sdpa_with_kv_cache \
--output_name quantized_baseline.pte \
--max_context_length 4096 -E 4,32

With the following inserted after the quantization:

edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)

And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing.

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8488

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures

As of commit be1921c with merge base 0dd7e4e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 14, 2025
@jackzhxng jackzhxng marked this pull request as draft February 14, 2025 00:29
@jackzhxng jackzhxng force-pushed the jz/dtype-shennanigans branch from e68b028 to cf0cb9d Compare February 14, 2025 00:33
@jackzhxng jackzhxng added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Feb 21, 2025
@jackzhxng jackzhxng force-pushed the jz/dtype-shennanigans branch from a85ce2a to 77df3eb Compare February 21, 2025 21:01
@jackzhxng jackzhxng marked this pull request as ready for review February 21, 2025 21:01
@jackzhxng jackzhxng changed the title Source transforms on model in original ckpt weights before dtype cast Fix bf16 quantization noise Feb 24, 2025
edge_manager = edge_manager.set_output_dir(output_dir_path).source_transform(
_get_source_transforms(
args.model, DType.from_torch_dtype(checkpoint_dtype), args
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)

Here to test

# We want to do compute the actual ops in the precision of the dtype_override,
# since the precision of the quantized linear will initially be the dtype of the
# checkpoint, not the dtype_override.
def _set_precision_to_fp32(module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kimishpatel for the issue we were discussing

@jackzhxng jackzhxng changed the title Fix bf16 quantization noise Fix xnnpack quantization discrepancy for non-fp32 Feb 25, 2025
Comment on lines 180 to 184
# Convert the model's weights only to the checkpoint's dtype, so that
# the checkpoint can be loaded into the model's state dict in its
# own dtype w/o potential precision loss.
for param in self.model_.parameters():
param.data = param.data.to(dtype=self.checkpoint_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldnt have to do this if the checkpoint is directly loaded, no? Not sure whats happening with self.model_.to(....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.model_.to(.... needs to happen before the the params are set to the checkpoint dtype so that we end up with our weights in the checkpoint dtype (needed for quantization) and the rest of the model in the dtype override. The when we load the checkpoint, no dtype promotion will be happening.

This dtype promotion is technically always lossless since all of the dtypes we support have lossless conversion to fp32, but I'm doing this in case we want to support dtypes in the future that don't have lossless conversion to fp32. If we can make this assumption though, then we can decouple model.py from dtype casting and move the logic outside which I think @larryliu0820 was looking to do.

.source_transform(_get_source_transforms(args.model, dtype_override, args))
)

_set_quantized_computation_dtype(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides the changes you are doing in this function, dont you also need to do edge_manager.model.to(self.dtype)?

@jackzhxng jackzhxng temporarily deployed to upload-benchmark-results February 25, 2025 11:20 — with GitHub Actions Inactive
@facebook-github-bot
Copy link
Contributor

@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jackzhxng jackzhxng force-pushed the jz/dtype-shennanigans branch from 00c8f4a to 82d748d Compare February 27, 2025 03:10
@facebook-github-bot
Copy link
Contributor

@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jackzhxng jackzhxng force-pushed the jz/dtype-shennanigans branch from 25d5ac7 to 49ed26d Compare February 27, 2025 16:40
@facebook-github-bot
Copy link
Contributor

@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Feb 27, 2025
Summary:
Perform quantization on the weights expressed in their original dtype (from the checkpoint) by performing source transformations before dtype cast. Previously the model was being converted to the `dtype_override` arg's dtype and then quantized. This eliminates supposedly eliminates quantization noise.

Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype

### Comparison of arbitrary q_proj tensor from sample Llama checkpoint:
Before:
```
Mismatched elements: 3260378 / 4194304 (77.7%)
Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed)
Signal-to-noise: 32.8974 dB
```

After: no difference


Test Plan:
### Manual testing
```
python -m examples.models.llama.export_llama \
-v -c xl_consolidated/consolidated_renamed.pth \
-p xl_consolidated/et_params.json -kv -d fp32 \
-qmode 8da4w --group_size 32 -X \
--use_sdpa_with_kv_cache \
--output_name quantized_baseline.pte \
--max_context_length 4096 -E 4,32
```

With the following inserted after the quantization:

```
edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)
```

And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing.

### Automated testing
+ existing CI tests

### Regression testing
TBD

Differential Revision: D70184325

Pulled By: jackzhxng
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70184325

@jackzhxng jackzhxng closed this Mar 19, 2025
@jackzhxng jackzhxng force-pushed the jz/dtype-shennanigans branch from 91c0d0c to a2f9cbe Compare March 19, 2025 00:06
@jackzhxng jackzhxng reopened this Mar 19, 2025
facebook-github-bot pushed a commit that referenced this pull request Mar 19, 2025
Summary:
Perform quantization on the weights expressed in their original dtype (from the checkpoint) by performing source transformations before dtype cast. Previously the model was being converted to the `dtype_override` arg's dtype and then quantized. This eliminates supposedly eliminates quantization noise.

Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype

### Comparison of arbitrary q_proj tensor from sample Llama checkpoint:
Before:
```
Mismatched elements: 3260378 / 4194304 (77.7%)
Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed)
Signal-to-noise: 32.8974 dB
```

After: no difference


Test Plan:
### Manual testing
```
python -m examples.models.llama.export_llama \
-v -c xl_consolidated/consolidated_renamed.pth \
-p xl_consolidated/et_params.json -kv -d fp32 \
-qmode 8da4w --group_size 32 -X \
--use_sdpa_with_kv_cache \
--output_name quantized_baseline.pte \
--max_context_length 4096 -E 4,32
```

With the following inserted after the quantization:

```
edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)
```

And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing.

### Automated testing
+ existing CI tests

### Regression testing
TBD

Differential Revision: D70184325

Pulled By: jackzhxng
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70184325

@facebook-github-bot
Copy link
Contributor

@jackzhxng has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 21, 2025
Summary:
Perform quantization on the weights expressed in their original dtype (from the checkpoint) by passing in the checkpoint dtype to the quantization source transformation and modifying the computation dtype (the result dtype of the dequant, the dtype that the ops are actually computed in) to the dtype override. We must do it this way since the checkpoint and computation dtype are coupled into a single `precision` parameter in the torchao api, and that is something that we cannot change.

Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype

### Comparison of arbitrary q_proj tensor from sample Llama checkpoint:
Before:
```
Mismatched elements: 3260378 / 4194304 (77.7%)
Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed)
Signal-to-noise: 32.8974 dB
```

After: no difference


Test Plan:
### Manual testing
```
python -m examples.models.llama.export_llama \
-v -c xl_consolidated/consolidated_renamed.pth \
-p xl_consolidated/et_params.json -kv -d fp32 \
-qmode 8da4w --group_size 32 -X \
--use_sdpa_with_kv_cache \
--output_name quantized_baseline.pte \
--max_context_length 4096 -E 4,32
```

With the following inserted after the quantization:

```
edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)
```

And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing.

### Automated testing
+ existing CI tests

### Regression testing
TBD

Reviewed By: kimishpatel

Differential Revision: D70184325

Pulled By: jackzhxng
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70184325

Summary:
Perform quantization on the weights expressed in their original dtype (from the checkpoint) by passing in the checkpoint dtype to the quantization source transformation and modifying the computation dtype (the result dtype of the dequant, the dtype that the ops are actually computed in) to the dtype override. We must do it this way since the checkpoint and computation dtype are coupled into a single `precision` parameter in the torchao api, and that is something that we cannot change.

Note - no need to worry about https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ.py#L1168, precision is passed in with the checkpoint dtype

### Comparison of arbitrary q_proj tensor from sample Llama checkpoint:
Before:
```
Mismatched elements: 3260378 / 4194304 (77.7%)
Greatest absolute difference: 0.08802086114883423 at index (1129, 604) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (0, 1350) (up to 1.3e-06 allowed)
Signal-to-noise: 32.8974 dB
```

After: no difference


Test Plan:
### Manual testing
```
python -m examples.models.llama.export_llama \
-v -c xl_consolidated/consolidated_renamed.pth \
-p xl_consolidated/et_params.json -kv -d fp32 \
-qmode 8da4w --group_size 32 -X \
--use_sdpa_with_kv_cache \
--output_name quantized_baseline.pte \
--max_context_length 4096 -E 4,32
```

With the following inserted after the quantization:

```
edge_manager.model(
    torch.tensor([[2, 3, 4]], dtype=torch.long),
    {"input_pos": torch.tensor([0], dtype=torch.long)},
)
```

And the following modifications to GPTQ.py in torchao: pytorch/ao#1756 for testing.

### Automated testing
+ existing CI tests

### Regression testing
TBD

Reviewed By: kimishpatel

Differential Revision: D70184325

Pulled By: jackzhxng
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D70184325

@facebook-github-bot facebook-github-bot merged commit 38851a1 into main Mar 22, 2025
159 of 166 checks passed
@facebook-github-bot facebook-github-bot deleted the jz/dtype-shennanigans branch March 22, 2025 08:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants