-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
ENH: Support for Multi-Wavelength (>2) NIRS/SNIRF Data Processing #13408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3ee2db8
3e71426
00fca67
2b39648
aa285f4
4bae23e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -56,14 +56,52 @@ def scalp_coupling_index( | |||||||
verbose=verbose, | ||||||||
).get_data() | ||||||||
|
||||||||
# Determine number of wavelengths per source-detector pair | ||||||||
ch_wavelengths = [c["loc"][9] for c in raw.info["chs"]] | ||||||||
n_wavelengths = len(set(ch_wavelengths)) | ||||||||
|
||||||||
# freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float) | ||||||||
# n_wavelengths = len(set(unique_freqs)) | ||||||||
Comment on lines
+62
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cruft?
Suggested change
|
||||||||
|
||||||||
sci = np.zeros(picks.shape) | ||||||||
for ii in range(0, len(picks), 2): | ||||||||
with np.errstate(invalid="ignore"): | ||||||||
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1] | ||||||||
if not np.isfinite(c): # someone had std=0 | ||||||||
c = 0 | ||||||||
sci[ii] = c | ||||||||
sci[ii + 1] = c | ||||||||
|
||||||||
if n_wavelengths == 2: | ||||||||
# Use pairwise correlation for 2 wavelengths (backward compatibility) | ||||||||
for ii in range(0, len(picks), 2): | ||||||||
with np.errstate(invalid="ignore"): | ||||||||
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1] | ||||||||
if not np.isfinite(c): # someone had std=0 | ||||||||
c = 0 | ||||||||
sci[ii] = c | ||||||||
sci[ii + 1] = c | ||||||||
else: | ||||||||
# For multiple wavelengths: calculate all pairwise correlations within each group | ||||||||
# and use the minimum as the quality metric | ||||||||
Comment on lines
+78
to
+79
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why you need a conditional here... if |
||||||||
|
||||||||
# Group picks by number of wavelengths | ||||||||
# Drops last incomplete group, but we're assuming valid data | ||||||||
pick_iter = iter(picks) | ||||||||
pick_groups = zip(*[pick_iter] * n_wavelengths) | ||||||||
|
||||||||
for group_picks in pick_groups: | ||||||||
group_data = filtered_data[group_picks] | ||||||||
|
||||||||
# Calculate pairwise correlations within the group | ||||||||
pair_indices = np.triu_indices(len(group_picks), k=1) | ||||||||
correlations = np.zeros(pair_indices[0].shape[0]) | ||||||||
|
||||||||
for n, (ii, jj) in enumerate(zip(*pair_indices)): | ||||||||
with np.errstate(invalid="ignore"): | ||||||||
c = np.corrcoef(group_data[ii], group_data[jj])[0][1] | ||||||||
if np.isfinite(c): | ||||||||
correlations[n] = c | ||||||||
|
||||||||
# Use minimum correlation as the quality metric | ||||||||
group_sci = correlations.min() | ||||||||
|
||||||||
# Assign the same SCI value to all channels in the group | ||||||||
sci[group_picks] = group_sci | ||||||||
|
||||||||
sci[zero_mask] = 0 | ||||||||
sci = sci[np.argsort(picks)] # restore original order | ||||||||
return sci |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,7 +104,7 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr | |
# All chromophore fNIRS data | ||
picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True) | ||
|
||
if (len(picks_wave) > 0) & (len(picks_chroma) > 0): | ||
if (len(picks_wave) > 0) and (len(picks_chroma) > 0): | ||
picks = _throw_or_return_empty( | ||
"MNE does not support a combination of amplitude, optical " | ||
"density, and haemoglobin data in the same raw structure.", | ||
|
@@ -122,14 +122,14 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr | |
picks = picks_chroma | ||
|
||
pair_vals = np.array(pair_vals) | ||
if pair_vals.shape != (2,): | ||
if pair_vals.shape[0] < 2: | ||
raise ValueError( | ||
f"Exactly two {error_word} must exist in info, got {list(pair_vals)}" | ||
f"At least two {error_word} must exist in info, got {list(pair_vals)}" | ||
) | ||
# In principle we do not need to require that these be sorted -- | ||
# all we need to do is change our sorted() below to make use of a | ||
# pair_vals.index(...) in a sort key -- but in practice we always want | ||
# (hbo, hbr) or (lower_freq, upper_freq) pairings, both of which will | ||
# (hbo, hbr) or (lowest_freq, higher_freq, ...) pairings, both of which will | ||
# work with a naive string sort, so let's just enforce sorted-ness here | ||
is_str = pair_vals.dtype.kind == "U" | ||
pair_vals = list(pair_vals) | ||
|
@@ -145,16 +145,23 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr | |
f"got {pair_vals} instead" | ||
) | ||
|
||
if len(picks) % 2 != 0: | ||
# Check that the total number of channels is divisible by the number of pair values | ||
# (e.g., for 2 wavelengths, we need even number of channels) | ||
if len(picks) % len(pair_vals) != 0: | ||
picks = _throw_or_return_empty( | ||
"NIRS channels not ordered correctly. An even number of NIRS " | ||
f"channels is required. {len(info.ch_names)} channels were" | ||
f"provided", | ||
f"NIRS channels not ordered correctly. The number of channels " | ||
f"must be a multiple of {len(pair_vals)} values, but " | ||
f"{len(picks)} channels were provided.", | ||
throw_errors, | ||
) | ||
|
||
# Ensure wavelength info exists for waveform data | ||
all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave] | ||
if len(pair_vals) != len(set(all_freqs)): | ||
picks = _throw_or_return_empty( | ||
f"The {error_word} in info must match the number of wavelengths, " | ||
f"but the data contains {len(set(all_freqs))} wavelengths instead.", | ||
throw_errors, | ||
) | ||
if np.any(np.isnan(all_freqs)): | ||
picks = _throw_or_return_empty( | ||
f"NIRS channels is missing wavelength information in the " | ||
|
@@ -189,40 +196,47 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr | |
# Reorder to be paired (naive sort okay here given validation above) | ||
picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])] | ||
|
||
# Validate our paired ordering | ||
for ii, jj in zip(picks[::2], picks[1::2]): | ||
ch1_name = info["chs"][ii]["ch_name"] | ||
ch2_name = info["chs"][jj]["ch_name"] | ||
ch1_re = use_RE.match(ch1_name) | ||
ch2_re = use_RE.match(ch2_name) | ||
ch1_S, ch1_D, ch1_value = ch1_re.groups()[:3] | ||
ch2_S, ch2_D, ch2_value = ch2_re.groups()[:3] | ||
if len(picks_wave): | ||
ch1_value, ch2_value = float(ch1_value), float(ch2_value) | ||
if ( | ||
(ch1_S != ch2_S) | ||
or (ch1_D != ch2_D) | ||
or (ch1_value != pair_vals[0]) | ||
or (ch2_value != pair_vals[1]) | ||
# Validate channel grouping (same source-detector pairs, all pair_vals match) | ||
for i in range(0, len(picks), len(pair_vals)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We prefer |
||
# Extract a group of channels (e.g., all wavelengths for one S-D pair) | ||
group_picks = picks[i : i + len(pair_vals)] | ||
|
||
# Parse channel names using regex to extract source, detector, and value info | ||
group_info = [ | ||
(use_RE.match(info["ch_names"][pick]).groups() or (pick, 0, 0)) | ||
for pick in group_picks | ||
] | ||
|
||
# Separate the parsed components: source IDs, detector IDs, and values (freq/chromophore) | ||
s_group, d_group, val_group = zip(*group_info) | ||
|
||
# For wavelength data, convert string frequencies to float for comparison | ||
if len(picks_wave) > 0: | ||
val_group = [float(v) for v in val_group] | ||
|
||
# Verify that all channels in this group have the same source-detector pair | ||
# and that the values match the expected pair_vals sequence | ||
if not ( | ||
len(set(s_group)) == 1 and len(set(d_group)) == 1 and val_group == pair_vals | ||
): | ||
picks = _throw_or_return_empty( | ||
"NIRS channels not ordered correctly. Channels must be " | ||
"ordered as source detector pairs with alternating" | ||
f" {error_word} {pair_vals[0]} & {pair_vals[1]}, but got " | ||
f"S{ch1_S}_D{ch1_D} pair " | ||
f"{repr(ch1_name)} and {repr(ch2_name)}", | ||
"grouped by source-detector pairs with alternating {error_word} " | ||
f"values {pair_vals}, but got mismatching names {[info['ch_names'][pick] for pick in group_picks]}.", | ||
throw_errors, | ||
) | ||
break | ||
|
||
if check_bads: | ||
for ii, jj in zip(picks[::2], picks[1::2]): | ||
want = [info.ch_names[ii], info.ch_names[jj]] | ||
for i in range(0, len(picks), len(pair_vals)): | ||
group_picks = picks[i : i + len(pair_vals)] | ||
|
||
want = [info.ch_names[pick] for pick in group_picks] | ||
got = list(set(info["bads"]).intersection(want)) | ||
if len(got) == 1: | ||
if 0 < len(got) < len(want): | ||
raise RuntimeError( | ||
f"NIRS bad labelling is not consistent, found {got} but " | ||
f"needed {want}" | ||
"NIRS bad labelling is not consistent. " | ||
f"Found {got} but needed {want}. " | ||
) | ||
return picks | ||
|
||
|
@@ -276,14 +290,29 @@ def _fnirs_spread_bads(info): | |
# as bad and spread the bad marking to all components of the optode pair. | ||
picks = _validate_nirs_info(info, check_bads=False) | ||
new_bads = set(info["bads"]) | ||
for ii, jj in zip(picks[::2], picks[1::2]): | ||
ch1_name, ch2_name = info.ch_names[ii], info.ch_names[jj] | ||
if ch1_name in new_bads: | ||
new_bads.add(ch2_name) | ||
elif ch2_name in new_bads: | ||
new_bads.add(ch1_name) | ||
info["bads"] = sorted(new_bads) | ||
|
||
# Extract SD pair groups from channel names | ||
# E.g. all channels belonging to S1D1, S1D2, etc. | ||
# Assumes valid channels (naming convention and number) | ||
ch_names = [info.ch_names[i] for i in picks] | ||
match = re.compile(r"^(S\d+_D\d+) ") | ||
|
||
# Create dict with keys corresponding to SD pairs | ||
# Defaultdict would require another import | ||
sd_groups = {} | ||
for ch_name in ch_names: | ||
sd_pair = match.match(ch_name).group(1) | ||
if sd_pair not in sd_groups: | ||
sd_groups[sd_pair] = [ch_name] | ||
else: | ||
sd_groups[sd_pair].append(ch_name) | ||
|
||
# Spread bad labeling across SD pairs | ||
for channels in sd_groups.values(): | ||
if any(channel in new_bads for channel in channels): | ||
new_bads.update(channels) | ||
|
||
info["bads"] = sorted(new_bads) | ||
return info | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a minor style regression