2020def exists (val ):
2121 return val is not None
2222
23+ def all_unique (arr ):
24+ return len (set (arr )) == len (arr )
25+
2326IndicesTensor = Union [LongTensor , IntTensor ]
2427
2528# function for calculating C - input covariance
@@ -186,8 +189,11 @@ def __init__(
186189 print ('unable to perform cholesky. please make sure input covariance matrix is properly calculated' )
187190 exit ()
188191
189- L_T_inv = torch .inverse (L .T )
190- self .register_buffer ('L_T_inv' , L_T_inv )
192+ L_T = L .T
193+ L_T_inv = torch .inverse (L_T )
194+
195+ self .register_buffer ('L_T' , L_T , persistent = False )
196+ self .register_buffer ('L_T_inv' , L_T_inv , persistent = False )
191197
192198 def parameters (self ):
193199 if not self .is_key_proj :
@@ -199,9 +205,9 @@ def parameters(self):
199205 def forward (
200206 self ,
201207 text_enc : FloatTensor ,
202- concept_indices : IndicesTensor ,
208+ concept_indices : Union [ IndicesTensor , Tuple [ IndicesTensor , ...]] ,
203209 text_enc_with_superclass : Optional [FloatTensor ] = None ,
204- concept_id : int = 0
210+ concept_id : Union [ int , Tuple [ int , ...]] = 0
205211 ):
206212 assert text_enc .shape [- 2 ] == self .text_seq_len , f'CLIP text sequence length is set to be { self .text_seq_len } , but received text encoding with length { text_enc .shape [- 2 ]} '
207213
@@ -228,6 +234,27 @@ def forward(
228234
229235 beta , temperature = (self .train_beta , self .train_temperature ) if self .training else (self .eval_beta , self .eval_temperature )
230236
237+ # determine whether it is single (for training) or multi-concept (only at inference)
238+ # may separate into different modules at a future date if too complex in one module
239+
240+ is_multi_concepts = isinstance (concept_indices , tuple )
241+
242+ if is_multi_concepts :
243+ num_concepts_at_forward = len (concept_indices )
244+
245+ assert not self .training , 'multi concepts can only be done at inference'
246+ assert isinstance (concept_id , tuple )
247+ assert all_unique (concept_id )
248+ assert len (concept_id ) == num_concepts_at_forward
249+ assert all ([cid < self .num_concepts for cid in concept_id ])
250+
251+ raise NotImplementedError
252+ else :
253+ num_concepts_at_forward = 1
254+
255+ assert isinstance (concept_id , int )
256+ assert concept_id < self .num_concepts
257+
231258 # extract the concept text encoding input
232259
233260 batch_indices = torch .arange (batch , device = device )
@@ -248,8 +275,6 @@ def forward(
248275
249276 # get the initialization state
250277
251- assert concept_id < self .num_concepts
252-
253278 initted = self .initted [concept_id ].item ()
254279
255280 if self .training :
@@ -317,16 +342,18 @@ def merge_rank1_edit_modules(
317342 * modules : Rank1EditModule
318343) -> Rank1EditModule :
319344
320- assert all ([m .initted .item () for m in modules ]), 'all modules must be initialized and ideally trained'
345+ assert all ([m .initted .all () for m in modules ]), 'all modules must be initialized and ideally trained'
321346 assert len (set ([m .concept_outputs .shape [- 1 ] for m in modules ])) == 1 , 'concept output dimension must be the same'
322347 assert len (set ([m .is_key_proj for m in modules ])) == 1 , 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'
323348
324349 first_module = modules [0 ]
325350 merged_module = deepcopy (first_module )
326351
327- merged_module .num_concepts = sum ([m .num_concepts for m in modules ])
352+ total_concepts = sum ([m .num_concepts for m in modules ])
353+ merged_module .num_concepts = total_concepts
328354
329355 concept_outputs = torch .cat (tuple (m .concept_outputs .data for m in modules ), dim = 0 )
330356 merged_module .concept_outputs = nn .Parameter (concept_outputs , requires_grad = not first_module .is_key_proj )
357+ merged_module .register_buffer ('initted' , torch .ones (total_concepts , 1 ).bool ())
331358
332359 return merged_module
0 commit comments