We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dbdce0f commit 70ac273Copy full SHA for 70ac273
med_seg_diff_pytorch/med_seg_diff_pytorch.py
@@ -494,12 +494,12 @@ def __init__(
494
):
495
super().__init__()
496
497
- self.model = model
+ self.model = model if isinstance(model, Unet) else model.module
498
+
499
self.input_img_channels = self.model.input_img_channels
500
self.mask_channels = self.model.mask_channels
501
self.self_condition = self.model.self_condition
-
502
- self.image_size = model.image_size
+ self.image_size = self.model.image_size
503
504
self.objective = objective
505
0 commit comments