@@ -53,19 +53,19 @@ def create_transform(self, module: Module, args: TransformArgs):
53
53
dtype = module .weight .dtype
54
54
device = get_offloaded_device (module )
55
55
56
- if not args . inverse :
57
- weight = self . weights [ size , dtype , device ]
58
- else :
59
- weight = self . inverses [ size , dtype , device ]
56
+ weight = self . weights [ size , dtype , device ]
57
+ if args . inverse :
58
+ weight = self . inverses [ weight ]
59
+
60
60
return RandomMatrixTransform (weight , args )
61
61
62
62
def _create_weight (self , size : int , dtype : dtype , device : device ) -> Parameter :
63
63
data = torch .rand ((size , size ), dtype = dtype , device = device )
64
64
return Parameter (data , requires_grad = self .scheme .requires_grad )
65
65
66
- def _create_inverse (self , size : int , dtype : dtype , device : device ) -> Parameter :
67
- weight = self . weights [ size , dtype , device ]
68
- return Parameter (high_precision_invert ( weight . data ) , requires_grad = False )
66
+ def _create_inverse (self , weight : Parameter ) -> Parameter :
67
+ data = high_precision_invert ( weight . data )
68
+ return Parameter (data , requires_grad = False )
69
69
70
70
71
71
class RandomMatrixTransform (TransformBase ):
0 commit comments