Skip to content

Commit c082d15

Browse files
committed
get all the hyperparameters in place
1 parent 38965af commit c082d15

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

perfusion_pytorch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from perfusion_pytorch.perfusion import (
2+
Rank1EditModule
3+
)

perfusion_pytorch/perfusion.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
from torch import nn, einsum, Tensor
3+
from torch.nn import Module
4+
import torch.nn.functional as F
5+
6+
from beartype import beartype
7+
from einops import rearrange
8+
9+
# helpers
10+
11+
def exists(val):
12+
return val is not None
13+
14+
# main contribution of paper
15+
# a module that wraps the keys and values projection of the cross attentions to text encodings
16+
17+
class Rank1EditModule(Module):
18+
19+
@beartype
20+
def __init__(
21+
self,
22+
key_or_values_proj: nn.Linear,
23+
*,
24+
C: Tensor,
25+
input_decay = 0.99,
26+
train_beta = 0.75,
27+
train_temperature = 0.1,
28+
eval_beta = 0.70, # in paper, specified a range (0.6 - 0.75) for local-key lock, and (0.4 -0.6) for global-key lock
29+
eval_temperature = 0.15
30+
):
31+
super().__init__()
32+
assert not exists(key_or_values_proj.bias), 'key value projection in attention should not have bias'
33+
34+
self.weight = key_or_values_proj.weight
35+
36+
self.train_beta = train_beta
37+
self.train_temperature = train_temperature
38+
self.eval_beta = eval_beta
39+
self.eval_temperature = eval_temperature
40+
41+
self.input_decay = input_decay
42+
43+
# buffers
44+
45+
self.register_buffer('C_inv', torch.inverse(C))
46+
47+
@beartype
48+
def forward(
49+
self,
50+
text_enc: Tensor,
51+
concept_indices: Tensor
52+
):
53+
"""
54+
following the pseudocode of Algorithm 1 in appendix
55+
"""
56+
57+
return text_enc

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
'text-to-image'
1818
],
1919
install_requires=[
20+
'beartype',
2021
'einops>=0.6.1',
2122
'torch>=2.0'
2223
],

0 commit comments

Comments
 (0)