22
22
apply_transform_weight ,
23
23
get_transform_size ,
24
24
)
25
- from compressed_tensors .utils import get_execution_device , get_offloaded_device
25
+ from compressed_tensors .utils import get_execution_device
26
26
from compressed_tensors .utils .helpers import ParameterizedDefaultDict
27
27
from torch import Tensor , device , dtype
28
28
from torch .nn import Module , Parameter
@@ -42,7 +42,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
42
42
super ().__init__ (name , scheme , seed )
43
43
self .weights = ParameterizedDefaultDict (self ._create_weight )
44
44
self .perms = ParameterizedDefaultDict (self ._create_permutation )
45
- self ._shared_tensors_device = None
46
45
47
46
def create_transform (self , module : Module , args : TransformArgs ):
48
47
"""
@@ -54,11 +53,9 @@ def create_transform(self, module: Module, args: TransformArgs):
54
53
"""
55
54
assert hasattr (module , "weight" )
56
55
size = get_transform_size (module , args .location , self .scheme .head_dim )
57
- dtype = self .scheme .precision
58
- device = get_offloaded_device (module )
59
56
exec_device = get_execution_device (module )
60
57
61
- factory_kwargs = {"device" : device , " construct_device" : exec_device }
58
+ factory_kwargs = {"construct_device" : exec_device }
62
59
weight = self .weights .get (size , factory_kwargs = factory_kwargs )
63
60
# TODO: permutations should be keyed by fused modules, not weight
64
61
perm = self .perms [weight ] if self .scheme .randomize else None
@@ -67,25 +64,12 @@ def create_transform(self, module: Module, args: TransformArgs):
67
64
def _create_weight (
68
65
self ,
69
66
size : int ,
70
- device : device ,
71
67
construct_device : device ,
72
68
) -> Parameter :
73
- # check that shared tensors device is consistent
74
- if self ._shared_tensors_device is None :
75
- self ._shared_tensors_device = device
76
-
77
- if device != self ._shared_tensors_device :
78
- raise NotImplementedError (
79
- "Creating multi-gpu transform weights are not supported as of now due "
80
- "to the limitations of shared tensors across GPUs."
81
- # in the future, tensors can be shared within GPUs,
82
- # and can be all-reduced during updates and compression
83
- )
84
-
85
- # construct on execution device, cache shared tensor on offload device
69
+ # construct on execution device, cache shared tensor on cpu
86
70
precision = self .scheme .precision
87
71
data = deterministic_hadamard_matrix (size , precision , construct_device )
88
- data = data .to (device = device )
72
+ data = data .to (device = "cpu" )
89
73
return Parameter (data , requires_grad = self .scheme .requires_grad )
90
74
91
75
def _create_permutation (self , weight : Parameter ) -> Parameter :
@@ -120,9 +104,10 @@ def forward(self, value: Tensor) -> Tensor:
120
104
if self .args .inverse :
121
105
weight = weight .T
122
106
107
+ # onloading is done by accelerate if parent module is offloaded
123
108
return (
124
109
apply_transform_weight (
125
- weight .to (self ._precision ),
110
+ weight .to (dtype = self ._precision , device = value . device ),
126
111
value .to (self ._precision ),
127
112
self .args .location ,
128
113
self .module_type ,
0 commit comments