Skip to content

Commit 20cad68

Browse files
committed
fix a bug with ema thanks to Yoad, also allow for initializing concept embed id with superclass, as suggested in feedback
1 parent 2c3fb7a commit 20cad68

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

perfusion_pytorch/embedding.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ class EmbeddingWrapper(Module):
2525
def __init__(
2626
self,
2727
embed: nn.Embedding,
28-
num_concepts = 1
28+
num_concepts = 1,
29+
superclass_embed_id: Optional[Union[int, Tuple[int, ...]]] = None
2930
):
3031
super().__init__()
3132
self.embed = embed
@@ -34,7 +35,20 @@ def __init__(
3435
self.num_embeds = num_embeds
3536
self.num_concepts = num_concepts
3637
self.concepts = nn.Parameter(torch.zeros(num_concepts, dim))
37-
nn.init.normal_(self.concepts, std = 0.02)
38+
39+
if exists(superclass_embed_id):
40+
# author had better results initializing the concept embed to the super class embed, allow for that option
41+
42+
if not isinstance(superclass_embed_id, tuple):
43+
superclass_embed_id = (superclass_embed_id,)
44+
45+
superclass_embed_indices = torch.tensor(list(superclass_embed_id))
46+
superclass_embeds = embed(superclass_embed_indices)
47+
self.concepts.data.copy_(superclass_embeds)
48+
else:
49+
# otherwise initialize to usually small init for embeds
50+
51+
nn.init.normal_(self.concepts, std = 0.02)
3852

3953
self.concept_embed_ids = tuple(range(num_embeds, num_embeds + num_concepts))
4054

@@ -44,7 +58,7 @@ def parameters(self):
4458
def forward(
4559
self,
4660
x,
47-
concept_id: Optional[Union[int, Tuple[int, ...]]] = None
61+
concept_id: Optional[Union[int, Tuple[int, ...]]] = None,
4862
):
4963
concept_masks = tuple(concept_id == x for concept_id in self.concept_embed_ids)
5064

perfusion_pytorch/perfusion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,11 @@ def forward(
316316

317317
if not initted:
318318
self.initted[concept_id].data.copy_(Tensor([True]))
319-
self.ema_concept_text_encs[concept_id].data.copy_(concept_text_enc)
319+
320+
# update ema i_*
321+
322+
self.ema_concept_text_encs[concept_id].data.copy_(concept_text_enc)
323+
320324
else:
321325
assert self.initted[concept_id_tensor].all(), 'you have not initialized or trained this module for the concepts id given'
322326

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'perfusion-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.30',
6+
version = '0.0.31',
77
license='MIT',
88
description = 'Perfusion - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)