@@ -57,8 +57,9 @@ def __init__(
5757 # will smooth both concept and superclass token inputs
5858
5959 self .register_buffer ('initted' , torch .zeros (num_finetune_prompts ).bool ())
60- self .register_buffer ('ema_concept_text_enc' , torch .zeros (num_finetune_prompts , dim_input ))
61- self .register_buffer ('ema_superclass_text_enc' , torch .zeros (num_finetune_prompts , dim_input ))
60+ self .register_buffer ('ema_concept_text_encs' , torch .zeros (num_finetune_prompts , dim_input ))
61+ self .register_buffer ('ema_superclass_text_encs' , torch .zeros (num_finetune_prompts , dim_input ))
62+ self .register_buffer ('superclass_outputs' , torch .zeros (num_finetune_prompts , dim_output ))
6263
6364 # C in the paper, inverse precomputed
6465
@@ -109,6 +110,8 @@ def forward(
109110 superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
110111 superclass_text_enc = rearrange (superclass_text_enc , 'b 1 d -> b d' )
111112
113+ superclass_output = einsum ('b i, o i -> b o' , superclass_text_enc , weights )
114+
112115 # only if training, and if prompt ids are given
113116 # do exponential smoothing of the inputs, both concept and superclass
114117
@@ -120,8 +123,13 @@ def forward(
120123 initted = rearrange (initted , 'b -> b 1' )
121124 all_initted = initted .all ()
122125
123- ema_concept_text_enc = self .ema_concept_text_enc [prompt_ids ]
124- ema_superclass_text_enc = self .ema_superclass_text_enc [prompt_ids ]
126+ ema_concept_text_enc = self .ema_concept_text_encs [prompt_ids ]
127+ ema_superclass_text_enc = self .ema_superclass_text_encs [prompt_ids ]
128+
129+ # for keys, the superclass output (o*) is stored on init
130+ # and never optimized
131+
132+ stored_superclass_output = self .superclass_outputs [prompt_ids ]
125133
126134 # if any in the batch is not initialized, initialize
127135
@@ -138,6 +146,12 @@ def forward(
138146 superclass_text_enc
139147 )
140148
149+ superclass_output = torch .where (
150+ initted ,
151+ stored_superclass_output ,
152+ superclass_output
153+ )
154+
141155 # exponential moving average of both concept and superclass
142156
143157 concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay )
@@ -147,20 +161,19 @@ def forward(
147161
148162 if not all_initted :
149163 self .initted [prompt_ids ] = True
150- self .ema_concept_text_enc [prompt_ids ] = ema_concept_text_enc
151- self .ema_superclass_text_enc [prompt_ids ] = ema_superclass_text_enc
164+ self .ema_concept_text_encs [prompt_ids ] = ema_concept_text_enc
165+ self .ema_superclass_text_encs [prompt_ids ] = ema_superclass_text_enc
166+ self .superclass_outputs [prompt_ids ] = superclass_output
152167
153168 # take care of the output
154169 # for the keys, make sure to turn off gradients as it is 'locked'
155170
156- superclass_text_enc_output = einsum ('b i, o i -> b o' , superclass_text_enc , weights )
157-
158171 if self .is_key_proj :
159- superclass_text_enc_output = superclass_text_enc_output .detach ()
172+ superclass_output = superclass_output .detach ()
160173
161174 # make it easier to match with paper
162175
163- i , o , W = concept_text_enc , superclass_text_enc_output , weights
176+ i , o , W = concept_text_enc , superclass_output , weights
164177
165178 # main contribution eq (3)
166179
0 commit comments