Skip to content

Commit 70ac273

Browse files
committed
fix for multi gpu, thanks to @TMullerSG
1 parent dbdce0f commit 70ac273

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,12 +494,12 @@ def __init__(
494494
):
495495
super().__init__()
496496

497-
self.model = model
497+
self.model = model if isinstance(model, Unet) else model.module
498+
498499
self.input_img_channels = self.model.input_img_channels
499500
self.mask_channels = self.model.mask_channels
500501
self.self_condition = self.model.self_condition
501-
502-
self.image_size = model.image_size
502+
self.image_size = self.model.image_size
503503

504504
self.objective = objective
505505

0 commit comments

Comments
 (0)