Skip to content

Commit 6b187f8

Browse files
authored
Merge pull request #114 from jhlegarreta/FixBASalientDataIdentification
ENH: Fix Bland-Altman plot salient data identification
2 parents 1cef030 + 24a20d4 commit 6b187f8

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

src/nifreeze/analysis/measure_agreement.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,9 @@ def identify_bland_altman_salient_data(
206206
`top_n` data points from the BA plot.
207207
208208
Once the left-most data points identified, the right-most `percentile` data
209-
points are considered from the remaining data points, and `top_n` data
210-
points are identified out of these.
209+
points are considered from the remaining data points.
210+
The ``top_n`` data points closest to the zero mean difference are
211+
identified among these.
211212
212213
Parameters
213214
----------
@@ -245,6 +246,18 @@ def identify_bland_altman_salient_data(
245246
reliability_mask = get_reliability_mask(diff, loa_lower, loa_upper)
246247
reliability_idx = np.where(reliability_mask)[0]
247248

249+
# Check that there are enough data points left to identify the requested
250+
# number of salient data points
251+
reliability_point_count = len(reliability_idx)
252+
salient_point_count = 2 * top_n
253+
if reliability_point_count < salient_point_count:
254+
raise ValueError(
255+
f"Too few reliable data points ({reliability_point_count}) to "
256+
f"identify the requested Bland-Altman salient points "
257+
f"(2 * {top_n}). Reduce the number of salient data points "
258+
f"requested ({top_n})"
259+
)
260+
248261
# Select the top_n lowest median values from the left side of the BA plot
249262
lower_idx = np.argsort(mean[reliability_idx])[:top_n]
250263
left_indices = reliability_idx[lower_idx]
@@ -258,18 +271,29 @@ def identify_bland_altman_salient_data(
258271
# Sort indices by descending mean (rightmost values first)
259272
right_sort_mean = remaining_idx[np.argsort(mean[remaining_idx])[::-1]]
260273

261-
# Take top percentile of the rightmost points
274+
# Take a percentile of the rightmost points
262275
top_p_count = int(percentile * len(right_sort_mean))
263276
top_p_sorted = right_sort_mean[:top_p_count]
264277

278+
# Check that there are enough data points left to identify the requested
279+
# number of rightmost points
280+
if top_p_count < top_n:
281+
raise ValueError(
282+
f"Too few data points ({top_p_count}) to identify the requested "
283+
f"Bland-Altman right-most salient points ({top_n}). Increase the "
284+
f"percentile requested ({top_n})"
285+
)
286+
265287
# Get absolute difference from mean_diff (closeness to zero mean difference)
266288
diff_distance = np.abs(diff[top_p_sorted] - mean_diff)
267289

268290
# Sort rightmost points by closeness to zero diff
269-
upper_idx = np.argsort(diff_distance)
291+
top_p_idx = np.argsort(diff_distance)
270292

271293
# Take top_n of them
272-
right_mask = right_sort_mean[upper_idx[:top_n]]
294+
upper_idx = top_p_sorted[top_p_idx][:top_n]
295+
right_mask = np.zeros_like(reliability_mask, dtype=bool)
296+
right_mask[upper_idx] = True
273297

274298
return {
275299
BASalientEntity.RELIABILITY_INDICES.value: reliability_idx,

test/test_analysis.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import pytest
2727

2828
from nifreeze.analysis.measure_agreement import (
29+
BASalientEntity,
2930
compute_bland_altman_features,
3031
compute_z_score,
32+
identify_bland_altman_salient_data,
3133
)
3234

3335

@@ -105,3 +107,37 @@ def test_compute_bland_altman_features(request):
105107
assert loa_lower < loa_upper
106108
assert np.isscalar(ci_mean)
107109
assert np.isscalar(ci_loa)
110+
111+
112+
def test_identify_bland_altman_salient_data():
113+
_data1 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
114+
_data2 = np.array([1.1, 2.1, 1.1, 2.7, 3.4, 5.1, 2.2, 6.3, 7.6, 8.2])
115+
116+
ci = 0.95
117+
118+
# Verify that a sufficient number of data points exists to get the requested
119+
# number of salient data points exists
120+
top_n = 6
121+
with pytest.raises(ValueError):
122+
identify_bland_altman_salient_data(_data1, _data2, ci, top_n)
123+
124+
top_n = 4
125+
126+
# Verify that the percentile is not restrictive enough to get the requested
127+
# number of rightmost salient data points exists
128+
percentile = 0.75
129+
with pytest.raises(ValueError):
130+
identify_bland_altman_salient_data(_data1, _data2, ci, top_n, percentile=percentile)
131+
132+
percentile = 0.8
133+
salient_data = identify_bland_altman_salient_data(
134+
_data1, _data2, ci, top_n, percentile=percentile
135+
)
136+
137+
assert len(salient_data[BASalientEntity.RELIABILITY_MASK.value]) == len(_data1)
138+
139+
assert len(salient_data[BASalientEntity.LEFT_INDICES.value]) == top_n
140+
assert len(salient_data[BASalientEntity.LEFT_MASK.value]) == len(_data1)
141+
142+
assert len(salient_data[BASalientEntity.RIGHT_INDICES.value]) == top_n
143+
assert len(salient_data[BASalientEntity.RIGHT_MASK.value]) == len(_data1)

0 commit comments

Comments
 (0)