Skip to content

Commit f8f7156

Browse files
committed
key by weight only
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 131673e commit f8f7156

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
4444
self.perms = ParameterizedDefaultDict(self._create_permutation)
45+
self._shared_tensors_device = None
4546

4647
def create_transform(self, module: Module, args: TransformArgs):
4748
"""
@@ -57,20 +58,33 @@ def create_transform(self, module: Module, args: TransformArgs):
5758
device = get_offloaded_device(module)
5859
exec_device = get_execution_device(module)
5960

60-
factory_kwargs = {"construct_device": exec_device}
61-
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
61+
factory_kwargs = {"device": device, "construct_device": exec_device}
62+
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
63+
# TODO: permutations should be keyed by fused modules, not weight
6264
perm = self.perms[weight] if self.scheme.randomize else None
6365
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6466

6567
def _create_weight(
6668
self,
6769
size: int,
68-
dtype: dtype,
6970
device: device,
7071
construct_device: device,
7172
) -> Parameter:
72-
# construct on execution device, cache on offload device
73-
data = deterministic_hadamard_matrix(size, dtype, construct_device)
73+
# check that shared tensors device is consistent
74+
if self._shared_tensors_device is None:
75+
self._shared_tensors_device = device
76+
77+
if device != self._shared_tensors_device:
78+
raise NotImplementedError(
79+
"Creating multi-gpu transform weights are not supported as of now due "
80+
"to the limitations of shared tensors across GPUs."
81+
# in the future, tensors can be shared within GPUs,
82+
# and can be all-reduced during updates and compression
83+
)
84+
85+
# construct on execution device, cache shared tensor on offload device
86+
precision = self.scheme.precision
87+
data = deterministic_hadamard_matrix(size, precision, construct_device)
7488
data = data.to(device=device)
7589
return Parameter(data, requires_grad=self.scheme.requires_grad)
7690

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
4141
super().__init__(name, scheme, seed)
4242
self.weights = ParameterizedDefaultDict(self._create_weight)
4343
self.inverses = ParameterizedDefaultDict(self._create_inverse)
44+
self._shared_tensors_device = None
4445

4546
def create_transform(self, module: Module, args: TransformArgs):
4647
"""
@@ -52,19 +53,34 @@ def create_transform(self, module: Module, args: TransformArgs):
5253
"""
5354
assert hasattr(module, "weight")
5455
size = get_transform_size(module, args.location, self.scheme.head_dim)
55-
dtype = self.scheme.precision
5656
device = get_offloaded_device(module)
5757

58-
weight = self.weights[size, dtype, device]
58+
factory_kwargs = {"device": device}
59+
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
5960
if args.inverse:
6061
weight = self.inverses[weight]
6162

6263
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6364

64-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65-
# TODO: verify that weight is invertible (has non-zero determinant)
65+
def _create_weight(self, size: int, device: device) -> Parameter:
66+
# check that shared tensors device is consistent
67+
if self._shared_tensors_device is None:
68+
self._shared_tensors_device = device
69+
70+
if device != self._shared_tensors_device:
71+
raise NotImplementedError(
72+
"Creating multi-gpu transform weights are not supported as of now due "
73+
"to the limitations of shared tensors across GPUs"
74+
# in the future, tensors can be shared within GPUs,
75+
# and can be all-reduced during updates and compression
76+
)
77+
78+
# TODO: construct such that weight is invertible (has non-zero determinant)
6679
data = torch.rand(
67-
(size, size), generator=self.generator, dtype=dtype, device=device
80+
(size, size),
81+
generator=self.generator,
82+
dtype=self.scheme.precision,
83+
device=device,
6884
)
6985
return Parameter(data, requires_grad=self.scheme.requires_grad)
7086

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ class RandomHadamardFactory(HadamardFactory):
3131
def _create_weight(
3232
self,
3333
size: int,
34-
dtype: dtype,
3534
device: device,
3635
construct_device: device,
3736
) -> Parameter:
3837
# construct on execution device, cache on offload device
39-
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
38+
precision = self.scheme.precision
39+
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
4040
data = data.to(device=device)
4141
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)