@@ -129,11 +129,11 @@ def __init__(
129129 self .register_buffer ('initted' , torch .zeros (num_concepts , 1 ).bool ())
130130 self .register_buffer ('ema_concept_text_encs' , torch .zeros (num_concepts , dim_input ))
131131
132- # superclass outputs - only optimized for values, but not keys
132+ # concept outputs - only optimized for values, but not keys
133133
134134 self .is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
135135
136- self .superclass_output = nn .Parameter (torch .zeros (num_concepts , dim_output ), requires_grad = not is_key_proj )
136+ self .concept_output = nn .Parameter (torch .zeros (num_concepts , dim_output ), requires_grad = not is_key_proj )
137137
138138 # C in the paper, inverse precomputed
139139
@@ -143,7 +143,7 @@ def parameters(self):
143143 if not self .is_key_proj :
144144 return []
145145
146- return [self .superclass_outputs ]
146+ return [self .concept_output ]
147147
148148 @beartype
149149 def forward (
@@ -209,15 +209,15 @@ def forward(
209209 if not initted :
210210 assert exists (superclass_output ), 'text_enc_with_superclass must be passed in for the first batch'
211211
212- # for the prompt ids not initialized yet, hard copy over the initial superclass outputs
213- self .superclass_output [concept_id ].data .copy_ (superclass_output )
212+ # init concept output with superclass output - fixed for keys, learned for values
213+ self .concept_output [concept_id ].data .copy_ (superclass_output )
214214
215- elif exists (superclass_output ):
215+ elif exists (superclass_output ) and self . is_key_proj :
216216 # if text enc with superclass is passed in for more than 1 batch
217- # just take the opportunity to exponentially average it a bit more
217+ # just take the opportunity to exponentially average it a bit more for the keys, which have fixed concept output (to superclass)
218218
219- ema_superclass_output = self .superclass_output * decay + superclass_output * (1. - decay )
220- self .superclass_output [concept_id ].data .copy_ (ema_superclass_output )
219+ ema_concept_output = self .concept_output * decay + superclass_output * (1. - decay )
220+ self .concept_output [concept_id ].data .copy_ (ema_concept_output )
221221
222222 # if any in the batch is not initialized, initialize
223223
@@ -234,13 +234,13 @@ def forward(
234234
235235 if not initted :
236236 self .initted [concept_id ].data .copy_ (Tensor ([True ]))
237- self .ema_concept_text_encs [concept_id ].data .copy_ (ema_concept_text_enc )
237+ self .ema_concept_text_encs [concept_id ].data .copy_ (concept_text_enc )
238238 else :
239239 assert initted , 'you have not initialized or trained this module yet'
240240
241241 # make it easier to match with paper
242242
243- i , o , W = concept_text_enc , self .superclass_output [concept_id ], weights
243+ i , o , W = self . ema_concept_text_encs [ concept_id ] , self .concept_output [concept_id ], weights
244244
245245 # main contribution eq (3)
246246
0 commit comments