Skip to content

Commit edc0774

Browse files
committed
Refactor landmark
1 parent 3ace4d7 commit edc0774

File tree

2 files changed

+121
-96
lines changed

2 files changed

+121
-96
lines changed

cinema/examples/inference/landmark_coordinate.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tqdm import tqdm
1111

1212
from cinema import ConvViT
13+
from cinema.examples.inference.landmark_heatmap import plot_landmarks, plot_lv
1314

1415

1516
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
@@ -30,59 +31,25 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
3031
exp_dir = Path(__file__).parent.parent.resolve()
3132
images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / f"data/ukb/1/1_{view}.nii.gz")))
3233
w, h, _, n_frames = images.shape
33-
preds_list = []
34-
lv_lengths = []
34+
coords_list = []
3535
for t in tqdm(range(n_frames), total=n_frames):
3636
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
3737
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
3838
with torch.no_grad(), torch.autocast("cuda", dtype=dtype, enabled=torch.cuda.is_available()):
3939
coords = model(batch)[0].numpy() # (6,)
4040
coords *= np.array([w, h, w, h, w, h])
4141
coords = [int(x) for x in coords]
42-
43-
# draw predictions with cross
44-
preds = images[..., t] * np.array([1, 1, 1])[None, None, :]
45-
preds = preds.clip(0, 255).astype(np.uint8)
46-
for i in range(3):
47-
pred_x, pred_y = coords[2 * i], coords[2 * i + 1]
48-
x1, x2 = max(0, pred_x - 9), min(preds.shape[0], pred_x + 10)
49-
y1, y2 = max(0, pred_y - 9), min(preds.shape[1], pred_y + 10)
50-
preds[pred_x, y1:y2] = [255, 0, 0]
51-
preds[x1:x2, pred_y] = [255, 0, 0]
52-
preds_list.append(preds)
53-
54-
# record LV length
55-
x1, y1, x2, y2, x3, y3 = coords
56-
lv_len = (((x1 + x2) / 2 - x3) ** 2 + ((y1 + y2) / 2 - y3) ** 2) ** 0.5
57-
lv_lengths.append(lv_len)
58-
preds = np.stack(preds_list, axis=-1) # (3, x, y, t)
42+
coords_list.append(coords)
43+
coords = np.stack(coords_list, axis=-1) # (6, t)
5944

6045
# visualise landmarks
61-
_, axs = plt.subplots(10, 5, figsize=(10, 20))
62-
for i in range(10):
63-
for j in range(5):
64-
t = i * 5 + j
65-
axs[i, j].imshow(preds[..., t])
66-
axs[i, j].set_xticks([])
67-
axs[i, j].set_yticks([])
68-
if j == 0:
69-
axs[i, j].set_ylabel(f"t = {t}")
70-
plt.subplots_adjust(wspace=0.02, hspace=0.02)
71-
plt.savefig(f"landmark_coordinate_landmark_{view}_{seed}.png", dpi=300, bbox_inches="tight")
46+
fig = plot_landmarks(images, coords)
47+
fig.savefig(f"landmark_coordinate_landmark_{view}_{seed}.png", dpi=300, bbox_inches="tight")
7248
plt.show(block=False)
7349

7450
# visualise LV length changes
75-
plt.figure(figsize=(4, 3))
76-
if view == "lax_2c":
77-
# first frame is empty for this particular example
78-
lv_lengths = lv_lengths[1:]
79-
lvef = (max(lv_lengths) - min(lv_lengths)) / max(lv_lengths) * 100
80-
plt.plot(lv_lengths, color="#82B366", label="LV")
81-
plt.xlabel("Frame")
82-
plt.ylabel("Length (mm)")
83-
plt.title(f"LVEF = {lvef:.2f}%")
84-
plt.legend(loc="lower right")
85-
plt.savefig(f"landmark_coordinate_lv_length_{view}_{seed}.png", dpi=300, bbox_inches="tight")
51+
fig = plot_lv(coords)
52+
plt.savefig(f"landmark_coordinate_gls_{view}_{seed}.png", dpi=300, bbox_inches="tight")
8653
plt.show(block=False)
8754

8855

cinema/examples/inference/landmark_heatmap.py

Lines changed: 113 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,110 @@
1212
from cinema import ConvUNetR, heatmap_soft_argmax
1313

1414

15+
def plot_heatmaps(images: np.ndarray, probs: np.ndarray, n_cols: int = 5) -> plt.Figure:
16+
"""Plot heatmaps.
17+
18+
Args:
19+
images: (x, y, t)
20+
probs: (3, x, y, t)
21+
n_cols: number of columns
22+
23+
Returns:
24+
figure
25+
"""
26+
n_frames = probs.shape[-1]
27+
n_rows = n_frames // n_cols
28+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), dpi=300)
29+
for i in range(n_rows):
30+
for j in range(n_cols):
31+
t = i * n_cols + j
32+
axs[i, j].imshow(images[..., 0, t], cmap="gray")
33+
axs[i, j].imshow(probs[0, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
34+
axs[i, j].imshow(probs[1, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
35+
axs[i, j].imshow(probs[2, ..., t, None] * np.array([1.0, 0.0, 0.0, 1.0]))
36+
axs[i, j].set_xticks([])
37+
axs[i, j].set_yticks([])
38+
if j == 0:
39+
axs[i, j].set_ylabel(f"t = {t}")
40+
fig.tight_layout()
41+
fig.subplots_adjust(wspace=0, hspace=0)
42+
return fig
43+
44+
45+
def plot_landmarks(images: np.ndarray, coords: np.ndarray, n_cols: int = 5) -> plt.Figure:
46+
"""Plot landmarks.
47+
48+
Args:
49+
images: (x, y, t)
50+
coords: (6, t)
51+
n_cols: number of columns
52+
53+
Returns:
54+
figure
55+
"""
56+
n_frames = images.shape[-1]
57+
n_rows = n_frames // n_cols
58+
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows), dpi=300)
59+
for i in range(n_rows):
60+
for j in range(n_cols):
61+
t = i * n_cols + j
62+
63+
# draw predictions with cross
64+
preds = images[..., t] * np.array([1, 1, 1])[None, None, :]
65+
preds = preds.clip(0, 255).astype(np.uint8)
66+
for k in range(3):
67+
pred_x, pred_y = coords[2 * k, t], coords[2 * k + 1, t]
68+
x1, x2 = max(0, pred_x - 9), min(preds.shape[0], pred_x + 10)
69+
y1, y2 = max(0, pred_y - 9), min(preds.shape[1], pred_y + 10)
70+
preds[pred_x, y1:y2] = [255, 0, 0]
71+
preds[x1:x2, pred_y] = [255, 0, 0]
72+
73+
axs[i, j].imshow(preds)
74+
axs[i, j].set_xticks([])
75+
axs[i, j].set_yticks([])
76+
if j == 0:
77+
axs[i, j].set_ylabel(f"t = {t}")
78+
fig.tight_layout()
79+
fig.subplots_adjust(wspace=0, hspace=0)
80+
return fig
81+
82+
83+
def plot_lv(coords: np.ndarray) -> plt.Figure:
84+
"""Plot GL shortening.
85+
86+
Args:
87+
coords: (6, t)
88+
89+
Returns:
90+
figure
91+
"""
92+
# GL shortening
93+
x1, y1 = coords[0], coords[1]
94+
x2, y2 = coords[2], coords[3]
95+
x3, y3 = coords[4], coords[5]
96+
lv_lengths = (((x1 + x2) / 2 - x3) ** 2 + ((y1 + y2) / 2 - y3) ** 2) ** 0.5
97+
gls = (max(lv_lengths) - min(lv_lengths)) / max(lv_lengths) * 100
98+
99+
# MAPSE
100+
ed_idx = np.argmin(lv_lengths)
101+
es_idx = np.argmax(lv_lengths)
102+
x1_ed, y1_ed = coords[0, ed_idx], coords[1, ed_idx]
103+
x2_ed, y2_ed = coords[2, ed_idx], coords[3, ed_idx]
104+
x1_es, y1_es = coords[0, es_idx], coords[1, es_idx]
105+
x2_es, y2_es = coords[2, es_idx], coords[3, es_idx]
106+
mapse = (
107+
((x1_ed - x1_es) ** 2 + (y1_ed - y1_es) ** 2) ** 0.5 + ((x2_ed - x2_es) ** 2 + (y2_ed - y2_es) ** 2) ** 0.5
108+
) / 2
109+
110+
fig = plt.figure(figsize=(4, 4), dpi=120)
111+
plt.plot(lv_lengths, color="#82B366", label="LV")
112+
plt.xlabel("Frame")
113+
plt.ylabel("Length (mm)")
114+
plt.title(f"GLS = {gls:.2f}%, MAPSE = {mapse:.2f} mm")
115+
plt.legend(loc="lower right")
116+
return fig
117+
118+
15119
def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
16120
"""Run landmark localization on LAX images using fine-tuned checkpoint."""
17121
# load model
@@ -31,8 +135,7 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
31135
images = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(exp_dir / f"data/ukb/1/1_{view}.nii.gz")))
32136
n_frames = images.shape[-1]
33137
probs_list = []
34-
preds_list = []
35-
lv_lengths = []
138+
coords_list = []
36139
for t in tqdm(range(n_frames), total=n_frames):
37140
batch = transform({view: torch.from_numpy(images[None, ..., 0, t])})
38141
batch = {k: v[None, ...].to(device=device, dtype=dtype) for k, v in batch.items()}
@@ -42,68 +145,23 @@ def run(view: str, seed: int, device: torch.device, dtype: torch.dtype) -> None:
42145
probs_list.append(probs[0].detach().to(torch.float32).cpu().numpy())
43146
coords = heatmap_soft_argmax(probs)[0].numpy()
44147
coords = [int(x) for x in coords]
45-
46-
# draw predictions with cross
47-
preds = images[..., t] * np.array([1, 1, 1])[None, None, :]
48-
preds = preds.clip(0, 255).astype(np.uint8)
49-
for i in range(3):
50-
pred_x, pred_y = coords[2 * i], coords[2 * i + 1]
51-
x1, x2 = max(0, pred_x - 9), min(preds.shape[0], pred_x + 10)
52-
y1, y2 = max(0, pred_y - 9), min(preds.shape[1], pred_y + 10)
53-
preds[pred_x, y1:y2] = [255, 0, 0]
54-
preds[x1:x2, pred_y] = [255, 0, 0]
55-
preds_list.append(preds)
56-
57-
# record LV length
58-
x1, y1, x2, y2, x3, y3 = coords
59-
lv_len = (((x1 + x2) / 2 - x3) ** 2 + ((y1 + y2) / 2 - y3) ** 2) ** 0.5
60-
lv_lengths.append(lv_len)
148+
coords_list.append(coords)
61149
probs = np.stack(probs_list, axis=-1) # (3, x, y, t)
62-
preds = np.stack(preds_list, axis=-1) # (3, x, y, t)
150+
coords = np.stack(coords_list, axis=-1) # (6, t)
63151

64152
# visualise heatmaps
65-
_, axs = plt.subplots(10, 5, figsize=(10, 20))
66-
for i in range(10):
67-
for j in range(5):
68-
t = i * 5 + j
69-
axs[i, j].imshow(images[..., 0, t], cmap="gray")
70-
axs[i, j].imshow((probs[0, ..., t, None]) * np.array([108 / 255, 142 / 255, 191 / 255, 1.0]))
71-
axs[i, j].imshow((probs[1, ..., t, None]) * np.array([214 / 255, 182 / 255, 86 / 255, 1.0]))
72-
axs[i, j].imshow((probs[2, ..., t, None]) * np.array([130 / 255, 179 / 255, 102 / 255, 1.0]))
73-
axs[i, j].set_xticks([])
74-
axs[i, j].set_yticks([])
75-
if j == 0:
76-
axs[i, j].set_ylabel(f"t = {t}")
77-
plt.subplots_adjust(wspace=0.02, hspace=0.02)
78-
plt.savefig(f"landmark_heatmap_{view}_{seed}.png", dpi=300, bbox_inches="tight")
153+
fig = plot_heatmaps(images, probs)
154+
fig.savefig(f"landmark_heatmap_probs_{view}_{seed}.png", dpi=300, bbox_inches="tight")
79155
plt.show(block=False)
80156

81157
# visualise landmarks
82-
_, axs = plt.subplots(10, 5, figsize=(10, 20))
83-
for i in range(10):
84-
for j in range(5):
85-
t = i * 5 + j
86-
axs[i, j].imshow(preds[..., t])
87-
axs[i, j].set_xticks([])
88-
axs[i, j].set_yticks([])
89-
if j == 0:
90-
axs[i, j].set_ylabel(f"t = {t}")
91-
plt.subplots_adjust(wspace=0.02, hspace=0.02)
92-
plt.savefig(f"landmark_heatmap_landmark_{view}_{seed}.png", dpi=300, bbox_inches="tight")
158+
fig = plot_landmarks(images, coords)
159+
fig.savefig(f"landmark_heatmap_landmark_{view}_{seed}.png", dpi=300, bbox_inches="tight")
93160
plt.show(block=False)
94161

95162
# visualise LV length changes
96-
plt.figure(figsize=(4, 3))
97-
if view == "lax_2c":
98-
# first frame is empty for this particular example
99-
lv_lengths = lv_lengths[1:]
100-
lvef = (max(lv_lengths) - min(lv_lengths)) / max(lv_lengths) * 100
101-
plt.plot(lv_lengths, color="#82B366", label="LV")
102-
plt.xlabel("Frame")
103-
plt.ylabel("Length (mm)")
104-
plt.title(f"LVEF = {lvef:.2f}%")
105-
plt.legend(loc="lower right")
106-
plt.savefig(f"landmark_heatmap_lv_length_{view}_{seed}.png", dpi=300, bbox_inches="tight")
163+
fig = plot_lv(coords)
164+
plt.savefig(f"landmark_heatmap_gls_{view}_{seed}.png", dpi=300, bbox_inches="tight")
107165
plt.show(block=False)
108166

109167

0 commit comments

Comments
 (0)