You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+3-1Lines changed: 3 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -78,7 +78,9 @@ values = wrapped_to_values(
78
78
79
79
## Todo
80
80
81
-
-[ ] handle rank-1 update for multiple concepts
81
+
-[] handle rank-1 update for multiple concepts
82
+
-[x] handle training with multiple concepts
83
+
-[ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
82
84
83
85
-[x] take care of the function that takes in the dataset and text encoder and precomputes the covariance matrix needed for the rank-1 update
84
86
-[x] instead of having the researcher worry about different learning rates, offer the fractional gradient trick from other paper (to learn the concept embedding)
asserttext_enc.shape[-2] ==self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
153
157
@@ -194,7 +198,9 @@ def forward(
194
198
195
199
# get the initialization state
196
200
197
-
initted=self.initted.item()
201
+
assertconcept_id<self.num_concepts
202
+
203
+
initted=self.initted[concept_id].item()
198
204
199
205
ifself.training:
200
206
# store the superclass i* if not all initialized
@@ -204,21 +210,21 @@ def forward(
204
210
assertexists(superclass_output), 'text_enc_with_superclass must be passed in for the first batch'
205
211
206
212
# for the prompt ids not initialized yet, hard copy over the initial superclass outputs
0 commit comments