Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mne/io/snirf/_snirf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ def __init__(
# Extract wavelengths
fnirs_wavelengths = np.array(dat.get("nirs/probe/wavelengths"))
fnirs_wavelengths = [int(w) for w in fnirs_wavelengths]
if len(fnirs_wavelengths) != 2:
if len(fnirs_wavelengths) < 2:
raise RuntimeError(
f"The data contains "
f"{len(fnirs_wavelengths)}"
f" wavelengths: {fnirs_wavelengths}. "
f"MNE only supports reading continuous"
" wave amplitude SNIRF files "
"with two wavelengths."
f"MNE requires at least two wavelengths for "
"continuous wave amplitude SNIRF files."
)

# Extract channels
Expand Down
81 changes: 58 additions & 23 deletions mne/preprocessing/nirs/_beer_lambert_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,34 @@ def beer_lambert_law(raw, ppf=6.0):
_validate_type(raw, BaseRaw, "raw")
_validate_type(ppf, ("numeric", "array-like"), "ppf")
ppf = np.array(ppf, float)
if ppf.ndim == 0: # upcast single float to shape (2,)
ppf = np.array([ppf, ppf])
if ppf.shape != (2,):
raise ValueError(
f"ppf must be float or array-like of shape (2,), got shape {ppf.shape}"
)
ppf = ppf[:, np.newaxis] # shape (2, 1)
picks = _validate_nirs_info(raw.info, fnirs="od", which="Beer-lambert")
# This is the one place we *really* need the actual/accurate frequencies
freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
abs_coef = _load_absorption(freqs)

# Get unique wavelengths and determine number of wavelengths
unique_freqs = np.unique(freqs)
n_wavelengths = len(unique_freqs)

# PPF validation for multiple wavelengths
if ppf.ndim == 0: # single float
# same PPF for all wavelengths, shape (n_wavelengths, 1)
ppf = np.full((n_wavelengths, 1), ppf)
elif ppf.ndim == 1 and len(ppf) == n_wavelengths:
# separate ppf for each wavelength
ppf = ppf[:, np.newaxis] # shape (n_wavelengths, 1)
else:
raise ValueError(
f"ppf must be a single float or an array-like of length {n_wavelengths} "
f"(number of wavelengths), got shape {ppf.shape}"
)

abs_coef = _load_absorption(unique_freqs) # shape (n_wavelengths, 2)
distances = source_detector_distances(raw.info, picks="all")
bad = ~np.isfinite(distances[picks])
bad |= distances[picks] <= 0
if bad.any():
warn(
"Source-detector distances are zero on NaN, some resulting "
"Source-detector distances are zero or NaN, some resulting "
"concentrations will be zero. Consider setting a montage "
"with raw.set_montage."
)
Expand All @@ -64,20 +75,42 @@ def beer_lambert_law(raw, ppf=6.0):
"likely due to optode locations being stored in a "
" unit other than meters."
)

rename = dict()
for ii, jj in zip(picks[::2], picks[1::2]):
EL = abs_coef * distances[ii] * ppf
iEL = pinv(EL)
channels_to_drop_all = [] # Accumulate all channels to drop

raw._data[[ii, jj]] = iEL @ raw._data[[ii, jj]] * 1e-3
# Iterate over channel groups ([Si_Di all wavelengths, Sj_Dj all wavelengths, ...])
pick_groups = zip(*[iter(picks)] * n_wavelengths)
for group_picks in pick_groups:
# Calculate Δc based on the system: ΔOD = E * L * PPF * Δc
# where E is (n_wavelengths, 2), Δc is (2, n_timepoints)
# using pseudo-inverse
EL = abs_coef * distances[group_picks[0]] * ppf
iEL = pinv(EL) # Pseudo-inverse for numerical stability
conc_data = iEL @ raw._data[group_picks] * 1e-3

# Replace the first two channels with HbO and HbR
raw._data[group_picks[:2]] = conc_data[:2] # HbO, HbR

# Update channel information
coil_dict = dict(hbo=FIFF.FIFFV_COIL_FNIRS_HBO, hbr=FIFF.FIFFV_COIL_FNIRS_HBR)
for ki, kind in zip((ii, jj), ("hbo", "hbr")):
for ki, kind in zip(group_picks[:2], ("hbo", "hbr")):
ch = raw.info["chs"][ki]
ch.update(coil_type=coil_dict[kind], unit=FIFF.FIFF_UNIT_MOL)
new_name = f"{ch['ch_name'].split(' ')[0]} {kind}"
rename[ch["ch_name"]] = new_name

# Accumulate extra wavelength channels to drop (keep only HbO and HbR)
if n_wavelengths > 2:
channels_to_drop = group_picks[2:]
channel_names_to_drop = [raw.ch_names[idx] for idx in channels_to_drop]
channels_to_drop_all.extend(channel_names_to_drop)

# Drop all accumulated extra wavelength channels after processing all groups
# This preserves channel indexing during the loop iterations
if channels_to_drop_all:
raw.drop_channels(channels_to_drop_all)

raw.rename_channels(rename)

# Validate the format of data after transformation is valid
Expand All @@ -86,7 +119,7 @@ def beer_lambert_law(raw, ppf=6.0):


def _load_absorption(freqs):
"""Load molar extinction coefficients."""
"""Load molar extinction coefficients"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor style regression

Suggested change
"""Load molar extinction coefficients"""
"""Load molar extinction coefficients."""

# Data from https://omlc.org/spectra/hemoglobin/summary.html
# The text was copied to a text file. The text before and
# after the table was deleted. The the following was run in
Expand All @@ -95,7 +128,9 @@ def _load_absorption(freqs):
# save('extinction_coef.mat', 'extinct_coef')
#
# Returns data as [[HbO2(freq1), Hb(freq1)],
# [HbO2(freq2), Hb(freq2)]]
# [HbO2(freq2), Hb(freq2)],
# ...,
# [HbO2(freqN), Hb(freqN)]]
extinction_fname = op.join(
op.dirname(__file__), "..", "..", "data", "extinction_coef.mat"
)
Expand All @@ -104,12 +139,12 @@ def _load_absorption(freqs):
interp_hbo = interp1d(a[:, 0], a[:, 1], kind="linear")
interp_hb = interp1d(a[:, 0], a[:, 2], kind="linear")

ext_coef = np.array(
[
[interp_hbo(freqs[0]), interp_hb(freqs[0])],
[interp_hbo(freqs[1]), interp_hb(freqs[1])],
]
)
abs_coef = ext_coef * 0.2303
# Build coefficient matrix for all wavelengths
# Shape: (n_wavelengths, 2) where columns are [HbO2, Hb]
ext_coef = np.zeros((len(freqs), 2))
for i, freq in enumerate(freqs):
ext_coef[i, 0] = interp_hbo(freq) # HbO2
ext_coef[i, 1] = interp_hb(freq) # Hb

abs_coef = ext_coef * 0.2303
return abs_coef
52 changes: 45 additions & 7 deletions mne/preprocessing/nirs/_scalp_coupling_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,52 @@ def scalp_coupling_index(
verbose=verbose,
).get_data()

# Determine number of wavelengths per source-detector pair
ch_wavelengths = [c["loc"][9] for c in raw.info["chs"]]
n_wavelengths = len(set(ch_wavelengths))

# freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
# n_wavelengths = len(set(unique_freqs))
Comment on lines +62 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cruft?

Suggested change
# freqs = np.array([raw.info["chs"][pick]["loc"][9] for pick in picks], float)
# n_wavelengths = len(set(unique_freqs))


sci = np.zeros(picks.shape)
for ii in range(0, len(picks), 2):
with np.errstate(invalid="ignore"):
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
if not np.isfinite(c): # someone had std=0
c = 0
sci[ii] = c
sci[ii + 1] = c

if n_wavelengths == 2:
# Use pairwise correlation for 2 wavelengths (backward compatibility)
for ii in range(0, len(picks), 2):
with np.errstate(invalid="ignore"):
c = np.corrcoef(filtered_data[ii], filtered_data[ii + 1])[0][1]
if not np.isfinite(c): # someone had std=0
c = 0
sci[ii] = c
sci[ii + 1] = c
else:
# For multiple wavelengths: calculate all pairwise correlations within each group
# and use the minimum as the quality metric
Comment on lines +78 to +79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you need a conditional here... if n_wavelengths==2 then "all pairwise correlations" should be a single correlation and give an equivalent output to the branch above... so we shouldn't need it


# Group picks by number of wavelengths
# Drops last incomplete group, but we're assuming valid data
pick_iter = iter(picks)
pick_groups = zip(*[pick_iter] * n_wavelengths)

for group_picks in pick_groups:
group_data = filtered_data[group_picks]

# Calculate pairwise correlations within the group
pair_indices = np.triu_indices(len(group_picks), k=1)
correlations = np.zeros(pair_indices[0].shape[0])

for n, (ii, jj) in enumerate(zip(*pair_indices)):
with np.errstate(invalid="ignore"):
c = np.corrcoef(group_data[ii], group_data[jj])[0][1]
if np.isfinite(c):
correlations[n] = c

# Use minimum correlation as the quality metric
group_sci = correlations.min()

# Assign the same SCI value to all channels in the group
sci[group_picks] = group_sci

sci[zero_mask] = 0
sci = sci[np.argsort(picks)] # restore original order
return sci
109 changes: 69 additions & 40 deletions mne/preprocessing/nirs/nirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
# All chromophore fNIRS data
picks_chroma = _picks_to_idx(info, ["hbo", "hbr"], exclude=[], allow_empty=True)

if (len(picks_wave) > 0) & (len(picks_chroma) > 0):
if (len(picks_wave) > 0) and (len(picks_chroma) > 0):
picks = _throw_or_return_empty(
"MNE does not support a combination of amplitude, optical "
"density, and haemoglobin data in the same raw structure.",
Expand All @@ -122,14 +122,14 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
picks = picks_chroma

pair_vals = np.array(pair_vals)
if pair_vals.shape != (2,):
if pair_vals.shape[0] < 2:
raise ValueError(
f"Exactly two {error_word} must exist in info, got {list(pair_vals)}"
f"At least two {error_word} must exist in info, got {list(pair_vals)}"
)
# In principle we do not need to require that these be sorted --
# all we need to do is change our sorted() below to make use of a
# pair_vals.index(...) in a sort key -- but in practice we always want
# (hbo, hbr) or (lower_freq, upper_freq) pairings, both of which will
# (hbo, hbr) or (lowest_freq, higher_freq, ...) pairings, both of which will
# work with a naive string sort, so let's just enforce sorted-ness here
is_str = pair_vals.dtype.kind == "U"
pair_vals = list(pair_vals)
Expand All @@ -145,16 +145,23 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
f"got {pair_vals} instead"
)

if len(picks) % 2 != 0:
# Check that the total number of channels is divisible by the number of pair values
# (e.g., for 2 wavelengths, we need even number of channels)
if len(picks) % len(pair_vals) != 0:
picks = _throw_or_return_empty(
"NIRS channels not ordered correctly. An even number of NIRS "
f"channels is required. {len(info.ch_names)} channels were"
f"provided",
f"NIRS channels not ordered correctly. The number of channels "
f"must be a multiple of {len(pair_vals)} values, but "
f"{len(picks)} channels were provided.",
throw_errors,
)

# Ensure wavelength info exists for waveform data
all_freqs = [info["chs"][ii]["loc"][9] for ii in picks_wave]
if len(pair_vals) != len(set(all_freqs)):
picks = _throw_or_return_empty(
f"The {error_word} in info must match the number of wavelengths, "
f"but the data contains {len(set(all_freqs))} wavelengths instead.",
throw_errors,
)
if np.any(np.isnan(all_freqs)):
picks = _throw_or_return_empty(
f"NIRS channels is missing wavelength information in the "
Expand Down Expand Up @@ -189,40 +196,47 @@ def _check_channels_ordered(info, pair_vals, *, throw_errors=True, check_bads=Tr
# Reorder to be paired (naive sort okay here given validation above)
picks = picks[np.argsort([info["ch_names"][pick] for pick in picks])]

# Validate our paired ordering
for ii, jj in zip(picks[::2], picks[1::2]):
ch1_name = info["chs"][ii]["ch_name"]
ch2_name = info["chs"][jj]["ch_name"]
ch1_re = use_RE.match(ch1_name)
ch2_re = use_RE.match(ch2_name)
ch1_S, ch1_D, ch1_value = ch1_re.groups()[:3]
ch2_S, ch2_D, ch2_value = ch2_re.groups()[:3]
if len(picks_wave):
ch1_value, ch2_value = float(ch1_value), float(ch2_value)
if (
(ch1_S != ch2_S)
or (ch1_D != ch2_D)
or (ch1_value != pair_vals[0])
or (ch2_value != pair_vals[1])
# Validate channel grouping (same source-detector pairs, all pair_vals match)
for i in range(0, len(picks), len(pair_vals)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We prefer ii or pi to i, more clearly different from 1j and readable etc.

# Extract a group of channels (e.g., all wavelengths for one S-D pair)
group_picks = picks[i : i + len(pair_vals)]

# Parse channel names using regex to extract source, detector, and value info
group_info = [
(use_RE.match(info["ch_names"][pick]).groups() or (pick, 0, 0))
for pick in group_picks
]

# Separate the parsed components: source IDs, detector IDs, and values (freq/chromophore)
s_group, d_group, val_group = zip(*group_info)

# For wavelength data, convert string frequencies to float for comparison
if len(picks_wave) > 0:
val_group = [float(v) for v in val_group]

# Verify that all channels in this group have the same source-detector pair
# and that the values match the expected pair_vals sequence
if not (
len(set(s_group)) == 1 and len(set(d_group)) == 1 and val_group == pair_vals
):
picks = _throw_or_return_empty(
"NIRS channels not ordered correctly. Channels must be "
"ordered as source detector pairs with alternating"
f" {error_word} {pair_vals[0]} & {pair_vals[1]}, but got "
f"S{ch1_S}_D{ch1_D} pair "
f"{repr(ch1_name)} and {repr(ch2_name)}",
"grouped by source-detector pairs with alternating {error_word} "
f"values {pair_vals}, but got mismatching names {[info['ch_names'][pick] for pick in group_picks]}.",
throw_errors,
)
break

if check_bads:
for ii, jj in zip(picks[::2], picks[1::2]):
want = [info.ch_names[ii], info.ch_names[jj]]
for i in range(0, len(picks), len(pair_vals)):
group_picks = picks[i : i + len(pair_vals)]

want = [info.ch_names[pick] for pick in group_picks]
got = list(set(info["bads"]).intersection(want))
if len(got) == 1:
if 0 < len(got) < len(want):
raise RuntimeError(
f"NIRS bad labelling is not consistent, found {got} but "
f"needed {want}"
"NIRS bad labelling is not consistent. "
f"Found {got} but needed {want}. "
)
return picks

Expand Down Expand Up @@ -276,14 +290,29 @@ def _fnirs_spread_bads(info):
# as bad and spread the bad marking to all components of the optode pair.
picks = _validate_nirs_info(info, check_bads=False)
new_bads = set(info["bads"])
for ii, jj in zip(picks[::2], picks[1::2]):
ch1_name, ch2_name = info.ch_names[ii], info.ch_names[jj]
if ch1_name in new_bads:
new_bads.add(ch2_name)
elif ch2_name in new_bads:
new_bads.add(ch1_name)
info["bads"] = sorted(new_bads)

# Extract SD pair groups from channel names
# E.g. all channels belonging to S1D1, S1D2, etc.
# Assumes valid channels (naming convention and number)
ch_names = [info.ch_names[i] for i in picks]
match = re.compile(r"^(S\d+_D\d+) ")

# Create dict with keys corresponding to SD pairs
# Defaultdict would require another import
sd_groups = {}
for ch_name in ch_names:
sd_pair = match.match(ch_name).group(1)
if sd_pair not in sd_groups:
sd_groups[sd_pair] = [ch_name]
else:
sd_groups[sd_pair].append(ch_name)

# Spread bad labeling across SD pairs
for channels in sd_groups.values():
if any(channel in new_bads for channel in channels):
new_bads.update(channels)

info["bads"] = sorted(new_bads)
return info


Expand Down
Loading