24
24
from compressed_tensors .utils import get_offloaded_device
25
25
from compressed_tensors .utils .helpers import ParameterizedDefaultDict
26
26
from torch import Tensor , device , dtype
27
- from torch .nn import Linear , Module , Parameter
27
+ from torch .nn import Module , Parameter
28
28
29
29
30
30
@TransformFactory .register ("random-matrix" )
@@ -52,14 +52,14 @@ def create_transform(self, module: Module, args: TransformArgs):
52
52
"""
53
53
assert hasattr (module , "weight" )
54
54
size = get_transform_size (module , args .location , self .scheme .head_dim )
55
- dtype = module . weight . dtype
55
+ dtype = self . scheme . precision
56
56
device = get_offloaded_device (module )
57
57
58
58
weight = self .weights [size , dtype , device ]
59
59
if args .inverse :
60
60
weight = self .inverses [weight ]
61
61
62
- return RandomMatrixTransform (weight , args , type (module ))
62
+ return RandomMatrixTransform (weight , self . scheme , args , type (module ))
63
63
64
64
def _create_weight (self , size : int , dtype : dtype , device : device ) -> Parameter :
65
65
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -78,25 +78,34 @@ class RandomMatrixTransform(TransformBase):
78
78
def __init__ (
79
79
self ,
80
80
weight : Tensor ,
81
+ scheme : TransformScheme ,
81
82
args : TransformArgs ,
82
83
module_type : type [torch .nn .Module ],
83
84
):
84
85
super ().__init__ ()
85
86
self .weight = weight # is an inverse if args.inverse
87
+ self .scheme = scheme
86
88
self .args = args
87
89
self .module_type = module_type
90
+ self ._precision = scheme .precision if args .is_online () else torch .float64
88
91
89
92
def forward (self , value : Tensor ) -> Parameter :
90
93
return apply_transform_weight (
91
- self .weight , value , self .args .location , self .module_type
92
- )
94
+ self .weight .to (self ._precision ),
95
+ value .to (self ._precision ),
96
+ self .args .location ,
97
+ self .module_type ,
98
+ ).to (value .dtype )
93
99
94
100
def right_inverse (self , value : Tensor ) -> Tensor :
95
101
inverse = high_precision_invert (self .weight )
96
102
return apply_transform_weight (
97
- inverse , value , self .args .location , self .module_type
98
- )
103
+ inverse .to (self ._precision ),
104
+ value .to (self ._precision ),
105
+ self .args .location ,
106
+ self .module_type ,
107
+ ).to (value .dtype )
99
108
100
109
101
110
def high_precision_invert (weight : Tensor ) -> Tensor :
102
- return torch .linalg .inv (weight .to (torch .float32 )).to (weight .dtype )
111
+ return torch .linalg .inv (weight .to (torch .float64 )).to (weight .dtype )
0 commit comments