Skip to content

Commit 11dc571

Browse files
committed
absolute positional embedding for vision transformer
1 parent f9447b0 commit 11dc571

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __init__(
236236
self,
237237
dim,
238238
*,
239+
image_size,
239240
patch_size,
240241
channels = 3,
241242
channels_out = None,
@@ -244,6 +245,13 @@ def __init__(
244245
depth = 4,
245246
):
246247
super().__init__()
248+
assert exists(image_size)
249+
assert (image_size % patch_size) == 0
250+
251+
num_patches_height_width = image_size // patch_size
252+
253+
self.pos_emb = nn.Parameter(torch.zeros(dim, num_patches_height_width, num_patches_height_width))
254+
247255
channels_out = default(channels_out, channels)
248256

249257
patch_dim = channels * (patch_size ** 2)
@@ -272,6 +280,8 @@ def __init__(
272280

273281
def forward(self, x):
274282
x = self.to_tokens(x)
283+
x = x + self.pos_emb
284+
275285
x = self.transformer(x)
276286
return self.to_patches(x)
277287

@@ -283,6 +293,7 @@ def __init__(
283293
fmap_size,
284294
dim,
285295
dynamic = True,
296+
image_size = None,
286297
dim_head = 32,
287298
heads = 4,
288299
depth = 4,
@@ -412,6 +423,7 @@ def __init__(
412423
if conditioning_klass == Conditioning:
413424
conditioning_klass = partial(
414425
Conditioning,
426+
image_size = image_size,
415427
dynamic = dynamic_ff_parser_attn_map,
416428
**conditioning_kwargs
417429
)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.0',
6+
version = '0.3.1',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)