Skip to content

Commit d535c56

Browse files
committed
Fix dtype
1 parent 7418ccf commit d535c56

File tree

6 files changed

+7
-7
lines changed

6 files changed

+7
-7
lines changed

cinema/examples/inference/landmark_heatmap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
3939
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
4040
logits = model(batch)[view] # (1, 3, x, y)
4141
probs = torch.sigmoid(logits) # (1, 3, width, height)
42-
probs_list.append(probs[0].detach().cpu().numpy())
42+
probs_list.append(probs[0].detach().astype(torch.float32).cpu().numpy())
4343
coords = heatmap_soft_argmax(probs)[0].numpy()
4444
coords = [int(x) for x in coords]
4545

cinema/examples/inference/mae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def reconstruct_images(
7777
patch_size=patch_size_dict[view],
7878
grid_size=grid_size_dict[view],
7979
)
80-
reconstructed_dict[view] = reconstructed.detach().cpu().numpy()[0, 0]
81-
masks_dict[view] = masks.detach().cpu().numpy()[0, 0]
80+
reconstructed_dict[view] = reconstructed.detach().astype(torch.float32).cpu().numpy()[0, 0]
81+
masks_dict[view] = masks.detach().astype(torch.float32).cpu().numpy()[0, 0]
8282
reconstructed_dict["sax"] = reconstructed_dict["sax"][..., :sax_slices]
8383
masks_dict["sax"] = masks_dict["sax"][..., :sax_slices]
8484
return reconstructed_dict, masks_dict

cinema/examples/inference/segmentation_lax_4c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def run(trained_dataset: str, seed: int, device: torch.device, dtype: torch.dtyp
126126
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
127127
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
128128
logits = model(batch)[view] # (1, 4, x, y)
129-
labels = torch.argmax(logits, dim=1)[0].detach().cpu().numpy() # (x, y)
129+
labels = torch.argmax(logits, dim=1)[0].detach().astype(torch.float32).cpu().numpy() # (x, y)
130130

131131
# the model seems to hallucinate an additional right ventricle and myocardium sometimes
132132
# find the connected component that is closest to left ventricle

cinema/examples/inference/segmentation_sax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def run(trained_dataset: str, seed: int, device: torch.device, dtype: torch.dtyp
107107
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
108108
logits = model(batch)[view] # (1, 4, x, y, z)
109109
labels_list.append(torch.argmax(logits, dim=1)[0, ..., :n_slices])
110-
labels = torch.stack(labels_list, dim=-1).detach().cpu().numpy() # (x, y, z, t)
110+
labels = torch.stack(labels_list, dim=-1).detach().astype(torch.float32).cpu().numpy() # (x, y, z, t)
111111

112112
# visualise segmentations
113113
fig = plot_segmentations(images, labels, t_step)

cinema/mae/mae_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_conv_mae_size(
125125
# value can be nan if target is empty
126126
# this is unlikely to happen with large mask_ratio
127127
if min(ns_masked) > 0:
128-
assert not np.isnan(loss.detach().cpu().numpy())
128+
assert not np.isnan(loss.detach().astype(torch.float32).cpu().numpy())
129129
for v in metrics.values():
130130
assert not np.isnan(v.detach())
131131
assert v.shape == ()

cinema/segmentation/train_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_segmentation_eval_metrics(
105105

106106
metrics = segmentation_metrics(logits, labels, spacing)
107107
for v in metrics.values():
108-
assert not np.any(np.isnan(v.detach().cpu().numpy()))
108+
assert not np.any(np.isnan(v.detach().astype(torch.float32).cpu().numpy()))
109109
assert v.shape == (batch,)
110110

111111
# ensure inputs are not modified

0 commit comments

Comments
 (0)