@@ -140,7 +140,6 @@ def __init__(
140140 assert not exists (key_or_values_proj .bias ), 'key value projection in attention should not have bias'
141141
142142 self .num_concepts = num_concepts
143- self .has_multiple_concepts = num_concepts > 1
144143
145144 self .weight = key_or_values_proj .weight
146145 dim_output , dim_input = self .weight .shape
@@ -176,15 +175,23 @@ def __init__(
176175 C_inv = torch .inverse (C )
177176 self .register_buffer ('C_inv' , C_inv )
178177
178+ @property
179+ def num_concepts (self ):
180+ return self ._num_concepts
181+
182+ @num_concepts .setter
183+ def num_concepts (self , value ):
184+ self ._num_concepts = value
185+
186+ if value == 1 :
187+ return
188+
179189 # for multiple concepts
180190 # need cholesky decomposed L_t_inv
181191 # Appendix B
182192
183- if not self .has_multiple_concepts :
184- return
185-
186193 try :
187- L = torch .linalg .cholesky (C_inv )
194+ L = torch .linalg .cholesky (self . C_inv )
188195 except :
189196 print ('unable to perform cholesky. please make sure input covariance matrix is properly calculated' )
190197 exit ()
@@ -195,6 +202,10 @@ def __init__(
195202 self .register_buffer ('L_T' , L_T , persistent = False )
196203 self .register_buffer ('L_T_inv' , L_T_inv , persistent = False )
197204
205+ @property
206+ def device (self ):
207+ return next (self .buffers ()).device
208+
198209 def parameters (self ):
199210 if not self .is_key_proj :
200211 return []
@@ -205,7 +216,8 @@ def parameters(self):
205216 def forward (
206217 self ,
207218 text_enc : FloatTensor ,
208- concept_indices : Union [IndicesTensor , Tuple [IndicesTensor , ...]],
219+ * ,
220+ concept_indices : Optional [IndicesTensor ] = None ,
209221 text_enc_with_superclass : Optional [FloatTensor ] = None ,
210222 concept_id : Union [int , Tuple [int , ...]] = 0
211223 ):
@@ -220,6 +232,7 @@ def forward(
220232 d - feature dimension
221233 i - input dimension
222234 o - output dimension
235+ c - concepts dimension (for multiple concepts)
223236 """
224237
225238 batch , device = text_enc .shape [0 ], self .C_inv .device
@@ -237,47 +250,41 @@ def forward(
237250 # determine whether it is single (for training) or multi-concept (only at inference)
238251 # may separate into different modules at a future date if too complex in one module
239252
240- is_multi_concepts = isinstance (concept_indices , tuple )
253+ is_multi_concepts = isinstance (concept_id , tuple )
241254
242255 if is_multi_concepts :
243- num_concepts_at_forward = len (concept_indices )
244-
245256 assert not self .training , 'multi concepts can only be done at inference'
246- assert isinstance (concept_id , tuple )
247257 assert all_unique (concept_id )
248- assert len (concept_id ) == num_concepts_at_forward
249258 assert all ([cid < self .num_concepts for cid in concept_id ])
250259
251- raise NotImplementedError
260+ concept_id_tensor = torch . tensor ( concept_id , dtype = torch . long , device = self . device )
252261 else :
253- num_concepts_at_forward = 1
254-
255- assert isinstance (concept_id , int )
256262 assert concept_id < self .num_concepts
263+ concept_id_tensor = torch .tensor ([concept_id ], dtype = torch .long , device = self .device )
257264
258- # extract the concept text encoding input
265+ # get the initialization state
259266
260- batch_indices = torch .arange (batch , device = device )
261- batch_indices = rearrange (batch_indices , 'b -> b 1' )
262- concept_indices = rearrange (concept_indices , 'b -> b 1' )
267+ if self .training :
268+ initted = self .initted [concept_id ].item ()
263269
264- concept_text_enc = text_enc [batch_indices , concept_indices ]
265- concept_text_enc = reduce (concept_text_enc , 'b 1 d -> d' , 'mean' )
270+ # extract the concept text encoding input
266271
267- # only if training
268- # do exponential smoothing of the inputs, both concept and superclass
272+ batch_indices = torch .arange (batch , device = device )
273+ batch_indices = rearrange (batch_indices , 'b -> b 1' )
274+ concept_indices = rearrange (concept_indices , 'b -> b 1' )
269275
270- if exists (text_enc_with_superclass ):
271- superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
272- superclass_text_enc = reduce (superclass_text_enc , 'b 1 d -> d' , 'mean' )
276+ concept_text_enc = text_enc [batch_indices , concept_indices ]
277+ concept_text_enc = reduce (concept_text_enc , 'b 1 d -> d' , 'mean' )
273278
274- superclass_output = einsum ('i, o i -> o' , superclass_text_enc , weights )
279+ # only if training
280+ # do exponential smoothing of the inputs, both concept and superclass
275281
276- # get the initialization state
282+ if exists (text_enc_with_superclass ):
283+ superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
284+ superclass_text_enc = reduce (superclass_text_enc , 'b 1 d -> d' , 'mean' )
277285
278- initted = self . initted [ concept_id ]. item ( )
286+ superclass_output = einsum ( 'i, o i -> o' , superclass_text_enc , weights )
279287
280- if self .training :
281288 # store the superclass i* if not all initialized
282289 # else fetch it from the buffer
283290
@@ -311,29 +318,49 @@ def forward(
311318 self .initted [concept_id ].data .copy_ (Tensor ([True ]))
312319 self .ema_concept_text_encs [concept_id ].data .copy_ (concept_text_enc )
313320 else :
314- assert initted , 'you have not initialized or trained this module yet '
321+ assert self . initted [ concept_id_tensor ]. all () , 'you have not initialized or trained this module for the concepts id given '
315322
316323 # make it easier to match with paper
317324
318- i , o , W = self .ema_concept_text_encs [concept_id ], self .concept_outputs [concept_id ], weights
325+ i , o , W = self .ema_concept_text_encs [concept_id_tensor ], self .concept_outputs [concept_id_tensor ], weights
319326
320327 # main contribution eq (3)
321328
322- i_energy = opt_einsum ('o, o i, i ->' , i , Ci , i )
329+ i_energy = opt_einsum ('c o, o i, c i ->' , i , Ci , i )
323330
324- sim = opt_einsum ('b n o, o i, i -> b n' , text_enc , Ci , i )
331+ sim = opt_einsum ('b n o, o i, c i -> c b n' , text_enc , Ci , i )
325332 sim = rearrange (sim , '... -> ... 1' )
326333
327334 sigmoid_term = (((sim / i_energy ) - beta ) / temperature ).sigmoid ()
328335
329- text_enc_output = einsum ('b n i, o i -> b n o' , text_enc , W )
336+ if is_multi_concepts :
337+ L_T , L_T_inv = self .L_T , self .L_T_inv
338+
339+ # metric - metric space - variable with tilde in Appendix B
340+
341+ # equation (6)
342+
343+ i_metric = einsum ('o i, c i -> c o' , L_T , i )
344+ u_metric , _ = torch .linalg .qr (i_metric .T )
345+ u = einsum ('o i, i c -> c o' , L_T_inv , u_metric )
346+
347+ # equation (10)
330348
331- concept_output = einsum ('i, o i -> o' , i , W )
332- concept_output = rearrange (concept_output , 'd -> 1 1 d' )
349+ em_orthogonal = text_enc - opt_einsum ('c o, b n i, c i -> b n o' , u , text_enc , u )
333350
334- W_em_orthogonal_term = text_enc_output - (sim * concept_output / i_energy )
351+ W_em_orthogonal_term = einsum ('b n i, o i -> b n o' , em_orthogonal , W )
352+ else :
353+ text_enc_output = einsum ('b n i, o i -> b n o' , text_enc , W )
354+
355+ concept_output = einsum ('c i, o i -> c o' , i , W )
356+ concept_output = rearrange (concept_output , 'c d -> c 1 1 d' )
357+
358+ W_em_orthogonal_term = text_enc_output - reduce (sim * concept_output / i_energy , 'c ... -> ...' , 'sum' )
335359
336- return W_em_orthogonal_term + sigmoid_term * rearrange (o , 'd -> 1 1 d' )
360+ gated_term = sigmoid_term * rearrange (o , 'c d -> c 1 1 d' )
361+ gated_term = reduce (gated_term , 'c ... -> ...' , 'sum' )
362+
363+ return W_em_orthogonal_term + gated_term
337364
338365# for merging trained Rank1EditModule(s) above
339366
@@ -354,6 +381,10 @@ def merge_rank1_edit_modules(
354381
355382 concept_outputs = torch .cat (tuple (m .concept_outputs .data for m in modules ), dim = 0 )
356383 merged_module .concept_outputs = nn .Parameter (concept_outputs , requires_grad = not first_module .is_key_proj )
384+
385+ ema_concept_text_encs = torch .cat (tuple (m .ema_concept_text_encs .data for m in modules ), dim = 0 )
386+ merged_module .register_buffer ('ema_concept_text_encs' , ema_concept_text_encs )
387+
357388 merged_module .register_buffer ('initted' , torch .ones (total_concepts , 1 ).bool ())
358389
359390 return merged_module
0 commit comments