77from torch .nn import Module
88import torch .nn .functional as F
99
10- from einops import rearrange
10+ from einops import rearrange , reduce
1111
1212from opt_einsum import contract as opt_einsum
1313
@@ -45,7 +45,8 @@ def calculate_input_covariance(
4545
4646 all_embeds = torch .cat ((all_embeds ), dim = 0 )
4747 all_embeds = rearrange (all_embeds , 'n d -> d n' )
48- return torch .cov (all_embeds , ** cov_kwargs )
48+
49+ return torch .cov (all_embeds , correction = 0 , ** cov_kwargs )
4950
5051# a module that wraps the keys and values projection of the cross attentions to text encodings
5152
@@ -56,7 +57,6 @@ def __init__(
5657 self ,
5758 key_or_values_proj : nn .Linear ,
5859 * ,
59- num_finetune_prompts : int ,
6060 C : Tensor , # covariance of input, precomputed from 100K laion text
6161 text_seq_len : int = 77 ,
6262 is_key_proj : bool = False ,
@@ -90,14 +90,14 @@ def __init__(
9090 # for exponentially smoothing the inputs
9191 # will smooth both concept and superclass token inputs
9292
93- self .register_buffer ('initted' , torch . zeros ( num_finetune_prompts ). bool ( ))
94- self .register_buffer ('ema_concept_text_encs' , torch .zeros (num_finetune_prompts , dim_input ))
93+ self .register_buffer ('initted' , Tensor ([ False ] ))
94+ self .register_buffer ('ema_concept_text_encs' , torch .zeros (dim_input ))
9595
9696 # superclass outputs - only optimized for values, but not keys
9797
9898 self .is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
9999
100- self .superclass_outputs = nn .Parameter (torch .zeros (num_finetune_prompts , dim_output ), requires_grad = not is_key_proj )
100+ self .superclass_outputs = nn .Parameter (torch .zeros (dim_output ), requires_grad = not is_key_proj )
101101
102102 # C in the paper, inverse precomputed
103103
@@ -150,57 +150,52 @@ def forward(
150150 concept_indices = rearrange (concept_indices , 'b -> b 1' )
151151
152152 concept_text_enc = text_enc [batch_indices , concept_indices ]
153- concept_text_enc = rearrange (concept_text_enc , 'b 1 d -> b d ' )
153+ concept_text_enc = reduce (concept_text_enc , 'b 1 d -> d' , 'mean ' )
154154
155- # only if training, and if prompt ids are given
155+ # only if training
156156 # do exponential smoothing of the inputs, both concept and superclass
157157
158158 if exists (text_enc_with_superclass ):
159159 superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
160- superclass_text_enc = rearrange (superclass_text_enc , 'b 1 d -> b d ' )
160+ superclass_text_enc = reduce (superclass_text_enc , 'b 1 d -> d' , 'mean ' )
161161
162- superclass_output = einsum ('b i, o i -> b o' , superclass_text_enc , weights )
162+ superclass_output = einsum ('i, o i -> o' , superclass_text_enc , weights )
163163
164164 if self .training and exists (prompt_ids ):
165165 # get the initialization state
166166 # as well as the exponentially smoothed text encodings
167167
168- initted = self .initted [prompt_ids ]
169- all_initted = initted .all ()
168+ initted = self .initted .item ()
170169
171170 ema_concept_text_enc = self .ema_concept_text_encs [prompt_ids ]
172171
173172 # store the superclass i* if not all initialized
174173 # else fetch it from the buffer
175174
176- if not all_initted :
175+ if not initted :
177176 assert exists (superclass_output ), 'text_enc_with_superclass must be passed in for the first epoch for all prompts to initialize the module correctly'
178177
179178 non_initted_prompt_ids = prompt_ids [~ initted ]
180179
181180 # for the prompt ids not initialized yet, hard copy over the initial superclass outputs
182- self .superclass_outputs [ non_initted_prompt_ids ] .data .copy_ (superclass_output )
181+ self .superclass_outputs .data .copy_ (superclass_output )
183182
184- superclass_output = self .superclass_outputs [ prompt_ids ]
183+ superclass_output = self .superclass_outputs
185184
186185 # if any in the batch is not initialized, initialize
187186
188- if not all_initted :
189- ema_concept_text_enc = torch .where (
190- rearrange (initted , 'b -> b 1' ),
191- ema_concept_text_enc ,
192- concept_text_enc
193- )
187+ if not initted :
188+ ema_concept_text_enc = concept_text_enc
194189
195190 # exponential moving average for concept input encoding
196191
197192 concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay )
198193
199194 # store
200195
201- if not all_initted :
202- self .initted [ prompt_ids ] = True
203- self .ema_concept_text_encs [ prompt_ids ] = ema_concept_text_enc
196+ if not initted :
197+ self .initted . data . copy_ ( Tensor ([ True ]))
198+ self .ema_concept_text_encs . data . copy_ ( ema_concept_text_enc )
204199
205200 # take care of the output
206201 # for the keys, make sure to turn off gradients as it is 'locked'
@@ -214,19 +209,18 @@ def forward(
214209
215210 # main contribution eq (3)
216211
217- i_energy = opt_einsum ('b o, o i, b i -> b' , i , Ci , i )
218- i_energy = rearrange (i_energy , '... -> ... 1 1' )
212+ i_energy = opt_einsum ('o, o i, i ->' , i , Ci , i )
219213
220- sim = opt_einsum ('b n o, o i, b i -> b n' , text_enc , Ci , i )
214+ sim = opt_einsum ('b n o, o i, i -> b n' , text_enc , Ci , i )
221215 sim = rearrange (sim , '... -> ... 1' )
222216
223217 sigmoid_term = (((sim / i_energy ) - beta ) / temperature ).sigmoid ()
224218
225219 text_enc_output = einsum ('b n i, o i -> b n o' , text_enc , W )
226220
227- concept_output = einsum ('b i, o i -> b o' , i , W )
228- concept_output = rearrange (concept_output , 'b d -> b 1 d' )
221+ concept_output = einsum ('i, o i -> o' , i , W )
222+ concept_output = rearrange (concept_output , 'd -> 1 1 d' )
229223
230224 W_em_orthogonal_term = text_enc_output - (sim * concept_output / i_energy )
231225
232- return W_em_orthogonal_term + sigmoid_term * rearrange (o , 'b d -> b 1 d' )
226+ return W_em_orthogonal_term + sigmoid_term * rearrange (o , 'd -> 1 1 d' )
0 commit comments