Skip to content

Commit 0bdd1ec

Browse files
committed
complete multiple concepts inference code
1 parent 895e1b5 commit 0bdd1ec

File tree

3 files changed

+83
-59
lines changed

3 files changed

+83
-59
lines changed

README.md

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,46 +48,39 @@ wrapped_to_values = Rank1EditModule(
4848

4949
text_enc = torch.randn(4, 77, 768) # regular input
5050
text_enc_with_superclass = torch.randn(4, 77, 768) # init_input in algorithm 1, for key-locking
51-
concept_ids = torch.randint(0, 77, (4,))
51+
concept_indices = torch.randint(0, 77, (4,))
5252

5353
keys = wrapped_to_keys(
5454
text_enc,
55-
text_enc_with_superclass,
56-
concept_ids
55+
concept_indices = concept_indices,
56+
text_enc_with_superclass = text_enc_with_superclass,
5757
)
5858

5959
values = wrapped_to_values(
6060
text_enc,
61-
text_enc_with_superclass,
62-
concept_ids
61+
concept_indices = concept_indices,
62+
text_enc_with_superclass = text_enc_with_superclass,
6363
)
6464

6565
# after much training ...
6666

6767
wrapped_to_keys.eval()
6868
wrapped_to_values.eval()
6969

70-
keys = wrapped_to_keys(
71-
text_enc,
72-
text_enc_with_superclass,
73-
concept_ids
74-
)
70+
keys = wrapped_to_keys(text_enc)
71+
72+
values = wrapped_to_values(text_enc)
7573

76-
values = wrapped_to_values(
77-
text_enc,
78-
text_enc_with_superclass,
79-
concept_ids
80-
)
8174
```
8275

8376
## Todo
8477

85-
- [ ] handle rank-1 update for multiple concepts
86-
- [x] handle training with multiple concepts
87-
- [ ] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
88-
- [ ] accept multiple concept indices
8978
- [ ] offer a magic function that automatically tries to wire up the cross attention by looking for appropriately named `nn.Linear` and auto-inferring which ones are keys or values
79+
- [ ] show example in readme for inference with multiple concepts
80+
- [ ] review multiple concepts
9081

82+
- [x] handle multiple concepts in one prompt at inference - summation of the sigmoid term + outputs
83+
- [x] accept multiple concept indices
9184
- [x] offer a way to combine separately learned concepts from multiple `Rank1EditModule` into one for inference
9285
- [x] offer function for merging `Rank1EditModule`s
9386
- [x] add the zero-shot masking of concept proposed in paper

perfusion_pytorch/perfusion.py

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def __init__(
140140
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
141141

142142
self.num_concepts = num_concepts
143-
self.has_multiple_concepts = num_concepts > 1
144143

145144
self.weight = key_or_values_proj.weight
146145
dim_output, dim_input = self.weight.shape
@@ -176,15 +175,23 @@ def __init__(
176175
C_inv = torch.inverse(C)
177176
self.register_buffer('C_inv', C_inv)
178177

178+
@property
179+
def num_concepts(self):
180+
return self._num_concepts
181+
182+
@num_concepts.setter
183+
def num_concepts(self, value):
184+
self._num_concepts = value
185+
186+
if value == 1:
187+
return
188+
179189
# for multiple concepts
180190
# need cholesky decomposed L_t_inv
181191
# Appendix B
182192

183-
if not self.has_multiple_concepts:
184-
return
185-
186193
try:
187-
L = torch.linalg.cholesky(C_inv)
194+
L = torch.linalg.cholesky(self.C_inv)
188195
except:
189196
print('unable to perform cholesky. please make sure input covariance matrix is properly calculated')
190197
exit()
@@ -195,6 +202,10 @@ def __init__(
195202
self.register_buffer('L_T', L_T, persistent = False)
196203
self.register_buffer('L_T_inv', L_T_inv, persistent = False)
197204

205+
@property
206+
def device(self):
207+
return next(self.buffers()).device
208+
198209
def parameters(self):
199210
if not self.is_key_proj:
200211
return []
@@ -205,7 +216,8 @@ def parameters(self):
205216
def forward(
206217
self,
207218
text_enc: FloatTensor,
208-
concept_indices: Union[IndicesTensor, Tuple[IndicesTensor, ...]],
219+
*,
220+
concept_indices: Optional[IndicesTensor] = None,
209221
text_enc_with_superclass: Optional[FloatTensor] = None,
210222
concept_id: Union[int, Tuple[int, ...]] = 0
211223
):
@@ -220,6 +232,7 @@ def forward(
220232
d - feature dimension
221233
i - input dimension
222234
o - output dimension
235+
c - concepts dimension (for multiple concepts)
223236
"""
224237

225238
batch, device = text_enc.shape[0], self.C_inv.device
@@ -237,47 +250,41 @@ def forward(
237250
# determine whether it is single (for training) or multi-concept (only at inference)
238251
# may separate into different modules at a future date if too complex in one module
239252

240-
is_multi_concepts = isinstance(concept_indices, tuple)
253+
is_multi_concepts = isinstance(concept_id, tuple)
241254

242255
if is_multi_concepts:
243-
num_concepts_at_forward = len(concept_indices)
244-
245256
assert not self.training, 'multi concepts can only be done at inference'
246-
assert isinstance(concept_id, tuple)
247257
assert all_unique(concept_id)
248-
assert len(concept_id) == num_concepts_at_forward
249258
assert all([cid < self.num_concepts for cid in concept_id])
250259

251-
raise NotImplementedError
260+
concept_id_tensor = torch.tensor(concept_id, dtype = torch.long, device = self.device)
252261
else:
253-
num_concepts_at_forward = 1
254-
255-
assert isinstance(concept_id, int)
256262
assert concept_id < self.num_concepts
263+
concept_id_tensor = torch.tensor([concept_id], dtype = torch.long, device = self.device)
257264

258-
# extract the concept text encoding input
265+
# get the initialization state
259266

260-
batch_indices = torch.arange(batch, device = device)
261-
batch_indices = rearrange(batch_indices, 'b -> b 1')
262-
concept_indices = rearrange(concept_indices, 'b -> b 1')
267+
if self.training:
268+
initted = self.initted[concept_id].item()
263269

264-
concept_text_enc = text_enc[batch_indices, concept_indices]
265-
concept_text_enc = reduce(concept_text_enc, 'b 1 d -> d', 'mean')
270+
# extract the concept text encoding input
266271

267-
# only if training
268-
# do exponential smoothing of the inputs, both concept and superclass
272+
batch_indices = torch.arange(batch, device = device)
273+
batch_indices = rearrange(batch_indices, 'b -> b 1')
274+
concept_indices = rearrange(concept_indices, 'b -> b 1')
269275

270-
if exists(text_enc_with_superclass):
271-
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
272-
superclass_text_enc = reduce(superclass_text_enc, 'b 1 d -> d', 'mean')
276+
concept_text_enc = text_enc[batch_indices, concept_indices]
277+
concept_text_enc = reduce(concept_text_enc, 'b 1 d -> d', 'mean')
273278

274-
superclass_output = einsum('i, o i -> o', superclass_text_enc, weights)
279+
# only if training
280+
# do exponential smoothing of the inputs, both concept and superclass
275281

276-
# get the initialization state
282+
if exists(text_enc_with_superclass):
283+
superclass_text_enc = text_enc_with_superclass[batch_indices, concept_indices]
284+
superclass_text_enc = reduce(superclass_text_enc, 'b 1 d -> d', 'mean')
277285

278-
initted = self.initted[concept_id].item()
286+
superclass_output = einsum('i, o i -> o', superclass_text_enc, weights)
279287

280-
if self.training:
281288
# store the superclass i* if not all initialized
282289
# else fetch it from the buffer
283290

@@ -311,29 +318,49 @@ def forward(
311318
self.initted[concept_id].data.copy_(Tensor([True]))
312319
self.ema_concept_text_encs[concept_id].data.copy_(concept_text_enc)
313320
else:
314-
assert initted, 'you have not initialized or trained this module yet'
321+
assert self.initted[concept_id_tensor].all(), 'you have not initialized or trained this module for the concepts id given'
315322

316323
# make it easier to match with paper
317324

318-
i, o, W = self.ema_concept_text_encs[concept_id], self.concept_outputs[concept_id], weights
325+
i, o, W = self.ema_concept_text_encs[concept_id_tensor], self.concept_outputs[concept_id_tensor], weights
319326

320327
# main contribution eq (3)
321328

322-
i_energy = opt_einsum('o, o i, i ->', i, Ci, i)
329+
i_energy = opt_einsum('c o, o i, c i ->', i, Ci, i)
323330

324-
sim = opt_einsum('b n o, o i, i -> b n', text_enc, Ci, i)
331+
sim = opt_einsum('b n o, o i, c i -> c b n', text_enc, Ci, i)
325332
sim = rearrange(sim, '... -> ... 1')
326333

327334
sigmoid_term = (((sim / i_energy) - beta) / temperature).sigmoid()
328335

329-
text_enc_output = einsum('b n i, o i -> b n o', text_enc, W)
336+
if is_multi_concepts:
337+
L_T, L_T_inv = self.L_T, self.L_T_inv
338+
339+
# metric - metric space - variable with tilde in Appendix B
340+
341+
# equation (6)
342+
343+
i_metric = einsum('o i, c i -> c o', L_T, i)
344+
u_metric, _ = torch.linalg.qr(i_metric.T)
345+
u = einsum('o i, i c -> c o', L_T_inv, u_metric)
346+
347+
# equation (10)
330348

331-
concept_output = einsum('i, o i -> o', i, W)
332-
concept_output = rearrange(concept_output, 'd -> 1 1 d')
349+
em_orthogonal = text_enc - opt_einsum('c o, b n i, c i -> b n o', u, text_enc, u)
333350

334-
W_em_orthogonal_term = text_enc_output - (sim * concept_output / i_energy)
351+
W_em_orthogonal_term = einsum('b n i, o i -> b n o', em_orthogonal, W)
352+
else:
353+
text_enc_output = einsum('b n i, o i -> b n o', text_enc, W)
354+
355+
concept_output = einsum('c i, o i -> c o', i, W)
356+
concept_output = rearrange(concept_output, 'c d -> c 1 1 d')
357+
358+
W_em_orthogonal_term = text_enc_output - reduce(sim * concept_output / i_energy, 'c ... -> ...', 'sum')
335359

336-
return W_em_orthogonal_term + sigmoid_term * rearrange(o, 'd -> 1 1 d')
360+
gated_term = sigmoid_term * rearrange(o, 'c d -> c 1 1 d')
361+
gated_term = reduce(gated_term, 'c ... -> ...', 'sum')
362+
363+
return W_em_orthogonal_term + gated_term
337364

338365
# for merging trained Rank1EditModule(s) above
339366

@@ -354,6 +381,10 @@ def merge_rank1_edit_modules(
354381

355382
concept_outputs = torch.cat(tuple(m.concept_outputs.data for m in modules), dim = 0)
356383
merged_module.concept_outputs = nn.Parameter(concept_outputs, requires_grad = not first_module.is_key_proj)
384+
385+
ema_concept_text_encs = torch.cat(tuple(m.ema_concept_text_encs.data for m in modules), dim = 0)
386+
merged_module.register_buffer('ema_concept_text_encs', ema_concept_text_encs)
387+
357388
merged_module.register_buffer('initted', torch.ones(total_concepts, 1).bool())
358389

359390
return merged_module

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.26',
6+
version = '0.0.27',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)