1111def exists (val ):
1212 return val is not None
1313
14- # main contribution of paper
1514# a module that wraps the keys and values projection of the cross attentions to text encodings
1615
1716class Rank1EditModule (Module ):
@@ -22,7 +21,8 @@ def __init__(
2221 key_or_values_proj : nn .Linear ,
2322 * ,
2423 num_finetune_prompts : int ,
25- C : Tensor ,
24+ C : Tensor , # covariance of input, precomputed from 100K laion text
25+ text_seq_len : int = 256 ,
2626 is_key_proj : bool = False ,
2727 input_decay = 0.99 ,
2828 train_beta = 0.75 ,
@@ -34,7 +34,7 @@ def __init__(
3434 assert not exists (key_or_values_proj .bias ), 'key value projection in attention should not have bias'
3535
3636 self .weight = key_or_values_proj .weight
37- dim_input = self .weight .shape [ - 1 ]
37+ dim_output , dim_input = self .weight .shape
3838
3939 self .is_key_proj = is_key_proj # will lock the output to the super-class, and turn off gradients
4040
@@ -48,10 +48,13 @@ def __init__(
4848 # they exponentially smooth the text encoding inputs during training
4949 # in addition to a lowered learning rate on the text encodings
5050
51+ self .text_seq_len = text_seq_len
52+
5153 self .register_buffer ('initted' , torch .zeros (num_finetune_prompts ).bool ())
5254 self .register_buffer ('ema_concept_text_enc' , torch .zeros (num_finetune_prompts , dim_input ))
55+ self .register_buffer ('outputs' , torch .zeros (num_finetune_prompts , text_seq_len , dim_output ))
5356
54- # buffers
57+ # C in the paper, inverse precomputed
5558
5659 self .register_buffer ('C_inv' , torch .inverse (C ))
5760
@@ -62,13 +65,22 @@ def forward(
6265 text_enc : Tensor ,
6366 concept_indices : Tensor
6467 ):
68+ assert text_enc .shape [- 2 ] == self .text_seq_len , f'CLIP text sequence length is set to be { self .text_seq_len } , but received text encoding with length { text_enc .shape [- 2 ]} '
69+
6570 """
6671 following the pseudocode of Algorithm 1 in appendix
72+
73+ einstein notation:
74+ b - batch
75+ n - sequence
76+ d - feature dimension
77+ i - input dimension
78+ o - output dimension
6779 """
6880
6981 batch , device = text_enc .shape [0 ], self .initted .device
7082
71- weights , decay = self .weight , self .input_decay
83+ weights , decay , Ci = self .weight , self .input_decay , self . C_inv
7284
7385 # beta and temperature depends on whether training or inference
7486
@@ -86,20 +98,59 @@ def forward(
8698 # during training, keep track of exponentially smoothed input
8799
88100 if self .training :
89- batch_initted = self .initted [prompt_ids ]
101+
102+ # get the initialization state
103+ # as well as the exponentially smoothed text encodings
104+
105+ initted = self .initted [prompt_ids ]
106+ all_initted = initted .all ()
107+
90108 ema_concept_text_enc = self .ema_concept_text_enc [prompt_ids ]
109+ outputs = self .outputs [prompt_ids ]
110+
111+ # if any in the batch is not initialized, initialize
91112
92- ema_concept_text_enc = torch .where (
93- rearrange (batch_initted , 'b -> b 1' ),
94- ema_concept_text_enc ,
95- concept_text_enc
96- )
113+ if not all_initted :
114+ ema_concept_text_enc = torch .where (
115+ rearrange (initted , 'b -> b 1' ),
116+ ema_concept_text_enc ,
117+ concept_text_enc
118+ )
119+
120+ outputs = torch .where (
121+ rearrange (initted , 'b -> b 1 1' ),
122+ outputs ,
123+ einsum ('o i, b n i -> b n o' , weights , text_enc )
124+ )
97125
98126 # update using exponential moving average
99127
100128 concept_text_enc = ema_concept_text_enc * decay + concept_text_enc * (1. - decay )
101129
102- self .initted [prompt_ids ] = True
103- self .ema_concept_text_enc [prompt_ids ] = concept_text_enc
130+ if not all_initted :
131+ self .initted [prompt_ids ] = True
132+ self .ema_concept_text_enc [prompt_ids ] = ema_concept_text_enc
133+ self .outputs [prompt_ids ] = outputs
134+
135+ # make it easier to match with paper
136+
137+ i , o , W = ema_concept_text_enc , outputs , weights
138+
139+ # main contribution eq (3)
140+
141+ i_energy = einsum ('b d, b d -> b' , i @ Ci , i )
142+ i_energy = rearrange (i_energy , '... -> ... 1 1' )
143+
144+ sim = einsum ('b n d, b d -> b n' , text_enc , i @ Ci )
145+ sim = rearrange (sim , '... -> ... 1' )
146+
147+ sigmoid_term = (((sim / i_energy ) - beta ) / temperature ).sigmoid ()
148+
149+ orig_output = einsum ('b n i, o i -> b n o' , text_enc , W )
150+
151+ concept_output = einsum ('b i, o i -> b o' , i , W )
152+ concept_output = rearrange (concept_output , 'b d -> b 1 d' )
153+
154+ W_em_orthogonal_term = orig_output - (sim * concept_output / i_energy )
104155
105- return einsum ( 'b n i, o i -> b n o' , text_enc , weights )
156+ return W_em_orthogonal_term + sigmoid_term * o
0 commit comments