@@ -41,6 +41,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
41
41
super ().__init__ (name , scheme , seed )
42
42
self .weights = ParameterizedDefaultDict (self ._create_weight )
43
43
self .inverses = ParameterizedDefaultDict (self ._create_inverse )
44
+ self ._shared_tensors_device = None
44
45
45
46
def create_transform (self , module : Module , args : TransformArgs ):
46
47
"""
@@ -52,19 +53,34 @@ def create_transform(self, module: Module, args: TransformArgs):
52
53
"""
53
54
assert hasattr (module , "weight" )
54
55
size = get_transform_size (module , args .location , self .scheme .head_dim )
55
- dtype = self .scheme .precision
56
56
device = get_offloaded_device (module )
57
57
58
- weight = self .weights [size , dtype , device ]
58
+ factory_kwargs = {"device" : device }
59
+ weight = self .weights .get (size , factory_kwargs = factory_kwargs )
59
60
if args .inverse :
60
61
weight = self .inverses [weight ]
61
62
62
63
return RandomMatrixTransform (weight , self .scheme , args , type (module ))
63
64
64
- def _create_weight (self , size : int , dtype : dtype , device : device ) -> Parameter :
65
- # TODO: verify that weight is invertible (has non-zero determinant)
65
+ def _create_weight (self , size : int , device : device ) -> Parameter :
66
+ # check that shared tensors device is consistent
67
+ if self ._shared_tensors_device is None :
68
+ self ._shared_tensors_device = device
69
+
70
+ if device != self ._shared_tensors_device :
71
+ raise NotImplementedError (
72
+ "Creating multi-gpu transform weights are not supported as of now due "
73
+ "to the limitations of shared tensors across GPUs"
74
+ # in the future, tensors can be shared within GPUs,
75
+ # and can be all-reduced during updates and compression
76
+ )
77
+
78
+ # TODO: construct such that weight is invertible (has non-zero determinant)
66
79
data = torch .rand (
67
- (size , size ), generator = self .generator , dtype = dtype , device = device
80
+ (size , size ),
81
+ generator = self .generator ,
82
+ dtype = self .scheme .precision ,
83
+ device = device ,
68
84
)
69
85
return Parameter (data , requires_grad = self .scheme .requires_grad )
70
86
0 commit comments