@@ -101,13 +101,17 @@ def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str)
101101 optimized_rotation = torch .load (optimized_rotation_path , weights_only = True )
102102 R1 = optimized_rotation ["R1" ].to (torch .float32 )
103103 config = model .params
104+ # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
105+ # `n_heads`.
104106 num_heads = config .n_heads
105107 head_dim = config .dim // num_heads
106108
107109 rotate_embeddings (model , R1 )
108110 rotate_head (model , R1 )
109111 cleanup_memory ()
110112
113+ # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
114+ # `Union[Tensor, Module]`.
111115 for idx , layer in enumerate (model .layers ):
112116 key = f"model.layers.{ idx } .self_attn.R2"
113117 R2 = optimized_rotation [key ].to (torch .float32 )
@@ -130,6 +134,8 @@ def fuse_ln_linear(
130134
131135 # Calculating new weight and bias
132136 W_ = linear .weight .data .to (dtype = torch .float32 )
137+ # pyre-fixme[58]: `*` is not supported for operand types `Tensor` and
138+ # `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`.
133139 linear .weight .data = (W_ * layernorm .weight .to (dtype = torch .float32 )).to (
134140 linear_dtype
135141 )
@@ -140,6 +146,8 @@ def fuse_ln_linear(
140146 torch .zeros (linear .out_features , dtype = torch .float32 )
141147 )
142148 linear .bias .data = linear .bias .data .to (dtype = torch .float32 ) + torch .matmul (
149+ # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
150+ # `Union[Tensor, Module]`.
143151 W_ , layernorm .bias .to (dtype = torch .float32 )
144152 )
145153 linear .bias .data = linear .bias .data .to (linear_dtype )
@@ -148,10 +156,18 @@ def fuse_ln_linear(
148156def fuse_layer_norms (model : torch .nn .Module ):
149157 # Embedding fusion
150158 for W in [model .tok_embeddings ]:
159+ # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
160+ # `weight`.
151161 W_ = W .weight .data .to (dtype = torch .float32 )
162+ # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
163+ # `weight`.
152164 W .weight .data = (W_ - W_ .mean (dim = - 1 , keepdim = True )).to (W .weight .data .dtype )
153165
154166 # Fuse the linear operations in Layernorm into the adjacent linear blocks.
167+ # pyre-fixme[29]:
168+ # `Union[BoundMethod[typing.Callable(torch._tensor.Tensor.__iter__)[[Named(self,
169+ # torch._tensor.Tensor)], typing.Any], torch._tensor.Tensor],
170+ # torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function.
155171 for layer in model .layers :
156172 # fuse the input layernorms into the linear layers
157173 fuse_ln_linear (layer .ffn_norm , [layer .feed_forward .w3 , layer .feed_forward .w1 ])
@@ -170,9 +186,15 @@ def fuse_layer_norms(model: torch.nn.Module):
170186 layer .attention_norm .weight .data = torch .ones_like (W_norm , dtype = torch .float32 )
171187
172188 fuse_ln_linear (
189+ # pyre-fixme[6]: For 1st argument expected `Module` but got `Union[Tensor,
190+ # Module]`.
173191 model .norm ,
192+ # pyre-fixme[6]: For 2nd argument expected `Iterable[Linear]` but got
193+ # `Iterable[Union[Tensor, Module]]`.
174194 [model .output ],
175195 )
196+ # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
197+ # `weight`.
176198 W_norm = model .norm .weight .data
177199 model .norm .weight .data = torch .ones_like (W_norm , dtype = torch .float32 )
178200
0 commit comments