17
17
import torch
18
18
from compressed_tensors .transform import TransformArgs
19
19
from compressed_tensors .utils import TorchDtype
20
- from pydantic import BaseModel , ConfigDict , Field
20
+ from pydantic import BaseModel , ConfigDict , Field , model_validator
21
21
22
22
23
23
__all__ = ["TransformScheme" ]
@@ -36,6 +36,8 @@ class TransformScheme(BaseModel):
36
36
:param randomize: True if uniquely randomized transform weights should be used,
37
37
otherwise use identical transform weights where applicable
38
38
:param requires_grad: True if weights include gradients for training
39
+ :param block_size: If set, the transform matrix will be block diagonal, with each
40
+ block being a square matrix of this size.
39
41
:param precision: Precision at which this transform should be applied during online
40
42
rotations. Fused (offline) rotations are always performed in float64
41
43
"""
@@ -44,7 +46,21 @@ class TransformScheme(BaseModel):
44
46
apply : List [TransformArgs ] = Field (default_factory = list )
45
47
randomize : bool = Field (default = False )
46
48
requires_grad : bool = Field (default = False )
47
- head_dim : Optional [int ] = Field (default = None )
49
+ block_size : Optional [int ] = Field (default = None )
50
+ head_dim : Optional [int ] = Field (
51
+ default = None , deprecated = "head_dim is deprecated, use block_size instead"
52
+ )
48
53
precision : TorchDtype = Field (default = torch .float32 )
49
54
55
+ @model_validator (mode = "after" )
56
+ def validate_model_after (model : "TransformScheme" ) -> "TransformScheme" :
57
+ """
58
+ If head_dim is used instead of block_size, set block_size to head_dim
59
+ and remove head_dim
60
+ """
61
+ if model .block_size is None and model .head_dim is not None :
62
+ model .block_size = model .head_dim
63
+ model .head_dim = None
64
+ return model
65
+
50
66
model_config = ConfigDict (extra = "forbid" )
0 commit comments