-
-
Notifications
You must be signed in to change notification settings - Fork 797
Cpu fused kernel #1804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Cpu fused kernel #1804
Changes from 70 commits
6be1412
252ac0f
f98c9e5
902bf35
fef8459
55cbaa0
bbef95b
e842513
d4473fa
dea8dd6
e9bb4fe
cdc8d5e
baacfac
eec3521
d7cc1c5
124b754
0f918c7
e1a8b20
eab45c8
d9f5dd8
070f8a0
a84addf
c4bb660
4ba13fd
c0d05ec
62a16a6
d9ad828
09ed6cb
a3f7b61
4708470
1dfe9f7
00289c4
a2578ba
72033dc
1c20ae8
7552fe2
8b32a39
8f1cc36
49d242a
4a9a6dc
6bcd19e
d7e981d
48739b0
f784be8
92192c9
bd02e71
8520069
e921cbb
9b5d97a
fd6cff1
46d6e47
3271c30
176a2b6
196984a
7652115
ea0e649
9277d24
4fb315b
81f1984
0f78bad
fcb8456
c5e1894
f2029c6
df1d669
bb3ac8d
26b5685
445725b
580010c
57b89bf
302a5fe
de5fb9c
6858a90
3b3d609
fbb911b
3179b42
0c88d43
feb8ad2
c6b714d
5497111
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
|
|
||
| import bitsandbytes as bnb | ||
| from bitsandbytes.cextension import ROCM_WARP_SIZE_64 | ||
| from bitsandbytes.functional import QuantState | ||
| from bitsandbytes.functional import QuantState, convert_weight_packed_for_cpu, has_avx512bf16 | ||
| from bitsandbytes.optim import GlobalOptimManager | ||
| from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer | ||
|
|
||
|
|
@@ -479,6 +479,7 @@ def __init__( | |
| self.compute_type_is_set = compute_dtype is not None | ||
| self.quant_state = None | ||
| self.quant_storage = quant_storage | ||
| self.enable_optimized_cpu = False | ||
|
|
||
| def set_compute_type(self, x): | ||
| if x.dtype in [torch.float32, torch.bfloat16]: | ||
|
|
@@ -512,8 +513,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): | |
| destination[prefix + "weight." + k] = v if keep_vars else v.detach() | ||
|
|
||
| def forward(self, x: torch.Tensor): | ||
| quant_state = self.weight.quant_state | ||
| fix_4bit_weight_quant_state_from_module(self) | ||
jiqing-feng marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if ( | ||
| not self.enable_optimized_cpu | ||
| and x.device.type == "cpu" | ||
| and has_avx512bf16() | ||
| and not self.training | ||
| and x.requires_grad == False | ||
| ): | ||
| self.weight.data, quant_state = convert_weight_packed_for_cpu(self.weight.data, quant_state) | ||
| self.enable_optimized_cpu = True | ||
| quant_state.enable_optimized_cpu = True | ||
|
|
||
|
Comment on lines
531
to
539
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a couple things I'm wondering about: When we serialize from CPU after running through forward(), we probably still want to be compatible with other devices. I am thinking for when serializing we want to undo this transformation if it's present. Possibly an edge concern, but if we do a forward pass on CPU and then move to an accelerator, what would happen? I assume the weights are then in the wrong order? @SunMarc I would appreciate any feedback you might have on this part!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me, I prefer that we stick with only one packing format for serialization and the all other hardware / kernels convert this packing format at initialization or during the forward as we do here. So we need a way to disable serialization or send a warning when someone tries to do that. This is probably something that we can do in transformers as I think most of the models are serialized from there. Also instead of
Either we re-convert the weights for cuda (but this opens the door to many conversion function between all packing format) or we just raise an error asking the users to only run the model on one device. |
||
| # weights are cast automatically as Int8Params, but the bias has to be cast manually | ||
| if self.bias is not None and self.bias.dtype != x.dtype: | ||
| self.bias.data = self.bias.data.to(x.dtype) | ||
|
|
@@ -527,9 +540,9 @@ def forward(self, x: torch.Tensor): | |
| x = x.to(self.compute_dtype) | ||
|
|
||
| bias = None if self.bias is None else self.bias.to(self.compute_dtype) | ||
| weight = self.weight.t() | ||
| weight = self.weight if getattr(quant_state, "enable_optimized_cpu", False) else self.weight.t() | ||
|
|
||
| return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) | ||
| return bnb.matmul_4bit(x, weight, bias=bias, quant_state=quant_state).to(inp_dtype) | ||
|
|
||
|
|
||
| class LinearFP4(Linear4bit): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.