22from beartype .typing import Union
33
44import torch
5- from torch import nn , einsum , Tensor , IntTensor , LongTensor , FloatTensor
5+ from torch import nn , einsum , Tensor , IntTensor , LongTensor , FloatTensor , Optional
66from torch .nn import Module
77import torch .nn .functional as F
88
1515def exists (val ):
1616 return val is not None
1717
18+ IndicesTensor = Union [LongTensor , IntTensor ]
19+
1820# a module that wraps the keys and values projection of the cross attentions to text encodings
1921
2022class Rank1EditModule (Module ):
@@ -51,6 +53,13 @@ def __init__(
5153
5254 self .text_seq_len = text_seq_len
5355
56+ # for exponentially smoothing the inputs
57+ # will smooth both concept and superclass token inputs
58+
59+ self .register_buffer ('initted' , torch .zeros (num_finetune_prompts ).bool ())
60+ self .register_buffer ('ema_concept_text_enc' , torch .zeros (num_finetune_prompts , dim_input ))
61+ self .register_buffer ('ema_superclass_text_enc' , torch .zeros (num_finetune_prompts , dim_input ))
62+
5463 # C in the paper, inverse precomputed
5564
5665 self .register_buffer ('C_inv' , torch .inverse (C ))
@@ -60,7 +69,9 @@ def forward(
6069 self ,
6170 text_enc : FloatTensor ,
6271 text_enc_with_superclass : FloatTensor ,
63- concept_indices : Union [IntTensor , LongTensor ]
72+ concept_indices : IndicesTensor ,
73+ * ,
74+ prompt_ids : Optional [IndicesTensor ] = None
6475 ):
6576 assert text_enc .shape [- 2 ] == self .text_seq_len , f'CLIP text sequence length is set to be { self .text_seq_len } , but received text encoding with length { text_enc .shape [- 2 ]} '
6677
@@ -98,21 +109,58 @@ def forward(
98109 superclass_text_enc = text_enc_with_superclass [batch_indices , concept_indices ]
99110 superclass_text_enc = rearrange (superclass_text_enc , 'b 1 d -> b d' )
100111
112+ # only if training, and if prompt ids are given
113+ # do exponential smoothing of the inputs, both concept and superclass
114+
115+ if self .training and exists (prompt_ids ):
116+ # get the initialization state
117+ # as well as the exponentially smoothed text encodings
118+
119+ initted = self .initted [prompt_ids ]
120+ initted = rearrange (initted , 'b -> b 1' )
121+ all_initted = initted .all ()
122+
123+ ema_concept_text_enc = self .ema_concept_text_enc [prompt_ids ]
124+ ema_superclass_text_enc = self .ema_superclass_text_enc [prompt_ids ]
125+
126+ # if any in the batch is not initialized, initialize
127+
128+ if not all_initted :
129+ ema_concept_text_enc = torch .where (
130+ initted ,
131+ ema_concept_text_enc ,
132+ concept_text_enc
133+ )
134+
135+ ema_superclass_text_enc = torch .where (
136+ initted ,
137+ ema_superclass_text_enc ,
138+ superclass_text_enc
139+ )
140+
141+ # exponential moving average of both concept and superclass
142+
143+ concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay )
144+ superclass_text_enc = ema_superclass_text_enc * decay + superclass_text_enc * (1. - decay )
145+
146+ # store
147+
148+ if not all_initted :
149+ self .initted [prompt_ids ] = True
150+ self .ema_concept_text_enc [prompt_ids ] = ema_concept_text_enc
151+ self .ema_superclass_text_enc [prompt_ids ] = ema_superclass_text_enc
152+
153+ # take care of the output
154+ # for the keys, make sure to turn off gradients as it is 'locked'
155+
101156 superclass_text_enc_output = einsum ('b i, o i -> b o' , superclass_text_enc , weights )
102157
103158 if self .is_key_proj :
104159 superclass_text_enc_output = superclass_text_enc_output .detach ()
105160
106- # only during training do they exponentially smooth
107-
108- if self .training :
109- online_estimated_concept_enc = decay * superclass_text_enc + (1. - decay ) * concept_text_enc
110- else :
111- online_estimated_concept_enc = concept_text_enc
112-
113161 # make it easier to match with paper
114162
115- i , o , W = online_estimated_concept_enc , superclass_text_enc_output , weights
163+ i , o , W = concept_text_enc , superclass_text_enc_output , weights
116164
117165 # main contribution eq (3)
118166
0 commit comments