Skip to content

Commit 0914f6f

Browse files
committed
eagerly generate with precision
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f485af6 commit 0914f6f

File tree

3 files changed

+22
-18
lines changed

3 files changed

+22
-18
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
27-
from torch import Tensor, device
27+
from torch import Tensor, device, dtype
2828
from torch.nn import Module, Parameter
2929

3030

@@ -55,8 +55,13 @@ def create_transform(self, module: Module, args: TransformArgs):
5555
size = get_transform_size(module, args.location, self.scheme.head_dim)
5656
exec_device = get_execution_device(module)
5757
device = get_offloaded_device(module)
58+
precision = self.scheme.precision if args.is_online() else torch.float64
5859

59-
factory_kwargs = {"device": device, "construct_device": exec_device}
60+
factory_kwargs = {
61+
"device": device,
62+
"construct_device": exec_device,
63+
"precision": precision,
64+
}
6065
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
6166
# TODO: permutations should be keyed by fused modules, not weight
6267
perm = self.perms[weight] if self.scheme.randomize else None
@@ -67,8 +72,8 @@ def _create_weight(
6772
size: int,
6873
device: device,
6974
construct_device: device,
75+
precision: dtype,
7076
) -> Parameter:
71-
precision = self.scheme.precision
7277
data = deterministic_hadamard_matrix(size, precision, construct_device)
7378
data = data.to(device=device)
7479
return Parameter(data, requires_grad=self.scheme.requires_grad)
@@ -93,8 +98,7 @@ def __init__(
9398
self.scheme = scheme
9499
self.args = args
95100
self.module_type = module_type
96-
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
97-
self._precision = scheme.precision if args.is_online() else torch.float64
101+
self._scale = torch.tensor(weight.size(0), dtype=torch.float64).sqrt()
98102

99103
def forward(self, value: Tensor) -> Tensor:
100104
weight = self.weight
@@ -107,8 +111,8 @@ def forward(self, value: Tensor) -> Tensor:
107111

108112
return (
109113
apply_transform_weight(
110-
weight.to(dtype=self._precision, device=value.device),
111-
value.to(self._precision),
114+
weight.to(device=value.device),
115+
value.to(dtype=weight.dtype),
112116
self.args.location,
113117
self.module_type,
114118
)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2525
from compressed_tensors.utils.offload import get_offloaded_device
26-
from torch import Tensor, device
26+
from torch import Tensor, device, dtype
2727
from torch.nn import Module, Parameter
2828

2929

@@ -53,20 +53,21 @@ def create_transform(self, module: Module, args: TransformArgs):
5353
assert hasattr(module, "weight")
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
5555
device = get_offloaded_device(module)
56+
precision = self.scheme.precision if args.is_online() else torch.float64
5657

57-
factory_kwargs = {"device": device}
58+
factory_kwargs = {"device": device, "precision": precision}
5859
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
5960
if args.inverse:
6061
weight = self.inverses[weight]
6162

6263
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6364

64-
def _create_weight(self, size: int, device: device) -> Parameter:
65+
def _create_weight(self, size: int, device: device, precision: dtype) -> Parameter:
6566
# TODO: construct such that weight is invertible (has non-zero determinant)
6667
data = torch.rand(
6768
(size, size),
6869
generator=self.generator,
69-
dtype=self.scheme.precision,
70+
dtype=precision,
7071
device=device,
7172
)
7273
return Parameter(data, requires_grad=self.scheme.requires_grad)
@@ -90,21 +91,20 @@ def __init__(
9091
self.scheme = scheme
9192
self.args = args
9293
self.module_type = module_type
93-
self._precision = scheme.precision if args.is_online() else torch.float64
9494

9595
def forward(self, value: Tensor) -> Parameter:
9696
return apply_transform_weight(
97-
self.weight.to(dtype=self._precision, device=value.device),
98-
value.to(self._precision),
97+
self.weight.to(device=value.device),
98+
value.to(dtype=self.weight.dtype),
9999
self.args.location,
100100
self.module_type,
101101
).to(value.dtype)
102102

103103
def right_inverse(self, value: Tensor) -> Tensor:
104104
inverse = high_precision_invert(self.weight)
105105
return apply_transform_weight(
106-
inverse.to(dtype=self._precision, device=value.device),
107-
value.to(self._precision),
106+
inverse.to(device=value.device),
107+
value.to(dtype=inverse.dtype),
108108
self.args.location,
109109
self.module_type,
110110
).to(value.dtype)

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from compressed_tensors.transform import HadamardFactory, TransformFactory
1616
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
17-
from torch import device
17+
from torch import device, dtype
1818
from torch.nn import Parameter
1919

2020

@@ -33,8 +33,8 @@ def _create_weight(
3333
size: int,
3434
device: device,
3535
construct_device: device,
36+
precision: dtype,
3637
) -> Parameter:
37-
precision = self.scheme.precision
3838
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
3939
data = data.to(device=device)
4040
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)