@@ -25,7 +25,8 @@ class EmbeddingWrapper(Module):
2525 def __init__ (
2626 self ,
2727 embed : nn .Embedding ,
28- num_concepts = 1
28+ num_concepts = 1 ,
29+ superclass_embed_id : Optional [Union [int , Tuple [int , ...]]] = None
2930 ):
3031 super ().__init__ ()
3132 self .embed = embed
@@ -34,7 +35,20 @@ def __init__(
3435 self .num_embeds = num_embeds
3536 self .num_concepts = num_concepts
3637 self .concepts = nn .Parameter (torch .zeros (num_concepts , dim ))
37- nn .init .normal_ (self .concepts , std = 0.02 )
38+
39+ if exists (superclass_embed_id ):
40+ # author had better results initializing the concept embed to the super class embed, allow for that option
41+
42+ if not isinstance (superclass_embed_id , tuple ):
43+ superclass_embed_id = (superclass_embed_id ,)
44+
45+ superclass_embed_indices = torch .tensor (list (superclass_embed_id ))
46+ superclass_embeds = embed (superclass_embed_indices )
47+ self .concepts .data .copy_ (superclass_embeds )
48+ else :
49+ # otherwise initialize to usually small init for embeds
50+
51+ nn .init .normal_ (self .concepts , std = 0.02 )
3852
3953 self .concept_embed_ids = tuple (range (num_embeds , num_embeds + num_concepts ))
4054
@@ -44,7 +58,7 @@ def parameters(self):
4458 def forward (
4559 self ,
4660 x ,
47- concept_id : Optional [Union [int , Tuple [int , ...]]] = None
61+ concept_id : Optional [Union [int , Tuple [int , ...]]] = None ,
4862 ):
4963 concept_masks = tuple (concept_id == x for concept_id in self .concept_embed_ids )
5064
0 commit comments