11import torch
2- from torch import nn
2+ from torch import nn , Tensor
33from torch .nn import Module
44
55from collections import namedtuple
66
77from beartype import beartype
8- from beartype .typing import Optional , Tuple , Union
8+ from beartype .door import is_bearable
9+ from beartype .typing import Optional , Tuple , Union , Callable , List
910
1011from einops import rearrange
1112
13+ from open_clip import tokenizer
14+
1215# constants
1316
1417EmbeddingReturn = namedtuple ('EmbeddingReturn' , [
1518 'embed_with_concept' ,
16- 'embed_with_superclass'
19+ 'embed_with_superclass' ,
20+ 'embed_mask'
1721])
1822
1923# helper functions
2024
2125def exists (val ):
2226 return val is not None
2327
28+ def default (val , d ):
29+ return val if exists (val ) else d
30+
2431def is_all_unique (arr ):
2532 return len (set (arr )) == len (arr )
2633
2734def filter_tuple_indices (tup , indices ):
2835 return tuple (tup [i ] for i in indices )
2936
37+ @beartype
38+ def get_mask (
39+ x : Tensor ,
40+ ids : Tuple [int , ...]
41+ ):
42+ masks = tuple (x == i for i in ids )
43+ mask , * rest_masks = masks
44+
45+ for rest_mask in rest_masks :
46+ mask = mask | rest_mask
47+
48+ return mask
49+
3050# embedding wrapper class
3151
3252class EmbeddingWrapper (Module ):
53+
3354 @beartype
3455 def __init__ (
3556 self ,
3657 embed : nn .Embedding ,
3758 num_concepts = 1 ,
38- superclass_embed_id : Optional [Union [int , Tuple [int , ...]]] = None
59+ superclass_embed_id : Optional [Union [int , Tuple [int , ...]]] = None ,
60+ superclass_string : Optional [str ] = None ,
61+ tokenize : Callable [str , Tensor ] = tokenizer .tokenize ,
62+ tokenizer_pad_id : int = 0 ,
63+ tokenizer_sos_eos_id : Tuple [int , int ] = (49406 , 49407 )
3964 ):
4065 super ().__init__ ()
4166 self .embed = embed
@@ -45,7 +70,27 @@ def __init__(
4570 self .num_concepts = num_concepts
4671 self .concepts = nn .Parameter (torch .zeros (num_concepts , dim ))
4772
73+ assert exists (superclass_embed_id ) ^ exists (superclass_string ), 'either superclass embed id is given, or the superclass string'
74+
75+ self .pad_id = tokenizer_pad_id
76+ self .tokenize = None
77+
78+ if exists (superclass_string ):
79+ self .tokenize = tokenize
80+
81+ ids = tokenize ([superclass_string ])[0 ]
82+
83+ mask_for_ids = get_mask (ids , (tokenizer_pad_id , * tokenizer_sos_eos_id ))
84+ ids = ids [~ mask_for_ids ]
85+
86+ assert ids .shape [- 1 ] == 1 , f'your superclass concept string must map exactly one token id'
87+ superclass_embed_id = ids [0 ].item ()
88+
89+ print (f'super class embed for "{ superclass_string } "" set as { superclass_embed_id } ' )
90+ print (f'you can now pass in a list of strings containing superclass concept, and this wrapper will return the embedding w/ concept and superclass required for finetuning' )
91+
4892 self .superclass_embed_id = superclass_embed_id
93+
4994 assert not (exists (superclass_embed_id ) and num_concepts > 1 ), 'cannot do multi concept with superclass embed id given'
5095
5196 if exists (superclass_embed_id ):
@@ -67,18 +112,42 @@ def __init__(
67112 def parameters (self ):
68113 return [self .concepts ]
69114
115+ @beartype
70116 def forward (
71117 self ,
72- x ,
118+ x : Union [ Tensor , List [ str ]] ,
73119 concept_id : Optional [Union [int , Tuple [int , ...]]] = None ,
74120 return_embed_with_superclass = True
75121 ) -> EmbeddingReturn :
76- concept_masks = tuple (concept_id == x for concept_id in self .concept_embed_ids )
77122
78123 if exists (concept_id ):
79124 if not isinstance (concept_id , tuple ):
80125 concept_id = (concept_id ,)
81126
127+ assert len (concept_id ) == 1 , 'can only train or inference on single concepts if passing in list of superclass strings'
128+ assert self .num_concepts == 1
129+
130+ if is_bearable (x , List [str ]):
131+ inferred_concept_id = self .concept_embed_ids [0 ]
132+
133+ x = self .tokenize (x )
134+
135+ superclass_mask = x == self .superclass_embed_id
136+ assert superclass_mask .any (dim = - 1 ).all (), 'superclass embed id must be present for all prompts'
137+
138+ # automatically replace the superclass id with the concept id
139+ x = torch .where (superclass_mask , inferred_concept_id , x )
140+
141+ # get the embedding mask, defined as not padding id
142+ # default to open clip tokenizer padding id of 0
143+
144+ embed_mask = x != self .pad_id
145+
146+ # get masks for all concepts (support for multi-concepts)
147+
148+ concept_masks = tuple (concept_id == x for concept_id in self .concept_embed_ids )
149+
150+ if exists (concept_id ):
82151 assert is_all_unique (concept_id ), 'concept ids must be all unique'
83152 assert all ([cid < self .num_concepts for cid in concept_id ])
84153
@@ -87,13 +156,20 @@ def forward(
87156 assert all (filter_tuple_indices (has_concept , concept_id )), f'concept ids { filter_tuple_indices (self .concept_embed_ids , concept_id )} not found in ids passed in'
88157 concept_masks = filter_tuple_indices (concept_masks , concept_id )
89158
159+ # just fetch the first embedding, as the concept embeddings are kept external to nn.Embedding
160+
90161 for concept_mask in concept_masks :
91162 x = x .masked_fill (concept_mask , 0 )
92163
164+ # get all the embeddings that are not the concept or superclass concept
165+ # rest of embeddings are also not learnable, only concept embedding
166+
93167 with torch .no_grad ():
94168 embeds = self .embed (x )
95169 embeds .detach_ ()
96170
171+ # substitute the concept back into the embeddings
172+
97173 for concept , concept_mask in zip (self .concepts , concept_masks ):
98174 embeds = torch .where (
99175 rearrange (concept_mask , '... -> ... 1' ),
@@ -110,9 +186,9 @@ def forward(
110186 with torch .no_grad ():
111187 superclass_embeds = self .embed (x )
112188
113- return EmbeddingReturn (embeds , superclass_embeds )
189+ return EmbeddingReturn (embeds , superclass_embeds , embed_mask )
114190
115- return EmbeddingReturn (embeds , None )
191+ return EmbeddingReturn (embeds , None , embed_mask )
116192
117193@beartype
118194def merge_embedding_wrappers (
0 commit comments