|
13 | 13 | import numpy as np |
14 | 14 |
|
15 | 15 | from functools import reduce # Required in Python 3 |
16 | | -from typing import Tuple |
| 16 | +from typing import Tuple, Any |
17 | 17 | from torch import Tensor |
| 18 | +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict |
18 | 19 |
|
19 | 20 | from .cextension import COMPILED_WITH_CUDA, lib |
20 | 21 |
|
@@ -580,58 +581,77 @@ def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=Non |
580 | 581 | self.nested = state2 is not None |
581 | 582 |
|
582 | 583 | @classmethod |
583 | | - def from_dict(cls, quant_state_dict: dict[str, torch.Tensor], device: torch.device) -> 'QuantState': |
| 584 | + def from_dict(cls, qs_dict: dict[str, Any], device: torch.device) -> 'QuantState': |
584 | 585 | """ |
585 | 586 | unpacks dict of tensors into QuantState |
586 | 587 | where necessary, convert into strings, torch.dtype, ints, etc. |
| 588 | +
|
| 589 | + quant_state_dict may contain item with non-tensor components with key like |
| 590 | + `...weight.quant_state.bitsandbytes__[nf4/fp4]` |
| 591 | + it is detected with key strored in qs_key, and then unpacked |
587 | 592 | """ |
588 | 593 |
|
589 | | - quant_state_dict = {k.split('.')[-1]:v for k, v in quant_state_dict.items()} |
590 | | - if 'quant_state_dict' in quant_state_dict: |
591 | | - quant_state_dict|= quant_state_dict.pop('quant_state_dict') |
| 594 | + # unpacking tensor with non-tensor components |
| 595 | + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] |
| 596 | + if len(qs_key) == 1: |
| 597 | + qs_key = qs_key[0] |
| 598 | + assert 'bitsandbytes__nf4' in qs_key or 'bitsandbytes__fp4' in qs_key, \ |
| 599 | + f"invalid qs_key value {qs_key}" |
| 600 | + qs_dict |= unpack_tensor_to_dict(qs_dict.pop(qs_key)) |
| 601 | + |
| 602 | + qs_dict = {k.split('.')[-1]:v for k, v in qs_dict.items()} # strip prefixes |
592 | 603 |
|
593 | | - if 'nested_absmax' in quant_state_dict: |
594 | | - offset = torch.tensor(float(quant_state_dict['nested_offset'])).to(device) |
| 604 | + if 'nested_absmax' in qs_dict: |
| 605 | + offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) |
595 | 606 | state2 = cls( |
596 | | - absmax=quant_state_dict['nested_absmax'].to(device), |
597 | | - code=quant_state_dict['nested_code'].to(device), |
598 | | - blocksize=int(quant_state_dict['nested_blocksize']), |
599 | | - dtype=getattr(torch, quant_state_dict['nested_dtype']), |
| 607 | + absmax=qs_dict['nested_absmax'].to(device), |
| 608 | + code=qs_dict['nested_code'].to(device), |
| 609 | + blocksize=qs_dict['nested_blocksize'], |
| 610 | + dtype=getattr(torch, qs_dict['nested_dtype']), |
600 | 611 | ) |
601 | 612 | else: |
602 | 613 | offset, state2 = None, None |
603 | 614 |
|
604 | 615 | quant_state = cls( |
605 | | - absmax=quant_state_dict['absmax'].to(device), |
606 | | - shape=torch.Size(map(int, quant_state_dict['shape'].split('.'))), |
607 | | - dtype=getattr(torch, quant_state_dict['dtype']), |
608 | | - blocksize=int(quant_state_dict['blocksize']), |
| 616 | + absmax=qs_dict['absmax'].to(device), |
| 617 | + shape=torch.Size(qs_dict['shape']), |
| 618 | + dtype=getattr(torch, qs_dict['dtype']), |
| 619 | + blocksize=qs_dict['blocksize'], |
609 | 620 | offset=offset, |
610 | 621 | state2=state2, |
611 | | - quant_type=quant_state_dict['quant_type'], |
612 | | - code=quant_state_dict['code'].to(device), |
| 622 | + quant_type=qs_dict['quant_type'], |
| 623 | + code=qs_dict['code'].to(device), |
613 | 624 | ) |
614 | 625 | return quant_state |
615 | 626 |
|
616 | | - def as_dict(self): |
617 | | - """dict of tensors and strings to use in serialization via _save_to_state_dict()""" |
| 627 | + def as_dict(self, packed=False): |
| 628 | + """ |
| 629 | + returns dict of tensors and strings to use in serialization via _save_to_state_dict() |
| 630 | + param: packed -- returns dict[str, torch.Tensor] for state_dict |
| 631 | + """ |
618 | 632 | qs_dict = { |
619 | 633 | 'absmax': self.absmax, |
620 | 634 | 'code': self.code, |
621 | | - 'shape': ','.join(map(str, self.shape)), |
622 | | - 'dtype': str(self.dtype).strip('torch'), |
623 | | - 'blocksize': str(self.blocksize), |
| 635 | + 'shape': tuple(self.shape), |
| 636 | + 'dtype': str(self.dtype).strip('torch.'), |
| 637 | + 'blocksize': self.blocksize, |
624 | 638 | 'quant_type': self.quant_type, |
625 | 639 | } |
626 | 640 | if self.nested: |
627 | 641 | qs_dict.update({ |
628 | 642 | 'nested_absmax': self.state2.absmax, |
629 | 643 | 'nested_code': self.state2.code, |
630 | | - 'nested_offset': f"{self.offset.item()}", |
631 | | - 'nested_blocksize': str(self.state2.blocksize), |
632 | | - 'nested_dtype': str(self.state2.dtype).strip('torch'), |
| 644 | + 'nested_offset': self.offset.item(), |
| 645 | + 'nested_blocksize': self.state2.blocksize, |
| 646 | + 'nested_dtype': str(self.state2.dtype).strip('torch.'), |
633 | 647 | }) |
634 | | - return qs_dict |
| 648 | + if not packed: |
| 649 | + return qs_dict |
| 650 | + |
| 651 | + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} |
| 652 | + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} |
| 653 | + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) |
| 654 | + return qs_packed_dict |
635 | 655 |
|
636 | 656 | def to(self, device): |
637 | 657 | # make sure the quantization state is on the right device |
|
0 commit comments