@@ -167,39 +167,55 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
167167 img_modulated = self .img_norm1 (img )
168168 img_modulated = apply_mod (img_modulated , (1 + img_mod1 .scale ), img_mod1 .shift , modulation_dims_img )
169169 img_qkv = self .img_attn .qkv (img_modulated )
170+ del img_modulated
170171 img_q , img_k , img_v = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
172+ del img_qkv
171173 img_q , img_k = self .img_attn .norm (img_q , img_k , img_v )
172174
173175 # prepare txt for attention
174176 txt_modulated = self .txt_norm1 (txt )
175177 txt_modulated = apply_mod (txt_modulated , (1 + txt_mod1 .scale ), txt_mod1 .shift , modulation_dims_txt )
176178 txt_qkv = self .txt_attn .qkv (txt_modulated )
179+ del txt_modulated
177180 txt_q , txt_k , txt_v = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
181+ del txt_qkv
178182 txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
179183
180184 if self .flipped_img_txt :
185+ q = torch .cat ((img_q , txt_q ), dim = 2 )
186+ del img_q , txt_q
187+ k = torch .cat ((img_k , txt_k ), dim = 2 )
188+ del img_k , txt_k
189+ v = torch .cat ((img_v , txt_v ), dim = 2 )
190+ del img_v , txt_v
181191 # run actual attention
182- attn = attention (torch .cat ((img_q , txt_q ), dim = 2 ),
183- torch .cat ((img_k , txt_k ), dim = 2 ),
184- torch .cat ((img_v , txt_v ), dim = 2 ),
192+ attn = attention (q , k , v ,
185193 pe = pe , mask = attn_mask , transformer_options = transformer_options )
194+ del q , k , v
186195
187196 img_attn , txt_attn = attn [:, : img .shape [1 ]], attn [:, img .shape [1 ]:]
188197 else :
198+ q = torch .cat ((txt_q , img_q ), dim = 2 )
199+ del txt_q , img_q
200+ k = torch .cat ((txt_k , img_k ), dim = 2 )
201+ del txt_k , img_k
202+ v = torch .cat ((txt_v , img_v ), dim = 2 )
203+ del txt_v , img_v
189204 # run actual attention
190- attn = attention (torch .cat ((txt_q , img_q ), dim = 2 ),
191- torch .cat ((txt_k , img_k ), dim = 2 ),
192- torch .cat ((txt_v , img_v ), dim = 2 ),
205+ attn = attention (q , k , v ,
193206 pe = pe , mask = attn_mask , transformer_options = transformer_options )
207+ del q , k , v
194208
195209 txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ]:]
196210
197211 # calculate the img bloks
198212 img += apply_mod (self .img_attn .proj (img_attn ), img_mod1 .gate , None , modulation_dims_img )
213+ del img_attn
199214 img += apply_mod (self .img_mlp (apply_mod (self .img_norm2 (img ), (1 + img_mod2 .scale ), img_mod2 .shift , modulation_dims_img )), img_mod2 .gate , None , modulation_dims_img )
200215
201216 # calculate the txt bloks
202217 txt += apply_mod (self .txt_attn .proj (txt_attn ), txt_mod1 .gate , None , modulation_dims_txt )
218+ del txt_attn
203219 txt += apply_mod (self .txt_mlp (apply_mod (self .txt_norm2 (txt ), (1 + txt_mod2 .scale ), txt_mod2 .shift , modulation_dims_txt )), txt_mod2 .gate , None , modulation_dims_txt )
204220
205221 if txt .dtype == torch .float16 :
@@ -249,12 +265,15 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
249265 qkv , mlp = torch .split (self .linear1 (apply_mod (self .pre_norm (x ), (1 + mod .scale ), mod .shift , modulation_dims )), [3 * self .hidden_size , self .mlp_hidden_dim ], dim = - 1 )
250266
251267 q , k , v = qkv .view (qkv .shape [0 ], qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
268+ del qkv
252269 q , k = self .norm (q , k , v )
253270
254271 # compute attention
255272 attn = attention (q , k , v , pe = pe , mask = attn_mask , transformer_options = transformer_options )
273+ del q , k , v
256274 # compute activation in mlp stream, cat again and run second linear layer
257- output = self .linear2 (torch .cat ((attn , self .mlp_act (mlp )), 2 ))
275+ mlp = self .mlp_act (mlp )
276+ output = self .linear2 (torch .cat ((attn , mlp ), 2 ))
258277 x += apply_mod (output , mod .gate , None , modulation_dims )
259278 if x .dtype == torch .float16 :
260279 x = torch .nan_to_num (x , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
0 commit comments