Skip to content

Commit 4e49ba2

Browse files
authored
Merge pull request #194 from pedropesserl/master
fix issues with training in version 1 and sampling with custom dataset
2 parents 60850bf + a47786b commit 4e49ba2

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

guided_diffusion/custom_dataset_loader.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(self, args, data_path , transform = None, mode = 'Training',plane =
2323
images = sorted(glob(os.path.join(path, "images/*.png")))
2424
masks = sorted(glob(os.path.join(path, "masks/*.png")))
2525

26-
self.name_list = images[:2]
27-
self.label_list = masks[:2]
26+
self.name_list = images
27+
self.label_list = masks
2828
self.data_path = path
2929
self.mode = mode
3030

@@ -44,18 +44,19 @@ def __getitem__(self, index):
4444
img = Image.open(img_path).convert('RGB')
4545
mask = Image.open(msk_path).convert('L')
4646

47-
if self.mode == 'Training':
48-
label = 0 if self.label_list[index] == 'benign' else 1
49-
else:
50-
label = int(self.label_list[index])
47+
# if self.mode == 'Training':
48+
# label = 0 if self.label_list[index] == 'benign' else 1
49+
# else:
50+
# label = int(self.label_list[index])
5151

5252
if self.transform:
5353
state = torch.get_rng_state()
5454
img = self.transform(img)
5555
torch.set_rng_state(state)
5656
mask = self.transform(mask)
5757

58-
if self.mode == 'Training':
59-
return (img, mask, name)
60-
else:
61-
return (img, mask, name)
58+
return (img, mask, name)
59+
# if self.mode == 'Training':
60+
# return (img, mask, name)
61+
# else:
62+
# return (img, mask, name)

guided_diffusion/unet.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,17 @@ def convert_to_fp32(self):
735735
self.middle_block.apply(convert_module_to_f32)
736736
self.output_blocks.apply(convert_module_to_f32)
737737

738+
def load_part_state_dict(self, state_dict):
739+
740+
own_state = self.state_dict()
741+
for name, param in state_dict.items():
742+
if name not in own_state:
743+
continue
744+
if isinstance(param, th.nn.Parameter):
745+
# backwards compatibility for serialized parameters
746+
param = param.data
747+
own_state[name].copy_(param)
748+
738749
def enhance(self, c, h):
739750
cu = layer_norm(c.size()[1:])(c)
740751
hu = layer_norm(h.size()[1:])(h)

scripts/segmentation_sample.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from guided_diffusion import dist_util, logger
1818
from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D
1919
from guided_diffusion.isicloader import ISICDataset
20+
from guided_diffusion.custom_dataset_loader import CustomDataset
2021
import torchvision.utils as vutils
2122
from guided_diffusion.utils import staple
2223
from guided_diffusion.script_util import (
@@ -58,6 +59,13 @@ def main():
5859

5960
ds = BRATSDataset3D(args.data_dir,transform_test)
6061
args.in_ch = 5
62+
else:
63+
tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor()]
64+
transform_test = transforms.Compose(tran_list)
65+
66+
ds = CustomDataset(args, args.data_dir, transform_test, mode = 'Test')
67+
args.in_ch = 4
68+
6169
datal = th.utils.data.DataLoader(
6270
ds,
6371
batch_size=args.batch_size,

0 commit comments

Comments
 (0)