Skip to content

Commit 5db0e13

Browse files
committed
cleanup, construct on dtype, change default
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 90ea08f commit 5db0e13

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,14 @@ def create_transform(self, module: Module, args: TransformArgs):
5353
"""
5454
assert hasattr(module, "weight")
5555
size = get_transform_size(module, args.location, self.scheme.head_dim)
56-
dtype = module.weight.dtype
56+
dtype = self.scheme.precision
5757
device = get_offloaded_device(module)
5858
exec_device = get_execution_device(module)
59-
precision = self.scheme.precision
6059

6160
factory_kwargs = {"construct_device": exec_device}
6261
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
6362
perm = self.perms[weight] if self.scheme.randomize else None
64-
return HadamardTransform(weight, perm, args, precision, type(module))
63+
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6564

6665
def _create_weight(
6766
self,
@@ -85,17 +84,17 @@ def __init__(
8584
self,
8685
weight: Parameter,
8786
perm: Optional[Parameter],
87+
scheme: TransformScheme,
8888
args: TransformArgs,
89-
precision: torch.dtype,
9089
module_type: type[torch.nn.Module],
9190
):
9291
super().__init__()
9392
self.weight = weight
9493
self.perm = perm
94+
self.scheme = scheme
9595
self.args = args
96-
self.precision = precision
9796
self.module_type = module_type
98-
self._scale = torch.tensor(weight.size(0), dtype=self.precision).sqrt()
97+
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
9998

10099
def forward(self, value: Tensor) -> Tensor:
101100
weight = self.weight
@@ -108,8 +107,8 @@ def forward(self, value: Tensor) -> Tensor:
108107

109108
return (
110109
apply_transform_weight(
111-
weight.to(self.precision),
112-
value.to(self.precision),
110+
weight.to(self.scheme.precision),
111+
value.to(self.scheme.precision),
113112
self.args.location,
114113
self.module_type,
115114
)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from compressed_tensors.utils import get_offloaded_device
2525
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2626
from torch import Tensor, device, dtype
27-
from torch.nn import Linear, Module, Parameter
27+
from torch.nn import Module, Parameter
2828

2929

3030
@TransformFactory.register("random-matrix")
@@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
"""
5353
assert hasattr(module, "weight")
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
55-
dtype = module.weight.dtype
55+
dtype = self.scheme.precision
5656
device = get_offloaded_device(module)
5757
precision = self.scheme.precision
5858

@@ -78,29 +78,29 @@ class RandomMatrixTransform(TransformBase):
7878
def __init__(
7979
self,
8080
weight: Tensor,
81+
scheme: TransformScheme,
8182
args: TransformArgs,
82-
precision: torch.dtype,
8383
module_type: type[torch.nn.Module],
8484
):
8585
super().__init__()
8686
self.weight = weight # is an inverse if args.inverse
87+
self.scheme = scheme
8788
self.args = args
88-
self.precision = precision
8989
self.module_type = module_type
9090

9191
def forward(self, value: Tensor) -> Parameter:
9292
return apply_transform_weight(
93-
self.weight.to(self.precision),
94-
value.to(self.precision),
93+
self.weight.to(self.scheme.precision),
94+
value.to(self.scheme.precision),
9595
self.args.location,
9696
self.module_type,
9797
).to(value.dtype)
9898

9999
def right_inverse(self, value: Tensor) -> Tensor:
100100
inverse = high_precision_invert(self.weight)
101101
return apply_transform_weight(
102-
inverse.to(self.precision),
103-
value.to(self.precision),
102+
inverse.to(self.scheme.precision),
103+
value.to(self.scheme.precision),
104104
self.args.location,
105105
self.module_type,
106106
).to(value.dtype)

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ class TransformScheme(BaseModel):
3636
:param randomize: True if uniquely randomized transform weights should be used,
3737
otherwise use identical transform weights where applicable
3838
:param requires_grad: True if weights include gradients for training
39+
:param precision: Precision at which this transform should be applied. This applies
40+
to both weight fusing and online rotations
3941
"""
4042

4143
type: str
4244
apply: List[TransformArgs] = Field(default_factory=list)
4345
randomize: bool = Field(default=False)
4446
requires_grad: bool = Field(default=False)
4547
head_dim: Optional[int] = Field(default=None)
46-
precision: TorchDtype = Field(default=torch.bfloat16)
48+
precision: TorchDtype = Field(default=torch.float32)

0 commit comments

Comments
 (0)