11import torch
22from torch import nn
33from torch .nn import Module
4+
45from beartype import beartype
6+ from beartype .typing import Optional , Tuple , Union
57
68from einops import rearrange
79
10+ # helper functions
11+
812def exists (val ):
913 return val is not None
1014
15+ def is_all_unique (arr ):
16+ return len (set (arr )) == len (arr )
17+
18+ def filter_tuple_indices (tup , indices ):
19+ return tuple (tup [i ] for i in indices )
20+
21+ # embedding wrapper class
22+
1123class EmbeddingWrapper (Module ):
1224 @beartype
1325 def __init__ (
@@ -19,17 +31,34 @@ def __init__(
1931 self .embed = embed
2032 num_embeds , dim = embed .weight .shape
2133
34+ self .num_embeds = num_embeds
2235 self .num_concepts = num_concepts
2336 self .concepts = nn .Parameter (torch .zeros (num_concepts , dim ))
2437 nn .init .normal_ (self .concepts , std = 0.02 )
2538
26- self .concept_ids = tuple (range (num_embeds , num_embeds + num_concepts ))
39+ self .concept_embed_ids = tuple (range (num_embeds , num_embeds + num_concepts ))
2740
2841 def parameters (self ):
2942 return [self .concepts ]
3043
31- def forward (self , x ):
32- concept_masks = tuple (concept_id == x for concept_id in self .concept_ids )
44+ def forward (
45+ self ,
46+ x ,
47+ concept_id : Optional [Union [int , Tuple [int , ...]]] = None
48+ ):
49+ concept_masks = tuple (concept_id == x for concept_id in self .concept_embed_ids )
50+
51+ if exists (concept_id ):
52+ if not isinstance (concept_id , tuple ):
53+ concept_id = (concept_id ,)
54+
55+ assert is_all_unique (concept_id ), 'concept ids must be all unique'
56+ assert all ([cid < self .num_concepts for cid in concept_id ])
57+
58+ has_concept = tuple (concept_mask .any (dim = - 1 ).all () for concept_mask in concept_masks )
59+
60+ assert all (filter_tuple_indices (has_concept , concept_id )), f'concept ids { filter_tuple_indices (self .concept_embed_ids , concept_id )} not found in ids passed in'
61+ concept_masks = filter_tuple_indices (concept_masks , concept_id )
3362
3463 for concept_mask in concept_masks :
3564 x = x .masked_fill (concept_mask , 0 )
@@ -38,7 +67,7 @@ def forward(self, x):
3867 embeds = self .embed (x )
3968 embeds .detach_ ()
4069
41- for concept_id , concept , concept_mask in zip (self . concept_ids , self .concepts , concept_masks ):
70+ for concept , concept_mask in zip (self .concepts , concept_masks ):
4271 embeds = torch .where (
4372 rearrange (concept_mask , '... -> ... 1' ),
4473 concept ,
0 commit comments