Skip to content

Commit c6c04fb

Browse files
committed
for embedding wrapper, allow for passing in concept ids, which would assert that the necessary concept embedding ids are present for each prompt across the batch
1 parent 0bdd1ec commit c6c04fb

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

perfusion_pytorch/embedding.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
11
import torch
22
from torch import nn
33
from torch.nn import Module
4+
45
from beartype import beartype
6+
from beartype.typing import Optional, Tuple, Union
57

68
from einops import rearrange
79

10+
# helper functions
11+
812
def exists(val):
913
return val is not None
1014

15+
def is_all_unique(arr):
16+
return len(set(arr)) == len(arr)
17+
18+
def filter_tuple_indices(tup, indices):
19+
return tuple(tup[i] for i in indices)
20+
21+
# embedding wrapper class
22+
1123
class EmbeddingWrapper(Module):
1224
@beartype
1325
def __init__(
@@ -19,17 +31,34 @@ def __init__(
1931
self.embed = embed
2032
num_embeds, dim = embed.weight.shape
2133

34+
self.num_embeds = num_embeds
2235
self.num_concepts = num_concepts
2336
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
2437
nn.init.normal_(self.concepts, std = 0.02)
2538

26-
self.concept_ids = tuple(range(num_embeds, num_embeds + num_concepts))
39+
self.concept_embed_ids = tuple(range(num_embeds, num_embeds + num_concepts))
2740

2841
def parameters(self):
2942
return [self.concepts]
3043

31-
def forward(self, x):
32-
concept_masks = tuple(concept_id == x for concept_id in self.concept_ids)
44+
def forward(
45+
self,
46+
x,
47+
concept_id: Optional[Union[int, Tuple[int, ...]]] = None
48+
):
49+
concept_masks = tuple(concept_id == x for concept_id in self.concept_embed_ids)
50+
51+
if exists(concept_id):
52+
if not isinstance(concept_id, tuple):
53+
concept_id = (concept_id,)
54+
55+
assert is_all_unique(concept_id), 'concept ids must be all unique'
56+
assert all([cid < self.num_concepts for cid in concept_id])
57+
58+
has_concept = tuple(concept_mask.any(dim = -1).all() for concept_mask in concept_masks)
59+
60+
assert all(filter_tuple_indices(has_concept, concept_id)), f'concept ids {filter_tuple_indices(self.concept_embed_ids, concept_id)} not found in ids passed in'
61+
concept_masks = filter_tuple_indices(concept_masks, concept_id)
3362

3463
for concept_mask in concept_masks:
3564
x = x.masked_fill(concept_mask, 0)
@@ -38,7 +67,7 @@ def forward(self, x):
3867
embeds = self.embed(x)
3968
embeds.detach_()
4069

41-
for concept_id, concept, concept_mask in zip(self.concept_ids, self.concepts, concept_masks):
70+
for concept, concept_mask in zip(self.concepts, concept_masks):
4271
embeds = torch.where(
4372
rearrange(concept_mask, '... -> ... 1'),
4473
concept,

perfusion_pytorch/perfusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
def exists(val):
2121
return val is not None
2222

23-
def all_unique(arr):
23+
def is_all_unique(arr):
2424
return len(set(arr)) == len(arr)
2525

2626
IndicesTensor = Union[LongTensor, IntTensor]
@@ -254,7 +254,7 @@ def forward(
254254

255255
if is_multi_concepts:
256256
assert not self.training, 'multi concepts can only be done at inference'
257-
assert all_unique(concept_id)
257+
assert is_all_unique(concept_id)
258258
assert all([cid < self.num_concepts for cid in concept_id])
259259

260260
concept_id_tensor = torch.tensor(concept_id, dtype = torch.long, device = self.device)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.27',
6+
version = '0.0.28',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)