Skip to content

Commit efbb6cd

Browse files
committed
allow for adding full self attention in specific stages of the unet
1 parent 0707009 commit efbb6cd

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from functools import partial
55
from collections import namedtuple
66

7+
from beartype import beartype
8+
79
import torch
810
from torch import nn, einsum
911
import torch.nn.functional as F
1012
from torch.fft import fft2, ifft2
1113

12-
from einops import rearrange, reduce, pack
14+
from einops import rearrange, reduce
1315
from einops.layers.torch import Rearrange
1416

1517
from tqdm.auto import tqdm
@@ -224,15 +226,17 @@ def forward(self, x, c):
224226

225227
# model
226228

229+
@beartype
227230
class Unet(nn.Module):
228231
def __init__(
229232
self,
230233
dim,
231234
image_size,
232235
init_dim = None,
233236
out_dim = None,
234-
dim_mults=(1, 2, 4, 8),
237+
dim_mults: tuple = (1, 2, 4, 8),
235238
channels = 3,
239+
full_self_attn: tuple = (False, False, False, True),
236240
self_condition = False,
237241
resnet_block_groups = 8,
238242
conditioning_klass = Conditioning,
@@ -272,6 +276,7 @@ def __init__(
272276
# layers
273277

274278
num_resolutions = len(in_out)
279+
assert len(full_self_attn) == num_resolutions
275280

276281
self.conditioners = nn.ModuleList([])
277282

@@ -283,15 +288,17 @@ def __init__(
283288

284289
curr_fmap_size = image_size
285290

286-
for ind, (dim_in, dim_out) in enumerate(in_out):
291+
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(in_out, full_self_attn)):
287292
is_last = ind >= (num_resolutions - 1)
293+
attn_klass = Attention if full_attn else LinearAttention
288294

289295
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in))
290296

297+
291298
self.downs.append(nn.ModuleList([
292299
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
293300
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
294-
Residual(LinearAttention(dim_in)),
301+
Residual(attn_klass(dim_in)),
295302
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
296303
]))
297304

@@ -314,15 +321,16 @@ def __init__(
314321

315322
self.ups = nn.ModuleList([])
316323

317-
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
324+
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(reversed(in_out), reversed(full_self_attn))):
318325
is_last = ind == (len(in_out) - 1)
326+
attn_klass = Attention if full_attn else LinearAttention
319327

320328
skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmaps else 1)
321329

322330
self.ups.append(nn.ModuleList([
323331
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
324332
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
325-
Residual(LinearAttention(dim_out)),
333+
Residual(attn_klass(dim_out)),
326334
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
327335
]))
328336

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.5',
6+
version = '0.0.6',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',
@@ -17,6 +17,7 @@
1717
'medical segmentation'
1818
],
1919
install_requires=[
20+
'beartype',
2021
'einops',
2122
'torch',
2223
'tqdm'

0 commit comments

Comments
 (0)