Skip to content

Commit 6929f16

Browse files
committed
always return on CPU, onload at runtime
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f8f7156 commit 6929f16

File tree

4 files changed

+16
-46
lines changed

4 files changed

+16
-46
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Tuple, Set
17+
from typing import List, Optional, Set, Tuple
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_transform_size,
2424
)
25-
from compressed_tensors.utils import get_execution_device, get_offloaded_device
25+
from compressed_tensors.utils import get_execution_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Module, Parameter
@@ -42,7 +42,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
4242
super().__init__(name, scheme, seed)
4343
self.weights = ParameterizedDefaultDict(self._create_weight)
4444
self.perms = ParameterizedDefaultDict(self._create_permutation)
45-
self._shared_tensors_device = None
4645

4746
def create_transform(self, module: Module, args: TransformArgs):
4847
"""
@@ -54,11 +53,9 @@ def create_transform(self, module: Module, args: TransformArgs):
5453
"""
5554
assert hasattr(module, "weight")
5655
size = get_transform_size(module, args.location, self.scheme.head_dim)
57-
dtype = self.scheme.precision
58-
device = get_offloaded_device(module)
5956
exec_device = get_execution_device(module)
6057

61-
factory_kwargs = {"device": device, "construct_device": exec_device}
58+
factory_kwargs = {"construct_device": exec_device}
6259
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
6360
# TODO: permutations should be keyed by fused modules, not weight
6461
perm = self.perms[weight] if self.scheme.randomize else None
@@ -67,25 +64,12 @@ def create_transform(self, module: Module, args: TransformArgs):
6764
def _create_weight(
6865
self,
6966
size: int,
70-
device: device,
7167
construct_device: device,
7268
) -> Parameter:
73-
# check that shared tensors device is consistent
74-
if self._shared_tensors_device is None:
75-
self._shared_tensors_device = device
76-
77-
if device != self._shared_tensors_device:
78-
raise NotImplementedError(
79-
"Creating multi-gpu transform weights are not supported as of now due "
80-
"to the limitations of shared tensors across GPUs."
81-
# in the future, tensors can be shared within GPUs,
82-
# and can be all-reduced during updates and compression
83-
)
84-
85-
# construct on execution device, cache shared tensor on offload device
69+
# construct on execution device, cache shared tensor on cpu
8670
precision = self.scheme.precision
8771
data = deterministic_hadamard_matrix(size, precision, construct_device)
88-
data = data.to(device=device)
72+
data = data.to(device="cpu")
8973
return Parameter(data, requires_grad=self.scheme.requires_grad)
9074

9175
def _create_permutation(self, weight: Parameter) -> Parameter:
@@ -120,9 +104,10 @@ def forward(self, value: Tensor) -> Tensor:
120104
if self.args.inverse:
121105
weight = weight.T
122106

107+
# onloading is done by accelerate if parent module is offloaded
123108
return (
124109
apply_transform_weight(
125-
weight.to(self._precision),
110+
weight.to(dtype=self._precision, device=value.device),
126111
value.to(self._precision),
127112
self.args.location,
128113
self.module_type,

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
apply_transform_weight,
2222
get_transform_size,
2323
)
24-
from compressed_tensors.utils import get_offloaded_device
2524
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
26-
from torch import Tensor, device, dtype
25+
from torch import Tensor
2726
from torch.nn import Module, Parameter
2827

2928

@@ -53,34 +52,20 @@ def create_transform(self, module: Module, args: TransformArgs):
5352
"""
5453
assert hasattr(module, "weight")
5554
size = get_transform_size(module, args.location, self.scheme.head_dim)
56-
device = get_offloaded_device(module)
5755

58-
factory_kwargs = {"device": device}
59-
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
56+
weight = self.weights.get(size)
6057
if args.inverse:
6158
weight = self.inverses[weight]
6259

6360
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6461

65-
def _create_weight(self, size: int, device: device) -> Parameter:
66-
# check that shared tensors device is consistent
67-
if self._shared_tensors_device is None:
68-
self._shared_tensors_device = device
69-
70-
if device != self._shared_tensors_device:
71-
raise NotImplementedError(
72-
"Creating multi-gpu transform weights are not supported as of now due "
73-
"to the limitations of shared tensors across GPUs"
74-
# in the future, tensors can be shared within GPUs,
75-
# and can be all-reduced during updates and compression
76-
)
77-
62+
def _create_weight(self, size: int) -> Parameter:
7863
# TODO: construct such that weight is invertible (has non-zero determinant)
7964
data = torch.rand(
8065
(size, size),
8166
generator=self.generator,
8267
dtype=self.scheme.precision,
83-
device=device,
68+
device=torch.device("cpu"),
8469
)
8570
return Parameter(data, requires_grad=self.scheme.requires_grad)
8671

@@ -106,8 +91,9 @@ def __init__(
10691
self._precision = scheme.precision if args.is_online() else torch.float64
10792

10893
def forward(self, value: Tensor) -> Parameter:
94+
# onloading is done by accelerate if parent module is offloaded
10995
return apply_transform_weight(
110-
self.weight.to(self._precision),
96+
self.weight.to(dtype=self._precision, device=value.device),
11197
value.to(self._precision),
11298
self.args.location,
11399
self.module_type,
@@ -116,7 +102,7 @@ def forward(self, value: Tensor) -> Parameter:
116102
def right_inverse(self, value: Tensor) -> Tensor:
117103
inverse = high_precision_invert(self.weight)
118104
return apply_transform_weight(
119-
inverse.to(self._precision),
105+
inverse.to(dtype=self._precision, device=value.device),
120106
value.to(self._precision),
121107
self.args.location,
122108
self.module_type,

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from compressed_tensors.transform import HadamardFactory, TransformFactory
1616
from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix
17-
from torch import device, dtype
17+
from torch import device
1818
from torch.nn import Parameter
1919

2020

@@ -31,11 +31,10 @@ class RandomHadamardFactory(HadamardFactory):
3131
def _create_weight(
3232
self,
3333
size: int,
34-
device: device,
3534
construct_device: device,
3635
) -> Parameter:
3736
# construct on execution device, cache on offload device
3837
precision = self.scheme.precision
3938
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
40-
data = data.to(device=device)
39+
data = data.to(device="cpu")
4140
return Parameter(data, requires_grad=self.scheme.requires_grad)

0 commit comments

Comments
 (0)