@@ -59,47 +59,13 @@ def get_transform_size(
59
59
60
60
61
61
def apply_transform_weight (
62
- weight : torch .Tensor ,
62
+ transform_weight : torch .Tensor ,
63
63
value : torch .Tensor ,
64
64
location : TransformLocation ,
65
65
module_type : type [torch .nn .Module ],
66
66
) -> torch .Tensor :
67
67
"""
68
- :param weight: transform weight to apply
69
- :param value: value to apply weight to
70
- :param location: determines how weight should be applied
71
- :param model_type: result of type(module), passed in to determine application of
72
- weight transform. This is needed because torch uses convention:
73
- - torch.nn.Linear(in_features,out_features) has weight shape
74
- (out_features, in_features)
75
- - torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape
76
- (num_embeddings, embedding_dim)
77
- The transform has to account for Linear's transposed weights
78
- :return: value after weight has been applied
79
- """
80
- # get function used to apply transform
81
- fn , axis = _get_transform_method (module_type , location )
82
-
83
- # reshape for head_dim
84
- head_dim = weight .shape [0 ]
85
- num_heads = value .shape [axis ] // head_dim
86
- value = value .unflatten (axis , (num_heads , head_dim ))
87
-
88
- # apply transform
89
- value = fn (weight , value )
90
-
91
- # [undo] reshape for head_dim
92
- value = value .flatten (axis - 1 , axis )
93
-
94
- return value
95
-
96
-
97
- def _get_transform_method (
98
- module_type : type [torch .nn .Module ],
99
- location : TransformLocation ,
100
- ) -> Tuple [Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ], int ]:
101
- """
102
- Using the transform location, determine how to apply the transform weight to the
68
+ Using the transform location, apply the transform_weight to the
103
69
given value wrt linear weights. For more info on input and output transforms,
104
70
see `TransformLocation`
105
71
@@ -129,51 +95,85 @@ def _get_transform_method(
129
95
= y U
130
96
= yh
131
97
132
- :param weight : transform weight to apply
133
- :param value: value to apply weight to
98
+ :param transform_weight : transform weight to apply
99
+ :param value: value to apply transform_weight to
134
100
:param location: determines how weight should be applied
135
- :return: value after transform weight has been applied
101
+ :param model_type: result of type(module), passed in to determine application of
102
+ weight transform
103
+ :return: value after transform_weight has been applied
136
104
"""
137
- fn = axis = None
105
+
106
+ assert transform_weight .shape [0 ] == transform_weight .shape [1 ]
138
107
139
108
if module_type == torch .nn .Linear :
140
109
if location == TransformLocation .INPUT :
141
- fn = lambda weight , value : value @ weight
142
- axis = - 1
110
+ return _multihead_matmul (value , transform_weight )
143
111
144
112
elif location == TransformLocation .WEIGHT_INPUT :
145
- fn = lambda weight , value : value @ weight .T
146
- axis = - 1
113
+ # equivalent to (transform_weight @ value.T) .T
114
+ return _multihead_matmul ( value , transform_weight . T )
147
115
148
116
elif location == TransformLocation .WEIGHT_OUTPUT :
149
- fn = lambda weight , value : weight .T @ value
150
- axis = - 2
117
+ # equivalent to ( value.T @ transform_weight).T
118
+ return _multihead_matmul ( transform_weight . T , value )
151
119
152
120
elif location == TransformLocation .OUTPUT :
153
- fn = lambda weight , value : value @ weight
154
- axis = - 1
121
+ return _multihead_matmul (value , transform_weight )
155
122
156
123
# similar derivation to torch.nn.Linear, but `y = (x W)`
157
- if module_type == torch .nn .Embedding :
124
+ elif module_type == torch .nn .Embedding :
158
125
if location == TransformLocation .INPUT :
159
- fn = lambda weight , value : value @ weight
160
- axis = - 1
126
+ return _multihead_matmul (value , transform_weight )
161
127
162
128
elif location == TransformLocation .WEIGHT_INPUT :
163
- fn = lambda weight , value : weight @ value
164
- axis = - 1
129
+ return _multihead_matmul (
130
+ transform_weight ,
131
+ value ,
132
+ )
165
133
166
134
elif location == TransformLocation .WEIGHT_OUTPUT :
167
- fn = lambda weight , value : value @ weight
168
- axis = - 1
135
+ return _multihead_matmul (value , transform_weight )
169
136
170
137
elif location == TransformLocation .OUTPUT :
171
- fn = lambda weight , value : value @ weight
172
- axis = - 1
138
+ return _multihead_matmul (value , transform_weight )
173
139
174
- if fn is None :
175
- raise NotImplementedError (
176
- f"Applying transforms to { module_type } { location } is not supported"
177
- )
140
+ raise NotImplementedError (
141
+ f"Applying transforms to { module_type } { location } is not supported"
142
+ )
178
143
179
- return fn , axis
144
+
145
+ def _multihead_matmul (A : torch .Tensor , B : torch .Tensor ) -> torch .Tensor :
146
+ """
147
+ Performs A @ B for last two dims of two matrices A and B that possibly
148
+ have different shapes, as is the case in multi-headed dimension. If
149
+ shapes are different, this is equivalent to converting the last two dims
150
+ of the smaller matrix into a block-diagonal matrix with the same shape as
151
+ the last two dims of the larger matrix.
152
+
153
+ E.g. if A is half the size of B, this function will perform
154
+ [[A ] @ B
155
+ [ A]]
156
+
157
+ If B is a third of the size of A, this function will perform
158
+ A @ [[B ]
159
+ [ B ]
160
+ [ B]]
161
+
162
+ This function will error out if the shapes are not evenly divisble
163
+
164
+ :param A: left-hand tensor
165
+ :param B: right-hand tensor
166
+ :return: result
167
+ """
168
+ if A .shape [- 1 ] > B .shape [- 2 ]:
169
+ head_dim = B .shape [- 2 ]
170
+ num_heads = A .shape [- 1 ] // head_dim
171
+ A = A .unflatten (- 1 , (num_heads , head_dim ))
172
+ return (A @ B ).flatten (- 2 , - 1 )
173
+ elif A .shape [- 1 ] < B .shape [- 2 ]:
174
+ head_dim = A .shape [- 1 ]
175
+ num_heads = B .shape [- 2 ] // head_dim
176
+ B = B .unflatten (- 2 , (num_heads , head_dim ))
177
+ return (A @ B ).flatten (- 3 , - 2 )
178
+ else :
179
+ return A @ B
0 commit comments