Skip to content

Commit 16b74de

Browse files
committed
another modification to karras unet3d for @QuantPrincess medical imaging work #295
1 parent 3a3ab1b commit 16b74de

File tree

2 files changed

+102
-21
lines changed

2 files changed

+102
-21
lines changed

denoising_diffusion_pytorch/karras_unet_3d.py

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
from math import sqrt, ceil
77
from functools import partial
8+
from typing import Optional, Union, Tuple
89

910
import torch
1011
from torch import nn, einsum
@@ -207,12 +208,15 @@ def __init__(
207208
attn_dim_head = 64,
208209
attn_res_mp_add_t = 0.3,
209210
attn_flash = False,
210-
downsample = False
211+
downsample = False,
212+
downsample_config: Tuple[bool, bool, bool] = (True, True, True)
211213
):
212214
super().__init__()
213215
dim_out = default(dim_out, dim)
214216

215217
self.downsample = downsample
218+
self.downsample_config = downsample_config
219+
216220
self.downsample_conv = None
217221

218222
curr_dim = dim
@@ -259,7 +263,10 @@ def forward(
259263
):
260264
if self.downsample:
261265
t, h, w = x.shape[-3:]
262-
x = F.interpolate(x, (t // 2, h // 2, w // 2), mode = 'trilinear')
266+
resize_factors = tuple((2 if downsample else 1) for downsample in self.downsample_config)
267+
interpolate_shape = tuple(shape // factor for shape, factor in zip((t, h, w), resize_factors))
268+
269+
x = F.interpolate(x, interpolate_shape, mode = 'trilinear')
263270
x = self.downsample_conv(x)
264271

265272
x = self.pixel_norm(x)
@@ -294,12 +301,15 @@ def __init__(
294301
attn_dim_head = 64,
295302
attn_res_mp_add_t = 0.3,
296303
attn_flash = False,
297-
upsample = False
304+
upsample = False,
305+
upsample_config: Tuple[bool, bool, bool] = (True, True, True)
298306
):
299307
super().__init__()
300308
dim_out = default(dim_out, dim)
301309

302310
self.upsample = upsample
311+
self.upsample_config = upsample_config
312+
303313
self.needs_skip = not upsample
304314

305315
self.to_emb = None
@@ -341,7 +351,10 @@ def forward(
341351
):
342352
if self.upsample:
343353
t, h, w = x.shape[-3:]
344-
x = F.interpolate(x, (t * 2, h * 2, w * 2), mode = 'trilinear')
354+
resize_factors = tuple((2 if upsample else 1) for upsample in self.upsample_config)
355+
interpolate_shape = tuple(shape * factor for shape, factor in zip((t, h, w), resize_factors))
356+
357+
x = F.interpolate(x, interpolate_shape, mode = 'trilinear')
345358

346359
res = self.res_conv(x)
347360

@@ -416,12 +429,14 @@ def __init__(
416429
self,
417430
*,
418431
image_size,
432+
frames,
419433
dim = 192,
420434
dim_max = 768, # channels will double every downsample and cap out to this value
421435
num_classes = None, # in paper, they do 1000 classes for a popular benchmark
422436
channels = 4, # 4 channels in paper for some reason, must be alpha channel?
423437
num_downsamples = 3,
424-
num_blocks_per_stage = 4,
438+
num_blocks_per_stage: Union[int, Tuple[int, ...]] = 4,
439+
downsample_types: Optional[Tuple[str, ...]] = None,
425440
attn_res = (16, 8),
426441
fourier_dim = 16,
427442
attn_dim_head = 64,
@@ -440,7 +455,9 @@ def __init__(
440455
# determine dimensions
441456

442457
self.channels = channels
458+
self.frames = frames
443459
self.image_size = image_size
460+
444461
input_channels = channels * (2 if self_condition else 1)
445462

446463
# input and output blocks
@@ -478,6 +495,25 @@ def __init__(
478495

479496
self.num_downsamples = num_downsamples
480497

498+
# specifying downsample types (either image, frames, or both)
499+
500+
downsample_types = default(downsample_types, 'all')
501+
downsample_types = cast_tuple(downsample_types, num_downsamples)
502+
503+
assert len(downsample_types) == num_downsamples
504+
assert all([t in {'all', 'frame', 'image'} for t in downsample_types])
505+
506+
# number of blocks per downsample
507+
508+
num_blocks_per_stage = cast_tuple(num_blocks_per_stage, num_downsamples)
509+
510+
if len(num_blocks_per_stage) == num_downsamples:
511+
first, *_ = num_blocks_per_stage
512+
num_blocks_per_stage = (first, *num_blocks_per_stage)
513+
514+
assert len(num_blocks_per_stage) == (num_downsamples + 1)
515+
assert all([num_blocks >= 1 for num_blocks in num_blocks_per_stage])
516+
481517
# attention
482518

483519
attn_res = set(cast_tuple(attn_res))
@@ -498,17 +534,18 @@ def __init__(
498534
self.ups = ModuleList([])
499535

500536
curr_dim = dim
501-
curr_res = image_size
537+
curr_image_res = image_size
538+
curr_frame_res = frames
502539

503540
self.skip_mp_cat = MPCat(t = mp_cat_t, dim = 1)
504541

505542
# take care of skip connection for initial input block and first three encoder blocks
506543

507544
prepend(self.ups, Decoder(dim * 2, dim, **block_kwargs))
508545

509-
assert num_blocks_per_stage >= 1
546+
init_num_blocks_per_stage, *rest_num_blocks_per_stage = num_blocks_per_stage
510547

511-
for _ in range(num_blocks_per_stage):
548+
for _ in range(init_num_blocks_per_stage):
512549
enc = Encoder(curr_dim, curr_dim, **block_kwargs)
513550
dec = Decoder(curr_dim * 2, curr_dim, **block_kwargs)
514551

@@ -517,20 +554,53 @@ def __init__(
517554

518555
# stages
519556

520-
for _ in range(self.num_downsamples):
557+
for _, layer_num_blocks_per_stage, layer_downsample_type in zip(range(self.num_downsamples), rest_num_blocks_per_stage, downsample_types):
558+
521559
dim_out = min(dim_max, curr_dim * 2)
522-
upsample = Decoder(dim_out, curr_dim, has_attn = curr_res in attn_res, upsample = True, **block_kwargs)
523560

524-
curr_res //= 2
525-
has_attn = curr_res in attn_res
561+
downsample_image = layer_downsample_type in {'all', 'image'}
562+
downsample_frame = layer_downsample_type in {'all', 'frame'}
526563

527-
downsample = Encoder(curr_dim, dim_out, downsample = True, has_attn = has_attn, **block_kwargs)
564+
assert not (downsample_image and not divisible_by(curr_image_res, 2))
565+
assert not (downsample_frame and not divisible_by(curr_frame_res, 2))
566+
567+
down_and_upsample_config = (
568+
downsample_frame,
569+
downsample_image,
570+
downsample_image
571+
)
572+
573+
upsample = Decoder(
574+
dim_out,
575+
curr_dim,
576+
has_attn = curr_image_res in attn_res,
577+
upsample = True,
578+
upsample_config = down_and_upsample_config,
579+
**block_kwargs
580+
)
581+
582+
if downsample_image:
583+
curr_image_res //= 2
584+
585+
if downsample_frame:
586+
curr_frame_res //= 2
587+
588+
has_attn = curr_image_res in attn_res
589+
590+
downsample = Encoder(
591+
curr_dim,
592+
dim_out,
593+
downsample = True,
594+
downsample_config = down_and_upsample_config,
595+
has_attn = has_attn,
596+
**block_kwargs
597+
)
528598

529599
append(self.downs, downsample)
530600
prepend(self.ups, upsample)
531601
prepend(self.ups, Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs))
532602

533-
for _ in range(num_blocks_per_stage):
603+
for _ in range(layer_num_blocks_per_stage):
534604
enc = Encoder(dim_out, dim_out, has_attn = has_attn, **block_kwargs)
535605
dec = Decoder(dim_out * 2, dim_out, has_attn = has_attn, **block_kwargs)
536606

@@ -541,7 +611,7 @@ def __init__(
541611

542612
# take care of the two middle decoders
543613

544-
mid_has_attn = curr_res in attn_res
614+
mid_has_attn = curr_image_res in attn_res
545615

546616
self.mids = ModuleList([
547617
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
@@ -563,7 +633,7 @@ def forward(
563633
):
564634
# validate image shape
565635

566-
assert x.shape[1:] == (self.channels, self.image_size, self.image_size, self.image_size)
636+
assert x.shape[1:] == (self.channels, self.frames, self.image_size, self.image_size)
567637

568638
# self conditioning
569639

@@ -689,19 +759,30 @@ def forward(self, x):
689759
# example
690760

691761
if __name__ == '__main__':
762+
692763
unet = KarrasUnet3D(
764+
frames = 32,
693765
image_size = 64,
694-
dim = 192,
766+
dim = 8,
695767
dim_max = 768,
768+
num_downsamples = 6,
769+
num_blocks_per_stage = (4, 3, 2, 2, 2, 2),
770+
downsample_types = (
771+
'image',
772+
'frame',
773+
'image',
774+
'frame',
775+
'image',
776+
'frame',
777+
),
778+
attn_dim_head = 8,
696779
num_classes = 1000,
697780
)
698781

699-
images = torch.randn(2, 4, 64, 64, 64)
782+
images = torch.randn(2, 4, 32, 64, 64)
700783

701784
denoised_images = unet(
702785
images,
703786
time = torch.ones(2,),
704787
class_labels = torch.randint(0, 1000, (2,))
705788
)
706-
707-
assert denoised_images.shape == images.shape
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.15'
1+
__version__ = '1.10.16'

0 commit comments

Comments
 (0)