@@ -72,8 +72,6 @@ def __init__(
7272 self .weight = key_or_values_proj .weight
7373 dim_output , dim_input = self .weight .shape
7474
75- self .is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
76-
7775 self .train_beta = train_beta
7876 self .train_temperature = train_temperature
7977 self .eval_beta = eval_beta
@@ -88,13 +86,23 @@ def __init__(
8886
8987 self .register_buffer ('initted' , torch .zeros (num_finetune_prompts ).bool ())
9088 self .register_buffer ('ema_concept_text_encs' , torch .zeros (num_finetune_prompts , dim_input ))
91- self .register_buffer ('superclass_text_encs' , torch .zeros (num_finetune_prompts , dim_input ))
92- self .register_buffer ('superclass_outputs' , torch .zeros (num_finetune_prompts , dim_output ))
89+
90+ # superclass outputs - only optimized for values, but not keys
91+
92+ self .is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
93+
94+ self .superclass_outputs = nn .Parameter (torch .zeros (num_finetune_prompts , dim_output ), requires_grad = not is_key_proj )
9395
9496 # C in the paper, inverse precomputed
9597
9698 self .register_buffer ('C_inv' , torch .inverse (C ))
9799
100+ def parameters (self ):
101+ if not self .is_key_proj :
102+ return []
103+
104+ return [self .superclass_outputs ]
105+
98106 @beartype
99107 def forward (
100108 self ,
@@ -134,59 +142,46 @@ def forward(
134142 concept_text_enc = text_enc [batch_indices , concept_indices ]
135143 concept_text_enc = rearrange (concept_text_enc , 'b 1 d -> b d' )
136144
137- # take care of initializing with superclass prompt
138- # for key-locking - this assumes stable diffusion was modified so text encoder takes in a prompt with both the <concept> as well as <superclass> - it seems this also has the limitation that <superclass> must be one token
139-
140- superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
141- superclass_text_enc = rearrange (superclass_text_enc , 'b 1 d -> b d' )
142-
143- superclass_output = einsum ('b i, o i -> b o' , superclass_text_enc , weights )
144-
145145 # only if training, and if prompt ids are given
146146 # do exponential smoothing of the inputs, both concept and superclass
147147
148+ if exists (text_enc_with_superclass ):
149+ superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
150+ superclass_text_enc = rearrange (superclass_text_enc , 'b 1 d -> b d' )
151+
152+ superclass_output = einsum ('b i, o i -> b o' , superclass_text_enc , weights )
153+
148154 if self .training and exists (prompt_ids ):
149155 # get the initialization state
150156 # as well as the exponentially smoothed text encodings
151157
152158 initted = self .initted [prompt_ids ]
153- initted = rearrange (initted , 'b -> b 1' )
154159 all_initted = initted .all ()
155160
156161 ema_concept_text_enc = self .ema_concept_text_encs [prompt_ids ]
157162
158- # fetch superclass
163+ # store the superclass i* if not all initialized
164+ # else fetch it from the buffer
159165
160- assert exists (superclass_text_enc ) or all_initted
166+ if not all_initted :
167+ assert exists (superclass_output ), 'text_enc_with_superclass must be passed in for the first epoch for all prompts to initialize the module correctly'
161168
162- stored_superclass_text_enc = self . superclass_text_encs [ prompt_ids ]
169+ non_initted_prompt_ids = prompt_ids [ ~ initted ]
163170
164- # for keys, the superclass output (o*) is stored on init
165- # and never optimized
171+ # for the prompt ids not initialized yet, hard copy over the initial superclass outputs
172+ self . superclass_outputs [ non_initted_prompt_ids ]. data . copy_ ( superclass_output )
166173
167- stored_superclass_output = self .superclass_outputs [prompt_ids ]
174+ superclass_output = self .superclass_outputs [prompt_ids ]
168175
169176 # if any in the batch is not initialized, initialize
170177
171178 if not all_initted :
172179 ema_concept_text_enc = torch .where (
173- initted ,
180+ rearrange ( initted , 'b -> b 1' ) ,
174181 ema_concept_text_enc ,
175182 concept_text_enc
176183 )
177184
178- superclass_text_enc = torch .where (
179- initted ,
180- stored_superclass_text_enc ,
181- superclass_text_enc
182- )
183-
184- superclass_output = torch .where (
185- initted ,
186- stored_superclass_output ,
187- superclass_output
188- )
189-
190185 # exponential moving average for concept input encoding
191186
192187 concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay )
@@ -196,8 +191,6 @@ def forward(
196191 if not all_initted :
197192 self .initted [prompt_ids ] = True
198193 self .ema_concept_text_encs [prompt_ids ] = ema_concept_text_enc
199- self .superclass_text_encs [prompt_ids ] = superclass_text_enc
200- self .superclass_outputs [prompt_ids ] = superclass_output
201194
202195 # take care of the output
203196 # for the keys, make sure to turn off gradients as it is 'locked'
0 commit comments