Skip to content

Commit 13cb9e3

Browse files
committed
fix get_offloaded_device
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6929f16 commit 13cb9e3

File tree

5 files changed

+30
-20
lines changed

5 files changed

+30
-20
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21+
import tqdm
2122
from compressed_tensors import InternalModule
2223
from compressed_tensors.registry.registry import RegistryMixin, T
2324
from compressed_tensors.transform import (
@@ -84,15 +85,21 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas
8485
"""
8586
raise NotImplementedError()
8687

87-
def apply_to_model(self, model: Module):
88+
def apply_to_model(self, model: Module, use_tqdm=True):
8889
"""
8990
Create transforms and apply them to the model
9091
9192
:param model: module to apply transforms to
9293
"""
93-
for arg in self.scheme.apply:
94-
for _, module in match_named_modules(model, arg.targets, arg.ignore):
95-
self._apply_to_module(module, arg)
94+
modules_args = [
95+
(module, arg)
96+
for arg in self.scheme.apply
97+
for _, module in match_named_modules(model, arg.targets, arg.ignore)
98+
]
99+
100+
desc = f"Applying {self.name} transforms"
101+
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102+
self._apply_to_module(module, arg)
96103

97104
self._update_tied_weights()
98105

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
apply_transform_weight,
2323
get_transform_size,
2424
)
25-
from compressed_tensors.utils import get_execution_device
25+
from compressed_tensors.utils import get_execution_device, get_offloaded_device
2626
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
2727
from torch import Tensor, device, dtype
2828
from torch.nn import Module, Parameter
@@ -54,8 +54,9 @@ def create_transform(self, module: Module, args: TransformArgs):
5454
assert hasattr(module, "weight")
5555
size = get_transform_size(module, args.location, self.scheme.head_dim)
5656
exec_device = get_execution_device(module)
57+
device = get_offloaded_device(module)
5758

58-
factory_kwargs = {"construct_device": exec_device}
59+
factory_kwargs = {"device": device, "construct_device": exec_device}
5960
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
6061
# TODO: permutations should be keyed by fused modules, not weight
6162
perm = self.perms[weight] if self.scheme.randomize else None
@@ -64,12 +65,12 @@ def create_transform(self, module: Module, args: TransformArgs):
6465
def _create_weight(
6566
self,
6667
size: int,
68+
device: device,
6769
construct_device: device,
6870
) -> Parameter:
69-
# construct on execution device, cache shared tensor on cpu
7071
precision = self.scheme.precision
7172
data = deterministic_hadamard_matrix(size, precision, construct_device)
72-
data = data.to(device="cpu")
73+
data = data.to(device=device)
7374
return Parameter(data, requires_grad=self.scheme.requires_grad)
7475

7576
def _create_permutation(self, weight: Parameter) -> Parameter:
@@ -104,10 +105,9 @@ def forward(self, value: Tensor) -> Tensor:
104105
if self.args.inverse:
105106
weight = weight.T
106107

107-
# onloading is done by accelerate if parent module is offloaded
108108
return (
109109
apply_transform_weight(
110-
weight.to(dtype=self._precision, device=value.device),
110+
weight.to(dtype=self._precision),
111111
value.to(self._precision),
112112
self.args.location,
113113
self.module_type,

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
get_transform_size,
2323
)
2424
from compressed_tensors.utils.helpers import ParameterizedDefaultDict
25-
from torch import Tensor
25+
from compressed_tensors.utils.offload import get_offloaded_device
26+
from torch import Tensor, device
2627
from torch.nn import Module, Parameter
2728

2829

@@ -52,20 +53,22 @@ def create_transform(self, module: Module, args: TransformArgs):
5253
"""
5354
assert hasattr(module, "weight")
5455
size = get_transform_size(module, args.location, self.scheme.head_dim)
56+
device = get_offloaded_device(module)
5557

56-
weight = self.weights.get(size)
58+
factory_kwargs = {"device": device}
59+
weight = self.weights.get(size, factory_kwargs=factory_kwargs)
5760
if args.inverse:
5861
weight = self.inverses[weight]
5962

6063
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6164

62-
def _create_weight(self, size: int) -> Parameter:
65+
def _create_weight(self, size: int, device: device) -> Parameter:
6366
# TODO: construct such that weight is invertible (has non-zero determinant)
6467
data = torch.rand(
6568
(size, size),
6669
generator=self.generator,
6770
dtype=self.scheme.precision,
68-
device=torch.device("cpu"),
71+
device=device,
6972
)
7073
return Parameter(data, requires_grad=self.scheme.requires_grad)
7174

@@ -91,9 +94,8 @@ def __init__(
9194
self._precision = scheme.precision if args.is_online() else torch.float64
9295

9396
def forward(self, value: Tensor) -> Parameter:
94-
# onloading is done by accelerate if parent module is offloaded
9597
return apply_transform_weight(
96-
self.weight.to(dtype=self._precision, device=value.device),
98+
self.weight.to(dtype=self._precision),
9799
value.to(self._precision),
98100
self.args.location,
99101
self.module_type,
@@ -102,7 +104,7 @@ def forward(self, value: Tensor) -> Parameter:
102104
def right_inverse(self, value: Tensor) -> Tensor:
103105
inverse = high_precision_invert(self.weight)
104106
return apply_transform_weight(
105-
inverse.to(dtype=self._precision, device=value.device),
107+
inverse.to(dtype=self._precision),
106108
value.to(self._precision),
107109
self.args.location,
108110
self.module_type,

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ class RandomHadamardFactory(HadamardFactory):
3131
def _create_weight(
3232
self,
3333
size: int,
34+
device: device,
3435
construct_device: device,
3536
) -> Parameter:
36-
# construct on execution device, cache on offload device
3737
precision = self.scheme.precision
3838
data = random_hadamard_matrix(size, precision, construct_device, self.generator)
39-
data = data.to(device="cpu")
39+
data = data.to(device=device)
4040
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/utils/offload.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
130130
first_key = list(module._hf_hook.weights_map.keys())[0]
131131
prefix_dataset = module._hf_hook.weights_map.dataset
132132
return prefix_dataset[first_key].device
133-
return next(module.parameters()).device
133+
else:
134+
return get_execution_device(module)
134135

135136

136137
@check_accelerate(fallback=None)

0 commit comments

Comments
 (0)