Skip to content

Commit 2badf6e

Browse files
Merge pull request #806 from NVIDIA/gh/release
[nnUNet/PyT] Release
2 parents 3f9cdc6 + d646ec9 commit 2badf6e

File tree

23 files changed

+2749
-0
lines changed

23 files changed

+2749
-0
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:20.12-py3
2+
FROM ${FROM_IMAGE_NAME}
3+
4+
ADD . /workspace/nnunet_pyt
5+
WORKDIR /workspace/nnunet_pyt
6+
7+
RUN pip install --upgrade pip
8+
RUN pip install --disable-pip-version-check -r requirements.txt
9+
RUN pip install pytorch-lightning==1.0.0 --no-dependencies
10+
RUN pip install monai==0.4.0 --no-dependencies
11+
RUN pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/ nvidia-dali-cuda110==0.29.0
12+
RUN pip install torch_optimizer==0.0.1a15 --no-dependencies
13+
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
14+
RUN unzip awscliv2.zip
15+
RUN ./aws/install
16+
RUN rm -rf awscliv2.zip aws

PyTorch/Segmentation/nnUNet/README.md

Lines changed: 706 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import itertools
2+
import os
3+
4+
import numpy as np
5+
import nvidia.dali.fn as fn
6+
import nvidia.dali.math as math
7+
import nvidia.dali.ops as ops
8+
import nvidia.dali.tfrecord as tfrec
9+
import nvidia.dali.types as types
10+
from nvidia.dali.pipeline import Pipeline
11+
from nvidia.dali.plugin.pytorch import DALIGenericIterator
12+
13+
14+
class TFRecordTrain(Pipeline):
15+
def __init__(self, batch_size, num_threads, device_id, **kwargs):
16+
super(TFRecordTrain, self).__init__(batch_size, num_threads, device_id)
17+
self.dim = kwargs["dim"]
18+
self.seed = kwargs["seed"]
19+
self.oversampling = kwargs["oversampling"]
20+
self.input = ops.TFRecordReader(
21+
path=kwargs["tfrecords"],
22+
index_path=kwargs["tfrecords_idx"],
23+
features={
24+
"X_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
25+
"Y_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
26+
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
27+
"Y": tfrec.FixedLenFeature([], tfrec.string, ""),
28+
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
29+
},
30+
num_shards=kwargs["gpus"],
31+
shard_id=device_id,
32+
random_shuffle=True,
33+
pad_last_batch=True,
34+
read_ahead=True,
35+
seed=self.seed,
36+
)
37+
self.patch_size = kwargs["patch_size"]
38+
self.crop_shape = types.Constant(np.array(self.patch_size), dtype=types.INT64)
39+
self.crop_shape_float = types.Constant(np.array(self.patch_size), dtype=types.FLOAT)
40+
self.layout = "CDHW" if self.dim == 3 else "CHW"
41+
self.axis_name = "DHW" if self.dim == 3 else "HW"
42+
43+
def load_data(self, features):
44+
img = fn.reshape(features["X"], shape=features["X_shape"], layout=self.layout)
45+
lbl = fn.reshape(features["Y"], shape=features["Y_shape"], layout=self.layout)
46+
lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
47+
return img, lbl
48+
49+
def random_augmentation(self, probability, augmented, original):
50+
condition = fn.cast(fn.coin_flip(probability=probability), dtype=types.DALIDataType.BOOL)
51+
neg_condition = condition ^ True
52+
return condition * augmented + neg_condition * original
53+
54+
@staticmethod
55+
def slice_fn(img, start_idx, length):
56+
return fn.slice(img, start_idx, length, axes=[0])
57+
58+
def crop_fn(self, img, lbl):
59+
center = fn.segmentation.random_mask_pixel(lbl, foreground=fn.coin_flip(probability=self.oversampling))
60+
crop_anchor = self.slice_fn(center, 1, self.dim) - self.crop_shape // 2
61+
adjusted_anchor = math.max(0, crop_anchor)
62+
max_anchor = self.slice_fn(fn.shapes(lbl), 1, self.dim) - self.crop_shape
63+
crop_anchor = math.min(adjusted_anchor, max_anchor)
64+
img = fn.slice(img.gpu(), crop_anchor, self.crop_shape, axis_names=self.axis_name, out_of_bounds_policy="pad")
65+
lbl = fn.slice(lbl.gpu(), crop_anchor, self.crop_shape, axis_names=self.axis_name, out_of_bounds_policy="pad")
66+
return img, lbl
67+
68+
def zoom_fn(self, img, lbl):
69+
resized_shape = self.crop_shape * self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.0)), 1.0)
70+
img, lbl = fn.crop(img, crop=resized_shape), fn.crop(lbl, crop=resized_shape)
71+
img = fn.resize(img, interp_type=types.DALIInterpType.INTERP_CUBIC, size=self.crop_shape_float)
72+
lbl = fn.resize(lbl, interp_type=types.DALIInterpType.INTERP_NN, size=self.crop_shape_float)
73+
return img, lbl
74+
75+
def noise_fn(self, img):
76+
img_noised = img + fn.normal_distribution(img, stddev=fn.uniform(range=(0.0, 0.33)))
77+
return self.random_augmentation(0.15, img_noised, img)
78+
79+
def blur_fn(self, img):
80+
img_blured = fn.gaussian_blur(img, sigma=fn.uniform(range=(0.5, 1.5)))
81+
return self.random_augmentation(0.15, img_blured, img)
82+
83+
def brightness_fn(self, img):
84+
brightness_scale = self.random_augmentation(0.15, fn.uniform(range=(0.7, 1.3)), 1.0)
85+
return img * brightness_scale
86+
87+
def contrast_fn(self, img):
88+
min_, max_ = fn.reductions.min(img), fn.reductions.max(img)
89+
scale = self.random_augmentation(0.15, fn.uniform(range=(0.65, 1.5)), 1.0)
90+
img = math.clamp(img * scale, min_, max_)
91+
return img
92+
93+
def flips_fn(self, img, lbl):
94+
kwargs = {"horizontal": fn.coin_flip(probability=0.33), "vertical": fn.coin_flip(probability=0.33)}
95+
if self.dim == 3:
96+
kwargs.update({"depthwise": fn.coin_flip(probability=0.33)})
97+
return fn.flip(img, **kwargs), fn.flip(lbl, **kwargs)
98+
99+
def define_graph(self):
100+
features = self.input(name="Reader")
101+
img, lbl = self.load_data(features)
102+
img, lbl = self.crop_fn(img, lbl)
103+
img, lbl = self.zoom_fn(img, lbl)
104+
img = self.noise_fn(img)
105+
img = self.blur_fn(img)
106+
img = self.brightness_fn(img)
107+
img = self.contrast_fn(img)
108+
img, lbl = self.flips_fn(img, lbl)
109+
return img, lbl
110+
111+
112+
class TFRecordEval(Pipeline):
113+
def __init__(self, batch_size, num_threads, device_id, **kwargs):
114+
super(TFRecordEval, self).__init__(batch_size, num_threads, device_id)
115+
self.input = ops.TFRecordReader(
116+
path=kwargs["tfrecords"],
117+
index_path=kwargs["tfrecords_idx"],
118+
features={
119+
"X_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
120+
"Y_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
121+
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
122+
"Y": tfrec.FixedLenFeature([], tfrec.string, ""),
123+
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
124+
},
125+
shard_id=device_id,
126+
num_shards=kwargs["gpus"],
127+
read_ahead=True,
128+
random_shuffle=False,
129+
pad_last_batch=True,
130+
)
131+
132+
def load_data(self, features):
133+
img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout="CDHW")
134+
lbl = fn.reshape(features["Y"].gpu(), shape=features["Y_shape"], layout="CDHW")
135+
lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
136+
return img, lbl
137+
138+
def define_graph(self):
139+
features = self.input(name="Reader")
140+
img, lbl = self.load_data(features)
141+
return img, lbl, features["fname"]
142+
143+
144+
class TFRecordTest(Pipeline):
145+
def __init__(self, batch_size, num_threads, device_id, **kwargs):
146+
super(TFRecordTest, self).__init__(batch_size, num_threads, device_id)
147+
self.input = ops.TFRecordReader(
148+
path=kwargs["tfrecords"],
149+
index_path=kwargs["tfrecords_idx"],
150+
features={
151+
"X_shape": tfrec.FixedLenFeature([4], tfrec.int64, 0),
152+
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
153+
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
154+
},
155+
shard_id=device_id,
156+
num_shards=kwargs["gpus"],
157+
read_ahead=True,
158+
random_shuffle=False,
159+
pad_last_batch=True,
160+
)
161+
162+
def define_graph(self):
163+
features = self.input(name="Reader")
164+
img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout="CDHW")
165+
return img, features["fname"]
166+
167+
168+
class TFRecordBenchmark(Pipeline):
169+
def __init__(self, batch_size, num_threads, device_id, **kwargs):
170+
super(TFRecordBenchmark, self).__init__(batch_size, num_threads, device_id)
171+
self.dim = kwargs["dim"]
172+
self.input = ops.TFRecordReader(
173+
path=kwargs["tfrecords"],
174+
index_path=kwargs["tfrecords_idx"],
175+
features={
176+
"X_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
177+
"Y_shape": tfrec.FixedLenFeature([self.dim + 1], tfrec.int64, 0),
178+
"X": tfrec.VarLenFeature([], tfrec.float32, 0.0),
179+
"Y": tfrec.FixedLenFeature([], tfrec.string, ""),
180+
"fname": tfrec.FixedLenFeature([], tfrec.string, ""),
181+
},
182+
shard_id=device_id,
183+
num_shards=kwargs["gpus"],
184+
read_ahead=True,
185+
)
186+
self.patch_size = kwargs["patch_size"]
187+
self.layout = "CDHW" if self.dim == 3 else "CHW"
188+
189+
def load_data(self, features):
190+
img = fn.reshape(features["X"].gpu(), shape=features["X_shape"], layout=self.layout)
191+
lbl = fn.reshape(features["Y"].gpu(), shape=features["Y_shape"], layout=self.layout)
192+
lbl = fn.reinterpret(lbl, dtype=types.DALIDataType.UINT8)
193+
return img, lbl
194+
195+
def crop_fn(self, img, lbl):
196+
img = fn.crop(img, crop=self.patch_size)
197+
lbl = fn.crop(lbl, crop=self.patch_size)
198+
return img, lbl
199+
200+
def define_graph(self):
201+
features = self.input(name="Reader")
202+
img, lbl = self.load_data(features)
203+
img, lbl = self.crop_fn(img, lbl)
204+
return img, lbl
205+
206+
207+
class LightningWrapper(DALIGenericIterator):
208+
def __init__(self, pipe, **kwargs):
209+
super().__init__(pipe, **kwargs)
210+
211+
def __next__(self):
212+
out = super().__next__()
213+
out = out[0]
214+
return out
215+
216+
217+
def fetch_dali_loader(tfrecords, idx_files, batch_size, mode, **kwargs):
218+
assert len(tfrecords) > 0, "Got empty tfrecord list"
219+
assert len(idx_files) == len(tfrecords), f"Got {len(idx_files)} index files but {len(tfrecords)} tfrecords"
220+
221+
if kwargs["benchmark"]:
222+
tfrecords = list(itertools.chain(*(20 * [tfrecords])))
223+
idx_files = list(itertools.chain(*(20 * [idx_files])))
224+
225+
pipe_kwargs = {
226+
"tfrecords": tfrecords,
227+
"tfrecords_idx": idx_files,
228+
"gpus": kwargs["gpus"],
229+
"seed": kwargs["seed"],
230+
"patch_size": kwargs["patch_size"],
231+
"dim": kwargs["dim"],
232+
"oversampling": kwargs["oversampling"],
233+
}
234+
235+
if kwargs["benchmark"] and mode == "eval":
236+
pipeline = TFRecordBenchmark
237+
output_map = ["image", "label"]
238+
dynamic_shape = False
239+
elif mode == "training":
240+
pipeline = TFRecordTrain
241+
output_map = ["image", "label"]
242+
dynamic_shape = False
243+
elif mode == "eval":
244+
pipeline = TFRecordEval
245+
output_map = ["image", "label", "fname"]
246+
dynamic_shape = True
247+
else:
248+
pipeline = TFRecordTest
249+
output_map = ["image", "fname"]
250+
dynamic_shape = True
251+
252+
device_id = int(os.getenv("LOCAL_RANK", "0"))
253+
pipe = pipeline(batch_size, kwargs["num_workers"], device_id, **pipe_kwargs)
254+
return LightningWrapper(
255+
pipe,
256+
auto_reset=True,
257+
reader_name="Reader",
258+
output_map=output_map,
259+
dynamic_shape=dynamic_shape,
260+
)

0 commit comments

Comments
 (0)