|
19 | 19 | INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, |
20 | 20 | LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, |
21 | 21 | OutlierTracer, |
| 22 | + enable_ipex_fusion, |
22 | 23 | ) |
23 | 24 |
|
24 | 25 | T = TypeVar("T", bound="torch.nn.Module") |
@@ -444,17 +445,35 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): |
444 | 445 | save weight and bias, |
445 | 446 | then fill state_dict with components of quant_state |
446 | 447 | """ |
| 448 | + if ( |
| 449 | + getattr(self.weight, "quant_state", None) is not None |
| 450 | + and getattr(self.weight.quant_state, "op_context", None) is not None |
| 451 | + ): |
| 452 | + context = self.weight.quant_state.op_context |
| 453 | + self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) |
| 454 | + |
447 | 455 | super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias |
448 | 456 |
|
449 | 457 | if getattr(self.weight, "quant_state", None) is not None: |
| 458 | + if ( |
| 459 | + self.weight.quant_state.absmax.shape.numel() == 0 |
| 460 | + and getattr(self.weight.quant_state, "op_context", None) is not None |
| 461 | + ): |
| 462 | + self.weight.quant_state.absmax = context.get_scales().reshape(-1) |
| 463 | + delattr(self.weight.quant_state, "op_context") |
450 | 464 | for k, v in self.weight.quant_state.as_dict(packed=True).items(): |
451 | 465 | destination[prefix + "weight." + k] = v if keep_vars else v.detach() |
452 | | - if getattr(self.weight.quant_state, "op_context", None) is not None: |
453 | | - context = self.weight.quant_state.op_context |
454 | | - destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1) |
455 | | - self.weight.data = context.to_public(context.get_weight()).reshape([1, -1]) |
456 | 466 |
|
457 | 467 | def forward(self, x: torch.Tensor): |
| 468 | + # Check if ipex fusion can be used |
| 469 | + if ( |
| 470 | + x.device.type == "cpu" |
| 471 | + and not hasattr(self.weight.quant_state, "op_context") |
| 472 | + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 |
| 473 | + and self.weight.quant_state.quant_type == "nf4" |
| 474 | + ): |
| 475 | + enable_ipex_fusion(self.weight, self.weight.quant_state) |
| 476 | + |
458 | 477 | # weights are cast automatically as Int8Params, but the bias has to be cast manually |
459 | 478 | if self.bias is not None and self.bias.dtype != x.dtype: |
460 | 479 | self.bias.data = self.bias.data.to(x.dtype) |
|
0 commit comments