Skip to content

Commit 03059d9

Browse files
committed
allow for simply fine tuning off prompts with superclass string, without manipulation of the BPE dictionary
1 parent 27e31c5 commit 03059d9

File tree

3 files changed

+114
-9
lines changed

3 files changed

+114
-9
lines changed

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,35 @@ values = wrapped_to_values(text_enc)
7373

7474
```
7575

76+
The repository also contains an `EmbeddingWrapper` that makes it easy to train on a new concept (and for eventual inference with multiple concepts)
77+
78+
```python
79+
import torch
80+
from torch import nn
81+
82+
from perfusion_pytorch import EmbeddingWrapper
83+
84+
embed = nn.Embedding(49407, 512) # open clip embedding, somewhere in the module tree of stable diffusion
85+
86+
# wrap it, and will automatically create a new concept for learning, based on the superclass embed string
87+
88+
wrapped_embed = EmbeddingWrapper(
89+
embed,
90+
superclass_string = 'dog'
91+
)
92+
93+
# now just pass in your prompts with the superclass id
94+
95+
embeds_with_new_concept, embeds_with_superclass, embed_mask = wrapped_embed([
96+
'a portrait of dog',
97+
'dog running through a green field',
98+
'a man walking his dog'
99+
]) # (3, 77, 512), (3, 77, 512), (3, 77)
100+
101+
# now pass both embeds through clip text transformer
102+
# the embed_mask needs to be passed to the cross attention as key padding mask
103+
```
104+
76105
## Todo
77106

78107
- [ ] wire up with SD 1.5, starting with xiao's dreambooth-sd

perfusion_pytorch/embedding.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,66 @@
11
import torch
2-
from torch import nn
2+
from torch import nn, Tensor
33
from torch.nn import Module
44

55
from collections import namedtuple
66

77
from beartype import beartype
8-
from beartype.typing import Optional, Tuple, Union
8+
from beartype.door import is_bearable
9+
from beartype.typing import Optional, Tuple, Union, Callable, List
910

1011
from einops import rearrange
1112

13+
from open_clip import tokenizer
14+
1215
# constants
1316

1417
EmbeddingReturn = namedtuple('EmbeddingReturn', [
1518
'embed_with_concept',
16-
'embed_with_superclass'
19+
'embed_with_superclass',
20+
'embed_mask'
1721
])
1822

1923
# helper functions
2024

2125
def exists(val):
2226
return val is not None
2327

28+
def default(val, d):
29+
return val if exists(val) else d
30+
2431
def is_all_unique(arr):
2532
return len(set(arr)) == len(arr)
2633

2734
def filter_tuple_indices(tup, indices):
2835
return tuple(tup[i] for i in indices)
2936

37+
@beartype
38+
def get_mask(
39+
x: Tensor,
40+
ids: Tuple[int, ...]
41+
):
42+
masks = tuple(x == i for i in ids)
43+
mask, *rest_masks = masks
44+
45+
for rest_mask in rest_masks:
46+
mask = mask | rest_mask
47+
48+
return mask
49+
3050
# embedding wrapper class
3151

3252
class EmbeddingWrapper(Module):
53+
3354
@beartype
3455
def __init__(
3556
self,
3657
embed: nn.Embedding,
3758
num_concepts = 1,
38-
superclass_embed_id: Optional[Union[int, Tuple[int, ...]]] = None
59+
superclass_embed_id: Optional[Union[int, Tuple[int, ...]]] = None,
60+
superclass_string: Optional[str] = None,
61+
tokenize: Callable[str, Tensor] = tokenizer.tokenize,
62+
tokenizer_pad_id: int = 0,
63+
tokenizer_sos_eos_id: Tuple[int, int] = (49406, 49407)
3964
):
4065
super().__init__()
4166
self.embed = embed
@@ -45,7 +70,27 @@ def __init__(
4570
self.num_concepts = num_concepts
4671
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
4772

73+
assert exists(superclass_embed_id) ^ exists(superclass_string), 'either superclass embed id is given, or the superclass string'
74+
75+
self.pad_id = tokenizer_pad_id
76+
self.tokenize = None
77+
78+
if exists(superclass_string):
79+
self.tokenize = tokenize
80+
81+
ids = tokenize([superclass_string])[0]
82+
83+
mask_for_ids = get_mask(ids, (tokenizer_pad_id, *tokenizer_sos_eos_id))
84+
ids = ids[~mask_for_ids]
85+
86+
assert ids.shape[-1] == 1, f'your superclass concept string must map exactly one token id'
87+
superclass_embed_id = ids[0].item()
88+
89+
print(f'super class embed for "{superclass_string}"" set as {superclass_embed_id}')
90+
print(f'you can now pass in a list of strings containing superclass concept, and this wrapper will return the embedding w/ concept and superclass required for finetuning')
91+
4892
self.superclass_embed_id = superclass_embed_id
93+
4994
assert not (exists(superclass_embed_id) and num_concepts > 1), 'cannot do multi concept with superclass embed id given'
5095

5196
if exists(superclass_embed_id):
@@ -67,18 +112,42 @@ def __init__(
67112
def parameters(self):
68113
return [self.concepts]
69114

115+
@beartype
70116
def forward(
71117
self,
72-
x,
118+
x: Union[Tensor, List[str]],
73119
concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
74120
return_embed_with_superclass = True
75121
) -> EmbeddingReturn:
76-
concept_masks = tuple(concept_id == x for concept_id in self.concept_embed_ids)
77122

78123
if exists(concept_id):
79124
if not isinstance(concept_id, tuple):
80125
concept_id = (concept_id,)
81126

127+
assert len(concept_id) == 1, 'can only train or inference on single concepts if passing in list of superclass strings'
128+
assert self.num_concepts == 1
129+
130+
if is_bearable(x, List[str]):
131+
inferred_concept_id = self.concept_embed_ids[0]
132+
133+
x = self.tokenize(x)
134+
135+
superclass_mask = x == self.superclass_embed_id
136+
assert superclass_mask.any(dim = -1).all(), 'superclass embed id must be present for all prompts'
137+
138+
# automatically replace the superclass id with the concept id
139+
x = torch.where(superclass_mask, inferred_concept_id, x)
140+
141+
# get the embedding mask, defined as not padding id
142+
# default to open clip tokenizer padding id of 0
143+
144+
embed_mask = x != self.pad_id
145+
146+
# get masks for all concepts (support for multi-concepts)
147+
148+
concept_masks = tuple(concept_id == x for concept_id in self.concept_embed_ids)
149+
150+
if exists(concept_id):
82151
assert is_all_unique(concept_id), 'concept ids must be all unique'
83152
assert all([cid < self.num_concepts for cid in concept_id])
84153

@@ -87,13 +156,20 @@ def forward(
87156
assert all(filter_tuple_indices(has_concept, concept_id)), f'concept ids {filter_tuple_indices(self.concept_embed_ids, concept_id)} not found in ids passed in'
88157
concept_masks = filter_tuple_indices(concept_masks, concept_id)
89158

159+
# just fetch the first embedding, as the concept embeddings are kept external to nn.Embedding
160+
90161
for concept_mask in concept_masks:
91162
x = x.masked_fill(concept_mask, 0)
92163

164+
# get all the embeddings that are not the concept or superclass concept
165+
# rest of embeddings are also not learnable, only concept embedding
166+
93167
with torch.no_grad():
94168
embeds = self.embed(x)
95169
embeds.detach_()
96170

171+
# substitute the concept back into the embeddings
172+
97173
for concept, concept_mask in zip(self.concepts, concept_masks):
98174
embeds = torch.where(
99175
rearrange(concept_mask, '... -> ... 1'),
@@ -110,9 +186,9 @@ def forward(
110186
with torch.no_grad():
111187
superclass_embeds = self.embed(x)
112188

113-
return EmbeddingReturn(embeds, superclass_embeds)
189+
return EmbeddingReturn(embeds, superclass_embeds, embed_mask)
114190

115-
return EmbeddingReturn(embeds, None)
191+
return EmbeddingReturn(embeds, None, embed_mask)
116192

117193
@beartype
118194
def merge_embedding_wrappers(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.9',
6+
version = '0.1.10',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)