Skip to content

Commit 895e1b5

Browse files
committed
a bit more progress on multi-concepts
1 parent 794e005 commit 895e1b5

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

perfusion_pytorch/embedding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ def __init__(
2020
num_embeds, dim = embed.weight.shape
2121

2222
self.num_concepts = num_concepts
23-
self.concepts = nn.Parameter(torch.randn(num_concepts, dim))
23+
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
24+
nn.init.normal_(self.concepts, std = 0.02)
25+
2426
self.concept_ids = tuple(range(num_embeds, num_embeds + num_concepts))
2527

2628
def parameters(self):

perfusion_pytorch/perfusion.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
def exists(val):
2121
return val is not None
2222

23+
def all_unique(arr):
24+
return len(set(arr)) == len(arr)
25+
2326
IndicesTensor = Union[LongTensor, IntTensor]
2427

2528
# function for calculating C - input covariance
@@ -186,8 +189,11 @@ def __init__(
186189
print('unable to perform cholesky. please make sure input covariance matrix is properly calculated')
187190
exit()
188191

189-
L_T_inv = torch.inverse(L.T)
190-
self.register_buffer('L_T_inv', L_T_inv)
192+
L_T = L.T
193+
L_T_inv = torch.inverse(L_T)
194+
195+
self.register_buffer('L_T', L_T, persistent = False)
196+
self.register_buffer('L_T_inv', L_T_inv, persistent = False)
191197

192198
def parameters(self):
193199
if not self.is_key_proj:
@@ -199,9 +205,9 @@ def parameters(self):
199205
def forward(
200206
self,
201207
text_enc: FloatTensor,
202-
concept_indices: IndicesTensor,
208+
concept_indices: Union[IndicesTensor, Tuple[IndicesTensor, ...]],
203209
text_enc_with_superclass: Optional[FloatTensor] = None,
204-
concept_id: int = 0
210+
concept_id: Union[int, Tuple[int, ...]] = 0
205211
):
206212
assert text_enc.shape[-2] == self.text_seq_len, f'CLIP text sequence length is set to be {self.text_seq_len}, but received text encoding with length {text_enc.shape[-2]}'
207213

@@ -228,6 +234,27 @@ def forward(
228234

229235
beta, temperature = (self.train_beta, self.train_temperature) if self.training else (self.eval_beta, self.eval_temperature)
230236

237+
# determine whether it is single (for training) or multi-concept (only at inference)
238+
# may separate into different modules at a future date if too complex in one module
239+
240+
is_multi_concepts = isinstance(concept_indices, tuple)
241+
242+
if is_multi_concepts:
243+
num_concepts_at_forward = len(concept_indices)
244+
245+
assert not self.training, 'multi concepts can only be done at inference'
246+
assert isinstance(concept_id, tuple)
247+
assert all_unique(concept_id)
248+
assert len(concept_id) == num_concepts_at_forward
249+
assert all([cid < self.num_concepts for cid in concept_id])
250+
251+
raise NotImplementedError
252+
else:
253+
num_concepts_at_forward = 1
254+
255+
assert isinstance(concept_id, int)
256+
assert concept_id < self.num_concepts
257+
231258
# extract the concept text encoding input
232259

233260
batch_indices = torch.arange(batch, device = device)
@@ -248,8 +275,6 @@ def forward(
248275

249276
# get the initialization state
250277

251-
assert concept_id < self.num_concepts
252-
253278
initted = self.initted[concept_id].item()
254279

255280
if self.training:
@@ -317,16 +342,18 @@ def merge_rank1_edit_modules(
317342
*modules: Rank1EditModule
318343
) -> Rank1EditModule:
319344

320-
assert all([m.initted.item() for m in modules]), 'all modules must be initialized and ideally trained'
345+
assert all([m.initted.all() for m in modules]), 'all modules must be initialized and ideally trained'
321346
assert len(set([m.concept_outputs.shape[-1] for m in modules])) == 1, 'concept output dimension must be the same'
322347
assert len(set([m.is_key_proj for m in modules])) == 1, 'all modules must be either for keys, or values. you cannot merge rank 1 edit modules of keys and values together'
323348

324349
first_module = modules[0]
325350
merged_module = deepcopy(first_module)
326351

327-
merged_module.num_concepts = sum([m.num_concepts for m in modules])
352+
total_concepts = sum([m.num_concepts for m in modules])
353+
merged_module.num_concepts = total_concepts
328354

329355
concept_outputs = torch.cat(tuple(m.concept_outputs.data for m in modules), dim = 0)
330356
merged_module.concept_outputs = nn.Parameter(concept_outputs, requires_grad = not first_module.is_key_proj)
357+
merged_module.register_buffer('initted', torch.ones(total_concepts, 1).bool())
331358

332359
return merged_module

0 commit comments

Comments
 (0)