Skip to content

Commit c849223

Browse files
committed
make exercise exercisable
1 parent 50f0953 commit c849223

File tree

5 files changed

+40
-232
lines changed

5 files changed

+40
-232
lines changed

README.md

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,21 @@ from T2-Weighted MRI"](https://www.var.ovgu.de/pub/2019_Meyer_ISBI_Zone_Segmenta
77
by Meyer et al.
88

99

10-
1. To get started run
10+
### Task 1: To get started run
1111

1212
```bash
1313
python ./data/download.py
1414
```
1515

1616
in your terminal. The script will download and prepare the medical scans and domain-expert
17-
annotations for you.
17+
annotations for you or you can copy the data from bender at the following location.
18+
```bash
19+
TODO: update bender location here.
20+
```
1821

19-
Data loading and resampling work already.
22+
Data loading and resampling work already. The next task is optional. If you want to skip it, download the `compute_roi.py` from eCampus and replace the contents with the existing function `compute_roi()` in the repository.
2023

21-
1. #### Find the bounding box roi as described below by finishing the `compute_roi` function.
24+
### Task 2 (Optional): Find the bounding box roi as described below by finishing the `compute_roi` function.
2225
Once you have obtained the train and test data, you must create a preprocessing pipeline.
2326
Proceed to `src/util.py` and compute the so called region of interest.
2427
Meyer et al. define this region as:
@@ -70,32 +73,27 @@ local coordinates now allows array indexing. Following Meyer et al. we discard a
7073

7174
Test your implementation by setting the if-condition wrapping the plotting utility in `compute_roi` to `True` and running vscode pytest `test_roi`. Remember to set it back to `False` afterwards.
7275

73-
2. #### Implement the UNet.
76+
### Task 3: Implement the UNet.
77+
Navigate to the `train.py` file in the `src` folder.
78+
Finish the `UNet3D` class, as discussed in the lecture.
79+
Use [torch.nn.Conv3d](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html), [torch.nn.ReLU](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html), [torch.nn.MaxPool3d](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html) and [th.nn.UpSample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html) to build the model. For upsampling, we suggest to use `mode='nearest'` algorithm for reproducibility purpose.
7480

75-
Navigate to the `train.py` module file in the `src` folder.
76-
Finish the `UNet3D` class, as discussed in the lecture.
77-
Use the [flax.linen.Conv](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.Conv.html), [flax.linen.relu](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.activation.relu.html), and [flax.linen.ConvTranspose](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.ConvTranspose.html), to build your model.
78-
79-
3. #### Implement the focal-loss
81+
### Task 4: Implement the focal-loss.
8082

8183
Open the `util.py` module in `src` and implement the `softmax_focal_loss` function as discussed in the lecture:
8284

8385
$$\mathcal{L}(\mathbf{o},\mathbf{I})=-\mathbf{I}\cdot(1-\sigma_s(\mathbf{o}))^\gamma\cdot\alpha\cdot\ln(\sigma_s(\mathbf{o})) $$
8486

8587
with output logits $\mathbf{o}$, the corresponding labels $\mathbf{I}$ and the softmax function $\sigma_s$.
8688

87-
4. #### Run and test the training script.
89+
### Task 5: Run and test the training script.
8890

8991
Execute the training script with by running `scripts/train.slurm` (locally or using `sbatch`).
9092

9193
After training you can test your model by changing the `checkpoint_name` variable in `src/sample.py` to the desired model checkpoint and running `scripts/test.slurm`.
9294

93-
#### Solution:
94-
![slice](./fig/prostatext2.png)
95-
![slice](./fig/prostatext2_net.png)
96-
![slice](./fig/prostatext2_true.png)
9795

98-
5. #### (Optional) Implement mean Intersection-over-Union (mIoU)
96+
### Task 6: Implement mean Intersection-over-Union (mIoU)
9997

10098
Open the `meanIoU.py` in `src` and implement the `compute_iou` function as discussed below.
10199
mIoU is the most common metric used for evaluating semantic segmentation tasks. It can be computed using the values from a confusion matrix as given below
@@ -113,5 +111,4 @@ python -m src.meanIoU
113111

114112
### Acknowledgments:
115113
We thank our course alumni Barbara Wichtmann, for bringing this problem to our attention.
116-
Without her feedback, this code would not exist.
117-
114+
Without her feedback, this code would not exist.

README.pdf

-179 KB
Binary file not shown.

src/meanIoU.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,8 @@ def compute_iou(preds: th.Tensor, target: th.Tensor) -> th.Tensor:
1919
jnp.ndarray: Mean Intersection over Union values
2020
"""
2121
assert preds.shape == target.shape
22-
23-
b, h, w, s = target.shape
24-
preds = preds.permute((0, 3, 1, 2))
25-
target = target.permute((0, 3, 1, 2))
26-
batch_preds = th.reshape(preds, (b * s, h, w, 1))
27-
batch_target = th.reshape(target, (b * s, h, w, 1))
28-
batch_iou = []
29-
for idx in range(b * s):
30-
preds = batch_preds[idx]
31-
target = batch_target[idx]
32-
per_class_iou = []
33-
for cls in range(0, 5):
34-
if th.any(preds == cls) or th.any(target == cls):
35-
tp = th.sum((preds == cls) & (target == cls))
36-
fp = th.sum((preds != cls) & (target == cls))
37-
fn = th.sum((preds == cls) & (target != cls))
38-
iou = tp / (tp + fp + fn + 1e-8)
39-
per_class_iou.append(iou)
40-
batch_iou.append(th.mean(th.tensor(per_class_iou)))
41-
return th.mean(th.tensor(batch_iou))
22+
# TODO: Implement meanIoU
23+
return th.tensor(0.0)
4224

4325

4426
if __name__ == "__main__":

src/train.py

Lines changed: 7 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -74,79 +74,8 @@ def __init__(self):
7474
input_feat = 1
7575
init_feat = 16
7676
out_neurons = 5
77-
# Five Downscale blocks
78-
self.downscale_1 = th.nn.Sequential(
79-
th.nn.Conv3d(input_feat, init_feat, (3, 3, 3), padding=1),
80-
th.nn.ReLU(),
81-
th.nn.Conv3d(init_feat, init_feat, (3, 3, 3), padding=1),
82-
th.nn.ReLU(),
83-
)
84-
self.downscale_2 = th.nn.Sequential(
85-
th.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
86-
# th.nn.BatchNorm3d(init_feat),
87-
th.nn.Conv3d(init_feat, init_feat * 2, (3, 3, 3), padding=1),
88-
th.nn.ReLU(),
89-
th.nn.Conv3d(init_feat * 2, init_feat * 2, (3, 3, 3), padding=1),
90-
th.nn.ReLU(),
91-
)
92-
self.downscale_3 = th.nn.Sequential(
93-
th.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
94-
# th.nn.BatchNorm3d(init_feat * 2),
95-
th.nn.Conv3d(init_feat * 2, init_feat * 4, (3, 3, 3), padding=1),
96-
th.nn.ReLU(),
97-
th.nn.Conv3d(init_feat * 4, init_feat * 4, (3, 3, 3), padding=1),
98-
th.nn.ReLU(),
99-
)
100-
self.downscale_4 = th.nn.Sequential(
101-
th.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
102-
# th.nn.BatchNorm3d(init_feat * 4),
103-
th.nn.Conv3d(init_feat * 4, init_feat * 8, (3, 3, 3), padding=1),
104-
th.nn.ReLU(),
105-
th.nn.Conv3d(init_feat * 8, init_feat * 8, (3, 3, 3), padding=1),
106-
th.nn.ReLU(),
107-
)
108-
self.downscale_5 = th.nn.Sequential(
109-
th.nn.MaxPool3d((1, 2, 2), stride=(1, 2, 2)),
110-
# th.nn.BatchNorm3d(init_feat * 8),
111-
th.nn.Conv3d(init_feat * 8, init_feat * 16, (3, 3, 3), padding=1),
112-
th.nn.ReLU(),
113-
th.nn.Conv3d(init_feat * 16, init_feat * 16, (3, 3, 3), padding=1),
114-
th.nn.ReLU(),
115-
)
116-
# Four Upscale conv blocks
117-
self.upscale_4 = th.nn.Sequential(
118-
# th.nn.BatchNorm3d(init_feat * 16 + init_feat * 8),
119-
th.nn.Conv3d(
120-
init_feat * 16 + init_feat * 8, init_feat * 8, (3, 3, 3), padding=1
121-
),
122-
th.nn.ReLU(),
123-
th.nn.Conv3d(init_feat * 8, init_feat * 8, (3, 3, 3), padding=1),
124-
th.nn.ReLU(),
125-
)
126-
self.upscale_3 = th.nn.Sequential(
127-
# th.nn.BatchNorm3d(init_feat * 8 + init_feat * 4),
128-
th.nn.Conv3d(
129-
init_feat * 8 + init_feat * 4, init_feat * 4, (3, 3, 3), padding=1
130-
),
131-
th.nn.ReLU(),
132-
th.nn.Conv3d(init_feat * 4, init_feat * 4, (3, 3, 3), padding=1),
133-
th.nn.ReLU(),
134-
)
135-
self.upscale_2 = th.nn.Sequential(
136-
# th.nn.BatchNorm3d(init_feat * 4 + init_feat * 2),
137-
th.nn.Conv3d(
138-
init_feat * 4 + init_feat * 2, init_feat * 2, (3, 3, 3), padding=1
139-
),
140-
th.nn.ReLU(),
141-
th.nn.Conv3d(init_feat * 2, init_feat * 2, (3, 3, 3), padding=1),
142-
th.nn.ReLU(),
143-
)
144-
self.upscale_1 = th.nn.Sequential(
145-
# th.nn.BatchNorm3d(init_feat * 2 + init_feat),
146-
th.nn.Conv3d(init_feat * 2 + init_feat, init_feat, (3, 3, 3), padding=1),
147-
th.nn.ReLU(),
148-
th.nn.Conv3d(init_feat, out_neurons, (3, 3, 3), padding=1),
149-
)
77+
# TODO: Initialize downscaling blocks
78+
# TODO: Initialize upscaling blocks
15079

15180
def forward(self, x: th.Tensor) -> th.Tensor:
15281
"""Forward pass.
@@ -157,42 +86,8 @@ def forward(self, x: th.Tensor) -> th.Tensor:
15786
Returns:
15887
th.Tensor: Segmented output.
15988
"""
160-
x1 = self.downscale_1(x)
161-
x1 = pad_odd(x1)
162-
163-
x2 = self.downscale_2(x1)
164-
x2 = pad_odd(x2)
165-
166-
x3 = self.downscale_3(x2)
167-
x3 = pad_odd(x3)
168-
169-
x4 = self.downscale_4(x3)
170-
x4 = pad_odd(x4)
171-
172-
x5 = self.downscale_5(x4)
173-
x5 = pad_odd(x5)
174-
175-
x6 = self.__upsize(x5)
176-
x6 = x6[:, :, : x4.shape[2], : x4.shape[3], : x4.shape[4]]
177-
x6 = th.cat([x4, x6], dim=1)
178-
x6 = self.upscale_4(x6)
179-
180-
x7 = self.__upsize(x6)
181-
x7 = x7[:, :, : x3.shape[2], : x3.shape[3], : x3.shape[4]]
182-
x7 = th.cat([x3, x7], dim=1)
183-
x7 = self.upscale_3(x7)
184-
185-
x8 = self.__upsize(x7)
186-
x8 = x8[:, :, : x2.shape[2], : x2.shape[3], : x2.shape[4]]
187-
x8 = th.cat([x2, x8], dim=1)
188-
x8 = self.upscale_2(x8)
189-
190-
x9 = self.__upsize(x8)
191-
x9 = x9[:, :, : x1.shape[2], : x1.shape[3], : x1.shape[4]]
192-
x9 = th.cat([x1, x9], dim=1)
193-
x9 = self.upscale_1(x9)
194-
out = x9[:, :, : x.shape[2], : x.shape[3], : x.shape[4]]
195-
return out
89+
# TODO: Implement 3D UNet as discussed in the lecture
90+
return th.tensor(0.0)
19691

19792
def __upsize(self, input_: th.Tensor) -> th.Tensor:
19893
"""Upsample image.
@@ -203,8 +98,8 @@ def __upsize(self, input_: th.Tensor) -> th.Tensor:
20398
Returns:
20499
th.Tensor: Upsampled image.
205100
"""
206-
_, _, d, h, w = input_.shape
207-
return th.nn.Upsample(size=(d, h * 2, w * 2), mode="nearest")(input_)
101+
# TODO: Upsample the height and width using th.nn.Upsample with nearest mode.
102+
return th.tensor(0.0)
208103

209104

210105
def train():
@@ -225,7 +120,7 @@ def train():
225120

226121
model = UNet3D().to(device)
227122
opt = th.optim.Adam(model.parameters(), lr=1e-4)
228-
load_new = False
123+
load_new = True
229124

230125
writer = metric_writers.create_default_writer(
231126
"./runs/" + str(datetime.now()), asynchronous=False
@@ -251,7 +146,6 @@ def train():
251146
val_loss_list = []
252147
train_loss_lost = []
253148
iter_count = 0
254-
# loss_fn = th.nn.CrossEntropyLoss()
255149

256150
for e in range(epochs):
257151
random.shuffle(epoch_batches)

src/util.py

Lines changed: 16 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22

33
from typing import List, Tuple
44

5-
import chex
6-
import jax
7-
import jax.numpy as jnp
85
import matplotlib.colors as mcolors
96
import matplotlib.pyplot as plt
107
import numpy as np
11-
import optax
128
import SimpleITK as sitk # noqa: N813
9+
import torch as th
1310
from SimpleITK.SimpleITK import Image
1411

1512
# from . import zone_segmentation_utils as utils
@@ -140,20 +137,26 @@ def compute_roi(images: Tuple[Image, Image, Image]):
140137
rects = []
141138
for pos, size in enumerate(sizes):
142139
lines = box_lines(size)
143-
rotated = [(rotation[pos] @ line.T).T for line in lines]
144-
shifted = [origins[pos] + line for line in rotated]
140+
# TODO: Rotate and shift the lines.
141+
rotated = []
142+
shifted = []
145143
rects.append(shifted)
146144

147145
# find the intersection.
148146
rects_stacked = np.stack(rects) # Had to rename because of mypy
147+
# TODO: Find the axis maxima and minima
149148
bbs = [
150-
(np.amin(rect, axis=(0, 1)), np.amax(rect, axis=(0, 1)))
149+
(
150+
np.zeros_like(rect[0, 0]),
151+
np.zeros_like(rect[0, 0]),
152+
) # TODO: fixme
151153
for rect in rects_stacked
152154
]
153155

154156
# compute intersection
155-
lower_end = np.amax(np.stack([bb[0] for bb in bbs], axis=0), axis=0)
156-
upper_end = np.amin(np.stack([bb[1] for bb in bbs], axis=0), axis=0)
157+
# TODO: Implement me.
158+
lower_end = np.zeros_like(bbs[0][0])
159+
upper_end = np.zeros_like(bbs[0][1])
157160
roi_bb = np.stack((lower_end, upper_end))
158161
roi_bb_size = roi_bb[1] - roi_bb[0]
159162

@@ -164,8 +167,8 @@ def compute_roi(images: Tuple[Image, Image, Image]):
164167
# compute roi coordinates in image space.
165168
img_coord_rois = [
166169
(
167-
(np.linalg.inv(rot) @ (roi_bb[0] - offset).T).T / spacing,
168-
(np.linalg.inv(rot) @ (roi_bb[1] - offset).T).T / spacing,
170+
np.zeros_like(roi_bb[0]), # TODO: Implement me
171+
np.zeros_like(roi_bb[1]), # TODO: Implement me
169172
)
170173
for rot, offset, spacing in zip(rotation, origins, spacings)
171174
]
@@ -244,54 +247,6 @@ def in_array(in_int, dim):
244247
return intersections, box_indices
245248

246249

247-
def sigmoid_focal_loss(
248-
logits: jnp.ndarray,
249-
labels: jnp.ndarray,
250-
alpha: float = -1,
251-
gamma: float = 2,
252-
) -> jnp.ndarray:
253-
"""Compute a sigmoid focal loss.
254-
255-
Implementation of the focal loss as used https://arxiv.org/abs/1708.02002.
256-
This loss often appears in the segmentation context.
257-
Use this loss function if classes are not mutually exclusive.
258-
See `sigmoid_binary_cross_entropy` for more information.
259-
260-
Args:
261-
logits: A float array of arbitrary shape.
262-
The predictions for each example.
263-
labels: A float array, its shape must be identical to
264-
that of logits. It containes the binary
265-
classification label for each element in logits
266-
(0 for the out of class and 1 for in class).
267-
This array is often one-hot encoded.
268-
alpha: (optional) Weighting factor in range (0,1) to balance
269-
positive vs negative examples. Default = -1 (no weighting).
270-
gamma: Exponent of the modulating factor (1 - p_t) to
271-
balance easy vs hard examples.
272-
273-
Returns:
274-
A loss value array with a shape identical to the logits and target
275-
arrays.
276-
"""
277-
chex.assert_type([logits], float)
278-
labels = labels.astype(logits.dtype)
279-
280-
# see also the original implementation at:
281-
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
282-
p = jax.nn.sigmoid(logits)
283-
ce_loss = optax.sigmoid_binary_cross_entropy(logits, labels)
284-
p_t = p * labels + (1 - p) * (1 - labels)
285-
loss = ce_loss * ((1 - p_t) ** gamma)
286-
if alpha >= 0:
287-
alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
288-
loss = alpha_t * loss
289-
return loss
290-
291-
292-
import torch as th
293-
294-
295250
def softmax_focal_loss(
296251
logits: th.Tensor,
297252
labels: th.Tensor,
@@ -308,25 +263,5 @@ def softmax_focal_loss(
308263
# return jnp.sum(loss, axis=-1)
309264
logits = logits.float()
310265
labels = labels.float()
311-
focus = th.pow(1.0 - th.nn.functional.softmax(logits, dim=-1), gamma)
312-
loss = -labels * focus * alpha * th.nn.functional.log_softmax(logits, dim=-1)
313-
return th.sum(loss, dim=-1)
314-
315-
316-
# def tversky(y_true, y_pred, alpha=.3, beta=.7):
317-
# """See: https://arxiv.org/pdf/1706.05721.pdf"""
318-
# y_true_f = jnp.reshape(y_true, -1)
319-
# y_pred_f = jnp.reshape(y_pred, -1)
320-
# intersection = jnp.sum(y_true_f * y_pred_f)
321-
# G_P = alpha * jnp.sum((1 - y_true_f) * y_pred_f) # G not P
322-
# P_G = beta * jnp.sum(y_true_f * (1 - y_pred_f)) # P not G
323-
# return (intersection + 1.) / (intersection + 1. + G_P + P_G)
324-
#
325-
# def Tversky_loss(y_true, y_pred):
326-
# return -tversky(y_true, y_pred)
327-
#
328-
#
329-
# def dice_coeff(logits, labels):
330-
# pred_probs = jax.nn.softmax(logits)
331-
# intersection = jnp.sum(labels * logits)
332-
# return ((2. * intersection + 1.) / (jnp.sum(labels) + jnp.sum(pred_probs) + 1.))*(-1.)
266+
# TODO: Implement softmax focal loss.
267+
return th.tensor(0.0)

0 commit comments

Comments
 (0)