Skip to content

Commit 07fc8d3

Browse files
committed
Fix dtype and device for GPU support
1 parent 0c25225 commit 07fc8d3

17 files changed

+255
-131
lines changed

cinema/log.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import sys
55
from pathlib import Path
66

7-
import wandb
87
from omegaconf import DictConfig, OmegaConf
98

109

@@ -43,7 +42,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict: # type
4342
return dict(items)
4443

4544

46-
def init_wandb(config: DictConfig, tags: list[str]) -> tuple[wandb.sdk.wandb_run.Run | None, Path]:
45+
def init_wandb(config: DictConfig, tags: list[str]) -> tuple: # type:ignore[type-arg]
4746
"""Initialize wandb.
4847
4948
Args:
@@ -54,6 +53,8 @@ def init_wandb(config: DictConfig, tags: list[str]) -> tuple[wandb.sdk.wandb_run
5453
wandb run and checkpoint directory.
5554
"""
5655
if config.logging.wandb.project:
56+
import wandb # lazy import
57+
5758
wandb_run = wandb.init(
5859
project=config.logging.wandb.project,
5960
entity=config.logging.wandb.entity,

cinema/mae/mae_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_conv_mae_size(
125125
# value can be nan if target is empty
126126
# this is unlikely to happen with large mask_ratio
127127
if min(ns_masked) > 0:
128-
assert not np.isnan(loss.detach().numpy())
128+
assert not np.isnan(loss.detach().cpu().numpy())
129129
for v in metrics.values():
130130
assert not np.isnan(v.detach())
131131
assert v.shape == ()

cinema/segmentation/train_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_segmentation_eval_metrics(
105105

106106
metrics = segmentation_metrics(logits, labels, spacing)
107107
for v in metrics.values():
108-
assert not np.any(np.isnan(v.detach().numpy()))
108+
assert not np.any(np.isnan(v.detach().cpu().numpy()))
109109
assert v.shape == (batch,)
110110

111111
# ensure inputs are not modified

examples/inference/classification_cvd.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cinema import ConvViT
1313

1414

15-
def run(trained_dataset: str, view: str, seed: int) -> None:
15+
def run(trained_dataset: str, view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
1616
"""Run CVD classification using fine-tuned checkpoint."""
1717
# load config to get class names
1818
config_path = hf_hub_download(
@@ -28,13 +28,14 @@ def run(trained_dataset: str, view: str, seed: int) -> None:
2828
model_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
2929
config_filename=f"finetuned/classification_cvd/{trained_dataset}_{view}/config.yaml",
3030
)
31+
model.to(device)
3132

3233
# load sample data from mnms2 of class HCM and form a batch of size 1
3334
spatial_size = (192, 192, 16) if view == "sax" else (256, 256)
3435
transform = Compose(
3536
[
3637
ScaleIntensityd(keys=view),
37-
SpatialPadd(keys=view, spatial_size=spatial_size, method="end", lazy=True, allow_missing_keys=True),
38+
SpatialPadd(keys=view, spatial_size=spatial_size, method="end"),
3839
]
3940
)
4041
exp_dir = Path(__file__).parent.parent.resolve()
@@ -43,9 +44,9 @@ def run(trained_dataset: str, view: str, seed: int) -> None:
4344
image = np.stack([ed_image, es_image], axis=0) # (2, x, y, 1) or (2, x, y, z)
4445
if view != "sax":
4546
image = image[..., 0] # (2, x, y, 1) -> (2, x, y)
46-
batch = transform({view: torch.from_numpy(image).to(dtype=torch.float32)})
47-
batch = {k: v[None, ...] for k, v in batch.items()} # batch size 1
48-
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
47+
batch = transform({view: torch.from_numpy(image)})
48+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
49+
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
4950
logits = model(batch) # (1, n_classes)
5051
probs = torch.softmax(logits, dim=1)[0] # (n_classes,)
5152
probs_dict = dict(zip(classes, probs.cpu().numpy(), strict=False))
@@ -55,10 +56,16 @@ def run(trained_dataset: str, view: str, seed: int) -> None:
5556

5657

5758
if __name__ == "__main__":
59+
dtype, device = torch.float32, torch.device("cpu")
60+
if torch.cuda.is_available():
61+
device = torch.device("cuda")
62+
if torch.cuda.is_bf16_supported():
63+
dtype = torch.bfloat16
64+
5865
for trained_dataset, view in zip(
5966
["acdc", "mnms", "mnms2", "mnms2"],
6067
["sax", "sax", "sax", "lax_4c"],
6168
strict=False,
6269
):
6370
for seed in range(3):
64-
run(trained_dataset, view, seed)
71+
run(trained_dataset, view, seed, device, dtype)

examples/inference/classification_sex.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cinema import ConvViT
1313

1414

15-
def run(seed: int) -> None:
15+
def run(seed: int, device: torch.device, dtype: torch.dtype) -> None:
1616
"""Run sex classification using fine-tuned checkpoint."""
1717
trained_dataset, view = "mnms", "sax"
1818
# load config to get class names
@@ -29,13 +29,14 @@ def run(seed: int) -> None:
2929
model_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
3030
config_filename=f"finetuned/classification_sex/{trained_dataset}_{view}/config.yaml",
3131
)
32+
model.to(device)
3233

3334
# load sample data from mnms2 of class HCM and form a batch of size 1
3435
spatial_size = (192, 192, 16) if view == "sax" else (256, 256)
3536
transform = Compose(
3637
[
3738
ScaleIntensityd(keys=view),
38-
SpatialPadd(keys=view, spatial_size=spatial_size, method="end", lazy=True, allow_missing_keys=True),
39+
SpatialPadd(keys=view, spatial_size=spatial_size, method="end"),
3940
]
4041
)
4142
exp_dir = Path(__file__).parent.parent.resolve()
@@ -44,9 +45,9 @@ def run(seed: int) -> None:
4445
image = np.stack([ed_image, es_image], axis=0) # (2, x, y, 1) or (2, x, y, z)
4546
if view != "sax":
4647
image = image[..., 0] # (2, x, y, 1) -> (2, x, y)
47-
batch = transform({view: torch.from_numpy(image).to(dtype=torch.float32)})
48-
batch = {k: v[None, ...] for k, v in batch.items()} # batch size 1
49-
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
48+
batch = transform({view: torch.from_numpy(image)})
49+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
50+
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
5051
logits = model(batch) # (1, n_classes)
5152
probs = torch.softmax(logits, dim=1)[0] # (n_classes,)
5253
probs_dict = dict(zip(classes, probs.cpu().numpy(), strict=False))
@@ -56,5 +57,11 @@ def run(seed: int) -> None:
5657

5758

5859
if __name__ == "__main__":
60+
dtype, device = torch.float32, torch.device("cpu")
61+
if torch.cuda.is_available():
62+
device = torch.device("cuda")
63+
if torch.cuda.is_bf16_supported():
64+
dtype = torch.bfloat16
65+
5966
for seed in range(3):
60-
run(seed)
67+
run(seed, device, dtype)

examples/inference/classification_vendor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from cinema import ConvViT
1313

1414

15-
def run(view: str, seed: int) -> None:
15+
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
1616
"""Run vendor classification using fine-tuned checkpoint."""
1717
trained_dataset = "mnms2"
1818
# load config to get class names
@@ -29,13 +29,14 @@ def run(view: str, seed: int) -> None:
2929
model_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/{trained_dataset}_{view}_{seed}.safetensors",
3030
config_filename=f"finetuned/classification_vendor/{trained_dataset}_{view}/config.yaml",
3131
)
32+
model.to(device)
3233

3334
# load sample data from mnms2 of class HCM and form a batch of size 1
3435
spatial_size = (192, 192, 16) if view == "sax" else (256, 256)
3536
transform = Compose(
3637
[
3738
ScaleIntensityd(keys=view),
38-
SpatialPadd(keys=view, spatial_size=spatial_size, method="end", lazy=True, allow_missing_keys=True),
39+
SpatialPadd(keys=view, spatial_size=spatial_size, method="end"),
3940
]
4041
)
4142
exp_dir = Path(__file__).parent.parent.resolve()
@@ -44,9 +45,9 @@ def run(view: str, seed: int) -> None:
4445
image = np.stack([ed_image, es_image], axis=0) # (2, x, y, 1) or (2, x, y, z)
4546
if view != "sax":
4647
image = image[..., 0] # (2, x, y, 1) -> (2, x, y)
47-
batch = transform({view: torch.from_numpy(image).to(dtype=torch.float32)})
48-
batch = {k: v[None, ...] for k, v in batch.items()} # batch size 1
49-
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
48+
batch = transform({view: torch.from_numpy(image)})
49+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
50+
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
5051
logits = model(batch) # (1, n_classes)
5152
probs = torch.softmax(logits, dim=1)[0] # (n_classes,)
5253
probs_dict = dict(zip(classes, probs.cpu().numpy(), strict=False))
@@ -56,6 +57,12 @@ def run(view: str, seed: int) -> None:
5657

5758

5859
if __name__ == "__main__":
60+
dtype, device = torch.float32, torch.device("cpu")
61+
if torch.cuda.is_available():
62+
device = torch.device("cuda")
63+
if torch.cuda.is_bf16_supported():
64+
dtype = torch.bfloat16
65+
5966
for view in ["sax", "lax_4c"]:
6067
for seed in range(3):
61-
run(view, seed)
68+
run(view, seed, device, dtype)

examples/inference/landmark_coordinate.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
from cinema import ConvViT
1313

1414

15-
def run(view: str, seed: int) -> None:
15+
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
1616
"""Run landmark localization on LAX images using fine-tuned checkpoint."""
1717
# load model
1818
model = ConvViT.from_finetuned(
1919
repo_id="mathpluscode/CineMA",
2020
model_filename=f"finetuned/landmark_coordinate/{view}/{view}_{seed}.safetensors",
2121
config_filename=f"finetuned/landmark_coordinate/{view}/config.yaml",
2222
)
23+
model.to(device)
2324

2425
# load sample data and form a batch of size 1
2526
transform = ScaleIntensityd(keys=view)
@@ -31,9 +32,9 @@ def run(view: str, seed: int) -> None:
3132
preds_list = []
3233
lv_lengths = []
3334
for t in tqdm(range(n_frames), total=n_frames):
34-
batch = transform({view: torch.from_numpy(images[None, ..., 0, t]).to(dtype=torch.float32)})
35-
batch = {k: v[None, ...] for k, v in batch.items()} # batch size 1
36-
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
35+
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
36+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
37+
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
3738
coords = model(batch)[0].numpy() # (6,)
3839
coords *= np.array([w, h, w, h, w, h])
3940
coords = [int(x) for x in coords]
@@ -85,6 +86,12 @@ def run(view: str, seed: int) -> None:
8586

8687

8788
if __name__ == "__main__":
89+
dtype, device = torch.float32, torch.device("cpu")
90+
if torch.cuda.is_available():
91+
device = torch.device("cuda")
92+
if torch.cuda.is_bf16_supported():
93+
dtype = torch.bfloat16
94+
8895
for view in ["lax_2c", "lax_4c"]:
8996
for seed in range(3):
90-
run(view, seed)
97+
run(view, seed, device, dtype)

examples/inference/landmark_heatmap.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
from cinema import ConvUNetR, heatmap_soft_argmax
1313

1414

15-
def run(view: str, seed: int) -> None:
15+
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
1616
"""Run landmark localization on LAX images using fine-tuned checkpoint."""
1717
# load model
1818
model = ConvUNetR.from_finetuned(
1919
repo_id="mathpluscode/CineMA",
2020
model_filename=f"finetuned/landmark_heatmap/{view}/{view}_{seed}.safetensors",
2121
config_filename=f"finetuned/landmark_heatmap/{view}/config.yaml",
2222
)
23+
model.to(device)
2324

2425
# load sample data and form a batch of size 1
2526
transform = ScaleIntensityd(keys=view)
@@ -32,9 +33,9 @@ def run(view: str, seed: int) -> None:
3233
preds_list = []
3334
lv_lengths = []
3435
for t in tqdm(range(n_frames), total=n_frames):
35-
batch = transform({view: torch.from_numpy(images[None, ..., 0, t]).to(dtype=torch.float32)})
36-
batch = {k: v[None, ...] for k, v in batch.items()} # batch size 1
37-
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
36+
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
37+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
38+
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
3839
logits = model(batch)[view] # (1, 3, x, y)
3940
probs = torch.sigmoid(logits) # (1, 3, width, height)
4041
probs_list.append(probs[0].detach().cpu().numpy())
@@ -106,6 +107,12 @@ def run(view: str, seed: int) -> None:
106107

107108

108109
if __name__ == "__main__":
110+
dtype, device = torch.float32, torch.device("cpu")
111+
if torch.cuda.is_available():
112+
device = torch.device("cuda")
113+
if torch.cuda.is_bf16_supported():
114+
dtype = torch.bfloat16
115+
109116
for view in ["lax_2c", "lax_4c"]:
110117
for seed in range(3):
111-
run(view, seed)
118+
run(view, seed, device, dtype)

examples/inference/mae.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@
1111
from cinema import CineMA, patchify, unpatchify
1212

1313

14-
def run() -> None:
14+
def run(device: torch.device, dtype: torch.dtype) -> None:
1515
"""Run MAE reconstruction."""
1616
# load model
1717
model = CineMA.from_pretrained()
18+
model.to(device)
1819
model.eval()
1920

2021
# load sample data and form a batch of size 1
2122
transform = Compose(
2223
[
2324
ScaleIntensityd(keys=("sax", "lax_2c", "lax_3c", "lax_4c"), allow_missing_keys=True),
24-
SpatialPadd(keys="sax", spatial_size=(192, 192, 16), method="end", lazy=True, allow_missing_keys=True),
25+
SpatialPadd(keys="sax", spatial_size=(192, 192, 16), method="end"),
2526
SpatialPadd(
2627
keys=("lax_2c", "lax_3c", "lax_4c"),
2728
spatial_size=(256, 256),
@@ -47,17 +48,17 @@ def run() -> None:
4748
)
4849
t = 25 # which time frame to use
4950
batch = {
50-
"sax": sax_image[None, ..., t].to(dtype=torch.float32),
51-
"lax_2c": lax_2c_image[None, ..., 0, t].to(dtype=torch.float32),
52-
"lax_3c": lax_3c_image[None, ..., 0, t].to(dtype=torch.float32),
53-
"lax_4c": lax_4c_image[None, ..., 0, t].to(dtype=torch.float32),
51+
"sax": sax_image[None, ..., t],
52+
"lax_2c": lax_2c_image[None, ..., 0, t],
53+
"lax_3c": lax_3c_image[None, ..., 0, t],
54+
"lax_4c": lax_4c_image[None, ..., 0, t],
5455
}
5556
batch = transform(batch)
5657
print(f"SAX view had originally {sax_image.shape[-2]} slices, now zero-padded to {batch['sax'].shape[-1]} slices.") # noqa: T201
57-
batch = {k: v[None, ...] for k, v in batch.items()} # batch size 1
58+
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
5859

5960
# forward
60-
with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
61+
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
6162
_, pred_dict, enc_mask_dict, _ = model(batch, enc_mask_ratio=0.75)
6263

6364
# visualize
@@ -76,8 +77,8 @@ def run() -> None:
7677
patch_size=model.dec_patch_size_dict[view],
7778
grid_size=model.enc_down_dict[view].patch_embed.grid_size,
7879
)
79-
reconstructed = reconstructed[0, 0].detach().numpy()
80-
image = batch[view][0, 0].detach().numpy()
80+
reconstructed = reconstructed[0, 0].detach().cpu().numpy()
81+
image = batch[view][0, 0].detach().cpu().numpy()
8182
error = np.abs(reconstructed - image)
8283

8384
if view == "sax":
@@ -104,4 +105,10 @@ def run() -> None:
104105

105106

106107
if __name__ == "__main__":
107-
run()
108+
dtype, device = torch.float32, torch.device("cpu")
109+
if torch.cuda.is_available():
110+
device = torch.device("cuda")
111+
if torch.cuda.is_bf16_supported():
112+
dtype = torch.bfloat16
113+
114+
run(device, dtype)

0 commit comments

Comments
 (0)