12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Optional
15
+ from typing import Optional , Tuple , Callable
16
16
17
17
import torch
18
18
from compressed_tensors .transform import TransformLocation
@@ -42,7 +42,8 @@ def get_matrix_size(
42
42
size = module .out_features
43
43
44
44
if head_dim is not None :
45
- assert size % head_dim == 0
45
+ if size % head_dim != 0 :
46
+ raise ValueError ("Cannot " )
46
47
return head_dim
47
48
48
49
else :
@@ -53,18 +54,35 @@ def apply_transform_weight(
53
54
weight : torch .Tensor ,
54
55
value : torch .Tensor ,
55
56
location : TransformLocation ,
57
+ module_type : type [torch .nn .Module ],
56
58
) -> torch .Tensor :
57
- return apply_transform_weight_linear (weight , value , location )
59
+ if module_type == torch .nn .Linear :
60
+ fn , axis = get_linear_transform_fn (module_type , location )
58
61
62
+ else :
63
+ raise NotImplementedError (
64
+ f"Applying transforms to { module_type } is not supported"
65
+ )
66
+
67
+ assert weight .shape [0 ] == weight .shape [1 ]
68
+ head_dim = weight .shape [0 ]
69
+ num_heads = value .shape [axis ] // head_dim
59
70
60
- def apply_transform_weight_linear (
61
- weight : torch .Tensor ,
62
- value : torch .Tensor ,
71
+ value = value .unflatten (axis , (num_heads , head_dim ))
72
+ value = fn (weight , value )
73
+ value = value .flatten (axis - 1 , axis )
74
+
75
+ return value
76
+
77
+
78
+ def get_linear_transform_fn (
79
+ module_type : type [torch .nn .Module ],
63
80
location : TransformLocation ,
64
- ):
81
+ ) -> Tuple [ Callable [[ torch . Tensor , torch . Tensor ], torch . Tensor ], int ] :
65
82
"""
66
83
Using the transform location, determine how to apply the transform weight to the
67
- given value. For more info on input and output transforms, see `TransformLocation`
84
+ given value wrt linear weights. For more info on input and output transforms,
85
+ see `TransformLocation`
68
86
69
87
The following explains how weights should be applied to values according to location
70
88
@@ -97,31 +115,28 @@ def apply_transform_weight_linear(
97
115
:param location: determines how weight should be applied
98
116
:return: value after transform weight has been applied
99
117
"""
100
- value_shape = value .shape
101
- weight_size = weight .shape [0 ]
102
- assert weight .shape [0 ] == weight .shape [1 ]
103
-
104
- if location == TransformLocation .INPUT :
105
- num_heads = value_shape [1 ] // weight_size
106
- value = value .reshape (value_shape [0 ], num_heads , weight_size )
107
- ret = value @ weight
108
-
109
- elif location == TransformLocation .WEIGHT_INPUT :
110
- num_heads = value_shape [1 ] // weight_size
111
- value = value .reshape (value_shape [0 ], num_heads , weight_size )
112
- ret = value @ weight .T
113
-
114
- elif location == TransformLocation .WEIGHT_OUTPUT :
115
- num_heads = value_shape [0 ] // weight_size
116
- value = value .reshape (num_heads , weight_size , value_shape [1 ])
117
- ret = weight .T @ value
118
-
119
- elif location == TransformLocation .OUTPUT :
120
- num_heads = value_shape [1 ] // weight_size
121
- value = value .reshape (value_shape [0 ], num_heads , weight_size )
122
- ret = value @ weight
123
-
124
- else :
125
- raise NotImplementedError (f"{ location } has not been implemented yet" )
126
-
127
- return ret .reshape (value_shape )
118
+ fn = axis = None
119
+
120
+ if module_type == torch .nn .Linear :
121
+ if location == TransformLocation .INPUT :
122
+ fn = lambda weight , value : value @ weight
123
+ axis = - 1
124
+
125
+ elif location == TransformLocation .WEIGHT_INPUT :
126
+ fn = lambda weight , value : value @ weight .T
127
+ axis = - 1
128
+
129
+ elif location == TransformLocation .WEIGHT_OUTPUT :
130
+ fn = lambda weight , value : weight .T @ value
131
+ axis = - 2
132
+
133
+ elif location == TransformLocation .OUTPUT :
134
+ fn = lambda weight , value : value @ weight
135
+ axis = - 1
136
+
137
+ if fn is None :
138
+ raise NotImplementedError (
139
+ f"Applying transforms to { module_type } { location } is not supported"
140
+ )
141
+
142
+ return fn , axis
0 commit comments