Skip to content

Commit 3ace4d7

Browse files
committed
Fix dtype
1 parent 7418ccf commit 3ace4d7

File tree

6 files changed

+15
-13
lines changed

6 files changed

+15
-13
lines changed

cinema/examples/inference/landmark_heatmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
3939
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
4040
logits = model(batch)[view] # (1, 3, x, y)
4141
probs = torch.sigmoid(logits) # (1, 3, width, height)
42-
probs_list.append(probs[0].detach().cpu().numpy())
42+
probs_list.append(probs[0].detach().to(torch.float32).cpu().numpy())
4343
coords = heatmap_soft_argmax(probs)[0].numpy()
4444
coords = [int(x) for x in coords]
4545

cinema/examples/inference/mae.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def reconstruct_images(
7777
patch_size=patch_size_dict[view],
7878
grid_size=grid_size_dict[view],
7979
)
80-
reconstructed_dict[view] = reconstructed.detach().cpu().numpy()[0, 0]
81-
masks_dict[view] = masks.detach().cpu().numpy()[0, 0]
80+
reconstructed_dict[view] = reconstructed.detach().to(torch.float32).cpu().numpy()[0, 0]
81+
masks_dict[view] = masks.detach().to(torch.float32).cpu().numpy()[0, 0]
8282
reconstructed_dict["sax"] = reconstructed_dict["sax"][..., :sax_slices]
8383
masks_dict["sax"] = masks_dict["sax"][..., :sax_slices]
8484
return reconstructed_dict, masks_dict
@@ -115,12 +115,12 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
115115
lax_4c_image = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / "data/ukb/1/1_lax_4c.nii.gz")))
116116

117117
image_dict = {
118-
"sax": sax_image[..., t],
119-
"lax_2c": lax_2c_image[..., 0, t],
120-
"lax_3c": lax_3c_image[..., 0, t],
121-
"lax_4c": lax_4c_image[..., 0, t],
118+
"sax": sax_image[None, ..., t],
119+
"lax_2c": lax_2c_image[None, ..., 0, t],
120+
"lax_3c": lax_3c_image[None, ..., 0, t],
121+
"lax_4c": lax_4c_image[None, ..., 0, t],
122122
}
123-
batch = {k: torch.from_numpy(v[None, ...]) for k, v in image_dict.items()}
123+
batch = {k: torch.from_numpy(v) for k, v in image_dict.items()}
124124

125125
# forward
126126
sax_slices = batch["sax"].shape[-1]
@@ -137,10 +137,12 @@ def run(device: torch.device, dtype: torch.dtype) -> None:
137137
grid_size_dict,
138138
sax_slices,
139139
)
140+
batch = {k: v.detach().to(torch.float32).cpu().numpy()[0, 0] for k, v in batch.items()}
141+
batch["sax"] = batch["sax"][..., :sax_slices]
140142

141143
# visualize
142144
fig = plot_mae_reconstruction(
143-
image_dict,
145+
batch,
144146
reconstructed_dict,
145147
masks_dict,
146148
)

cinema/examples/inference/segmentation_lax_4c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def run(trained_dataset: str, seed: int, device: torch.device, dtype: torch.dtyp
126126
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
127127
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
128128
logits = model(batch)[view] # (1, 4, x, y)
129-
labels = torch.argmax(logits, dim=1)[0].detach().cpu().numpy() # (x, y)
129+
labels = torch.argmax(logits, dim=1)[0].detach().to(torch.float32).cpu().numpy() # (x, y)
130130

131131
# the model seems to hallucinate an additional right ventricle and myocardium sometimes
132132
# find the connected component that is closest to left ventricle

cinema/examples/inference/segmentation_sax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def run(trained_dataset: str, seed: int, device: torch.device, dtype: torch.dtyp
107107
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
108108
logits = model(batch)[view] # (1, 4, x, y, z)
109109
labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
110-
labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy() # (x, y, z, t)
110+
labels = torch.stack(labels_list, dim=-1).detach().to(torch.float32).cpu().numpy() # (x, y, z, t)
111111

112112
# visualise segmentations
113113
fig = plot_segmentations(images, labels, t_step)

cinema/mae/mae_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_conv_mae_size(
125125
# value can be nan if target is empty
126126
# this is unlikely to happen with large mask_ratio
127127
if min(ns_masked) > 0:
128-
assert not np.isnan(loss.detach().cpu().numpy())
128+
assert not np.isnan(loss.detach().to(torch.float32).cpu().numpy())
129129
for v in metrics.values():
130130
assert not np.isnan(v.detach())
131131
assert v.shape == ()

cinema/segmentation/train_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_segmentation_eval_metrics(
105105

106106
metrics = segmentation_metrics(logits, labels, spacing)
107107
for v in metrics.values():
108-
assert not np.any(np.isnan(v.detach().cpu().numpy()))
108+
assert not np.any(np.isnan(v.detach().to(torch.float32).cpu().numpy()))
109109
assert v.shape == (batch,)
110110

111111
# ensure inputs are not modified

0 commit comments

Comments
 (0)