@@ -137,18 +137,26 @@ def __init__(
137137 self .scale = dim_head ** - 0.5
138138 self .heads = heads
139139 hidden_dim = dim_head * heads
140- self .to_qkv = nn .Conv2d (dim , hidden_dim * 3 , 1 , bias = False )
140+ self .to_q = nn .Conv2d (dim , hidden_dim , 1 , bias = False )
141+ self .to_kv = nn .Conv2d (dim , hidden_dim * 2 , 1 , bias = False )
141142 self .to_out = nn .Conv2d (hidden_dim , dim , 1 )
142143
143- def forward (self , x ):
144+ def forward (
145+ self ,
146+ x ,
147+ context = None
148+ ):
144149 b , c , h , w = x .shape
145- qkv = self .to_qkv (x ).chunk (3 , dim = 1 )
150+ context = default (context , x )
151+
152+ qkv = (self .to_q (x ), * self .to_kv (context ).chunk (2 , dim = 1 ))
146153 q , k , v = map (lambda t : rearrange (t , 'b (h c) x y -> b h c (x y)' , h = self .heads ), qkv )
147154
148155 q = q * self .scale
149156
150157 sim = einsum ('b h d i, b h d j -> b h i j' , q , k )
151158 attn = sim .softmax (dim = - 1 )
159+
152160 out = einsum ('b h i j, b h d j -> b h i d' , attn , v )
153161 out = rearrange (out , 'b h (x y) d -> b (h d) x y' , x = h , y = w )
154162 return self .to_out (out )
0 commit comments