Skip to content

Commit 2fd0303

Browse files
committed
allow for easy transformation of text and superclass embeds into text encodings and superclass text encodings, if the clip transformer (post embed -> post transformer embed after the last layernorm) is passed in
1 parent e15c326 commit 2fd0303

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

perfusion_pytorch/embedding.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def forward(
118118
self,
119119
x: Union[Tensor, List[str]],
120120
concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
121-
return_embed_with_superclass = True
121+
return_embed_with_superclass = True,
122+
clip_transformer_fn: Optional[Callable[[Tensor], Tensor]] = None
122123
) -> EmbeddingReturn:
123124

124125
assert not (self.training and self.num_concepts > 1), 'cannot train with multiple concepts'
@@ -195,15 +196,30 @@ def forward(
195196
# if training, and superclass embed id given
196197
# also return embeddings with superclass, for deriving superclass_text_enc
197198

199+
superclass_embeds = None
200+
198201
if self.training and exists(self.superclass_embed_id) and return_embed_with_superclass:
199202
x = x.masked_fill(concept_masks[0], self.superclass_embed_id)
200203

201204
with torch.no_grad():
202205
superclass_embeds = self.embed(x)
203206

204-
return EmbeddingReturn(embeds, superclass_embeds, embed_mask, concept_indices)
207+
# if the clip transformer function is passed in, transform the embeds and superclass_embeds into the text_enc and superclass_text_enc, to be forwarded by cross attentions into the Rank1EditModules
208+
209+
if exists(clip_transformer_fn):
210+
with torch.no_grad():
211+
embeds = clip_transformer_fn(embeds)
212+
213+
if exists(superclass_embeds):
214+
superclass_embeds = clip_transformer_fn(superclass_embeds)
215+
216+
# return tuple, with
217+
# 1. text embeds | encodings
218+
# 2. superclass text embeds | encoding
219+
# 3. text mask
220+
# 4. concept indices
205221

206-
return EmbeddingReturn(embeds, None, embed_mask, concept_indices)
222+
return EmbeddingReturn(embeds, superclass_embeds, embed_mask, concept_indices)
207223

208224
@beartype
209225
def merge_embedding_wrappers(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.14',
6+
version = '0.1.15',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)