Skip to content

Commit 7734cce

Browse files
authored
[Transform] Better dispatch support for offloaded and multi-gpu (#423)
* key by weight only Signed-off-by: Kyle Sayers <[email protected]> * always return on CPU, onload at runtime Signed-off-by: Kyle Sayers <[email protected]> * fix get_offloaded_device Signed-off-by: Kyle Sayers <[email protected]> * reduce diff Signed-off-by: Kyle Sayers <[email protected]> * reduce diff Signed-off-by: Kyle Sayers <[email protected]> * reduce diff Signed-off-by: Kyle Sayers <[email protected]> * move to device to support pipeline parallel Signed-off-by: Kyle Sayers <[email protected]> * eagerly generate with precision Signed-off-by: Kyle Sayers <[email protected]> * add comment Signed-off-by: Kyle Sayers <[email protected]> --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent a97308e commit 7734cce

File tree

5 files changed

+46
-31
lines changed

5 files changed

+46
-31
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21+
import tqdm
2122
from compressed_tensors.registry.registry import RegistryMixin, T
2223
from compressed_tensors.transform import (
2324
TransformArgs,
@@ -84,15 +85,21 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas
8485
"""
8586
raise NotImplementedError()
8687

87-
def apply_to_model(self, model: Module):
88+
def apply_to_model(self, model: Module, use_tqdm=True):
8889
"""
8990
Create transforms and apply them to the model
9091
9192
:param model: module to apply transforms to
9293
"""
93-
for arg in self.scheme.apply:
94-
for _, module in match_named_modules(model, arg.targets, arg.ignore):
95-
self._apply_to_module(module, arg)
94+
modules_args = [
95+
(module, arg)
96+
for arg in self.scheme.apply
97+
for _, module in match_named_modules(model, arg.targets, arg.ignore)
98+
]
99+
100+
desc = f"Applying {self.name} transforms"
101+
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102+
self._apply_to_module(module, arg)
96103

97104
self._update_tied_weights()
98105

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,28 @@ def create_transform(self, module: Module, args: TransformArgs):
5353
"""
5454
assert hasattr(module, "weight")
5555
size = get_transform_size(module, args.location, self.scheme.head_dim)
56-
dtype = self.scheme.precision
57-
device = get_offloaded_device(module)
5856
exec_device = get_execution_device(module)
59-
60-
factory_kwargs = {"construct_device": exec_device}
61-
weight = self.weights.get(size, dtype, device, factory_kwargs=factory_kwargs)
57+
device = get_offloaded_device(module)
58+
precision = self.scheme.precision if args.is_online() else torch.float64
59+
60+
factory_kwargs = {
61+
"device": device,
62+
"construct_device": exec_device,
63+
"precision": precision,
64+
}
65+
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
66+
# TODO: permutations should be keyed by fused modules, not weight
6267
perm = self.perms[weight] if self.scheme.randomize else None
6368
return HadamardTransform(weight, perm, self.scheme, args, type(module))
6469

6570
def _create_weight(
6671
self,
6772
size: int,
68-
dtype: dtype,
6973
device: device,
7074
construct_device: device,
75+
precision: dtype,
7176
) -> Parameter:
72-
# construct on execution device, cache on offload device
73-
data = deterministic_hadamard_matrix(size, dtype, construct_device)
77+
data = deterministic_hadamard_matrix(size, precision, construct_device)
7478
data = data.to(device=device)
7579
return Parameter(data, requires_grad=self.scheme.requires_grad)
7680

@@ -94,8 +98,7 @@ def __init__(
9498
self.scheme = scheme
9599
self.args = args
96100
self.module_type = module_type
97-
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98-
self._precision = scheme.precision if args.is_online() else torch.float64
101+
self._scale = torch.tensor(weight.size(0), dtype=torch.float64).sqrt()
99102

100103
def forward(self, value: Tensor) -> Tensor:
101104
weight = self.weight
@@ -108,8 +111,8 @@ def forward(self, value: Tensor) -> Tensor:
108111

109112
return (
110113
apply_transform_weight(
111-
weight.to(self._precision),
112-
value.to(self._precision),
114+
weight.to(device=value.device),
115+
value.to(dtype=weight.dtype),
113116
self.args.location,
114117
self.module_type,
115118
)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
apply_transform_weight,
2222
get_transform_size,
2323
)
24-
from compressed_tensors.utils import get_offloaded_device
2524
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
25+
from compressed_tensors.utils.offload import get_offloaded_device
2626
from torch import Tensor, device, dtype
2727
from torch.nn import Module, Parameter
2828

@@ -52,19 +52,23 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
"""
5353
assert hasattr(module, "weight")
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
55-
dtype = self.scheme.precision
5655
device = get_offloaded_device(module)
56+
precision = self.scheme.precision if args.is_online() else torch.float64
5757

58-
weight = self.weights[size, dtype, device]
58+
factory_kwargs = {"device": device, "precision": precision}
59+
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
5960
if args.inverse:
6061
weight = self.inverses[weight]
6162

6263
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6364

64-
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
65-
# TODO: verify that weight is invertible (has non-zero determinant)
65+
def _create_weight(self, size: int, device: device, precision: dtype) -> Parameter:
66+
# TODO: construct such that weight is invertible (has non-zero determinant)
6667
data = torch.rand(
67-
(size, size), generator=self.generator, dtype=dtype, device=device
68+
(size, size),
69+
generator=self.generator,
70+
dtype=precision,
71+
device=device,
6872
)
6973
return Parameter(data, requires_grad=self.scheme.requires_grad)
7074

@@ -87,21 +91,20 @@ def __init__(
8791
self.scheme = scheme
8892
self.args = args
8993
self.module_type = module_type
90-
self._precision = scheme.precision if args.is_online() else torch.float64
9194

9295
def forward(self, value: Tensor) -> Parameter:
9396
return apply_transform_weight(
94-
self.weight.to(self._precision),
95-
value.to(self._precision),
97+
self.weight.to(device=value.device),
98+
value.to(dtype=self.weight.dtype),
9699
self.args.location,
97100
self.module_type,
98101
).to(value.dtype)
99102

100103
def right_inverse(self, value: Tensor) -> Tensor:
101104
inverse = high_precision_invert(self.weight)
102105
return apply_transform_weight(
103-
inverse.to(self._precision),
104-
value.to(self._precision),
106+
inverse.to(device=value.device),
107+
value.to(dtype=inverse.dtype),
105108
self.args.location,
106109
self.module_type,
107110
).to(value.dtype)

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@ class RandomHadamardFactory(HadamardFactory):
3131
def _create_weight(
3232
self,
3333
size: int,
34-
dtype: dtype,
3534
device: device,
3635
construct_device: device,
36+
precision: dtype,
3737
) -> Parameter:
38-
# construct on execution device, cache on offload device
39-
data = random_hadamard_matrix(size, dtype, construct_device, self.generator)
38+
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
4039
data = data.to(device=device)
4140
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/utils/offload.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
131131
first_key = list(module._hf_hook.weights_map.keys())[0]
132132
prefix_dataset = module._hf_hook.weights_map.dataset
133133
return prefix_dataset[first_key].device
134-
return next(module.parameters()).device
134+
else:
135+
# if the module is not offloaded, then any addded weights
136+
# should be placed the module's execution device
137+
return get_execution_device(module)
135138

136139

137140
@check_accelerate(fallback=None)

0 commit comments

Comments
 (0)