File tree Expand file tree Collapse file tree 1 file changed +16
-3
lines changed
src/compressed_tensors/transform/utils Expand file tree Collapse file tree 1 file changed +16
-3
lines changed Original file line number Diff line number Diff line change 21
21
__all__ = ["get_transform_size" , "apply_transform_weight" ]
22
22
23
23
24
+ TRANSFORM_PRECISION = torch .float64
25
+
26
+
24
27
def get_transform_size (
25
28
module : torch .nn .Module ,
26
29
location : TransformLocation ,
@@ -77,14 +80,24 @@ def apply_transform_weight(
77
80
The transform has to account for Linear's transposed weights
78
81
:return: value after weight has been applied
79
82
"""
83
+ # get function used to apply transform
80
84
fn , axis = _get_transform_method (module_type , location )
81
85
82
- assert weight . shape [ 0 ] == weight . shape [ 1 ]
86
+ # reshape for head_dim
83
87
head_dim = weight .shape [0 ]
84
88
num_heads = value .shape [axis ] // head_dim
85
-
86
89
value = value .unflatten (axis , (num_heads , head_dim ))
87
- value = fn (weight , value )
90
+
91
+ # cast to transform precision
92
+ value_dtype = value .dtype
93
+
94
+ # apply transform
95
+ value = fn (weight .to (TRANSFORM_PRECISION ), value .to (TRANSFORM_PRECISION ))
96
+
97
+ # [undo] cast to transform precision
98
+ value = value .to (value_dtype )
99
+
100
+ # [undo] reshape for head_dim
88
101
value = value .flatten (axis - 1 , axis )
89
102
90
103
return value
You can’t perform that action at this time.
0 commit comments