Skip to content

Commit faa825b

Browse files
committed
refactor onsets_to_dm to wrap new nilearn functionality instead
1 parent c4e8b97 commit faa825b

File tree

10 files changed

+216
-375
lines changed

10 files changed

+216
-375
lines changed

docs/tutorials/01_DataOperations/plot_design_matrix.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@
864864
],
865865
"metadata": {
866866
"kernelspec": {
867-
"display_name": "Python 3",
867+
"display_name": ".venv",
868868
"language": "python",
869869
"name": "python3"
870870
},

nltools/file_reader.py

Lines changed: 65 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -6,162 +6,87 @@
66

77
__all__ = ["onsets_to_dm"]
88

9-
import pandas as pd
109
import numpy as np
11-
from nltools.data import Design_Matrix
12-
import warnings
13-
from pathlib import Path
10+
from typing import Callable
11+
from nilearn.glm.first_level import make_first_level_design_matrix as make_dm
12+
from .external import glover_hrf
13+
from .data import Design_Matrix
1414

1515

1616
def onsets_to_dm(
17-
F,
18-
sampling_freq,
17+
timings,
1918
run_length,
20-
header="infer",
21-
sort=False,
22-
keep_separate=True,
23-
add_poly=None,
24-
unique_cols=None,
19+
TR,
20+
hrf_model="glover",
21+
drift_model=None,
22+
high_pass=0.01,
23+
drift_order=0,
2524
fill_na=None,
26-
verbose=True,
2725
**kwargs,
2826
):
29-
"""
30-
This function can assist in reading in one or several in a 2-3 column onsets files, specified in seconds and converting it to a Design Matrix organized as samples X Stimulus Classes. sampling_freq should be specified in hertz; for TRs use hertz = 1/TR. Onsets files **must** be organized with columns in one of the following 4 formats:
27+
"""Read 1 or more file paths and return 1 or more design matrices.
3128
32-
1) 'Stim, Onset'
33-
2) 'Onset, Stim'
34-
3) 'Stim, Onset, Duration'
35-
4) 'Onset, Duration, Stim'
29+
Your timing file needs have the following column names:
3630
37-
No other file organizations are currently supported. *Note:* Stimulus offsets (onset + duration) that fall into an adjacent TR include that full TR. E.g. offset of 10.16s with TR = 2 has an offset of TR 5, which spans 10-12s, rather than an offset of TR 4, which spans 8-10s.
31+
- 'onset': required
32+
- 'duration': required
33+
- 'trial_type': optional
34+
- 'modulation': optional
3835
39-
Args:
40-
F (str/Path/pd.DataFrame): filepath or pandas dataframe
41-
sampling_freq (float): samping frequency in hertz, i.e 1 / TR
42-
run_length (int): run length in number of TRs
43-
header (str/None, optional): whether there's an additional header row in the
44-
supplied file/dataframe. See `pd.read_csv` for more details. Defaults to `"infer"`.
45-
sort (bool, optional): whether to sort dataframe columns alphabetically. Defaults to False.
46-
keep_separate (bool, optional): if a list of files or dataframes is supplied,
47-
whether to create separate polynomial columns per file. Defaults to `True`.
48-
add_poly (bool/int, optional): whether to add Nth order polynomials to design
49-
matrix. Defaults to None.
50-
unique_cols (list/None, optional): if a list of files or dataframes is supplied,
51-
what additional columns to keep separate per file (e.g. spikes). Defaults to None.
52-
fill_na (Any, optional): what to replace NaNs with. Defaults to None (no filling).
36+
This function is a wrapper around [`nilearn.glm.first_level.make_first_level_design_matrix`](https://nilearn.github.io/stable/modules/generated/nilearn.glm.first_level.make_first_level_design_matrix.html#nilearn.glm.first_level.make_first_level_design_matrix) which is more robust that older implementations.
5337
38+
However, the default options are **different** and create a design matrix with minimal additional modifications. You can use kwargs to control settings to also convolve predictors with a variety of HRF functions, add nuisance parameters, drift and cosine functions, etc.
39+
40+
Args:
41+
timings (str, Path, pd.DataFrame, list): file(s) or dataframe(s) containing stimulus timing
42+
run_length (int, list): number or list of numbers for run lengths in TRs
43+
TR (float, optional): repetition time in seconds. Defaults to None.
44+
hrf_model (str, optional): convolve each column of the design matrix (e.g. 'glover'). Defaults to None.
45+
drift_model (str, optional): how to add drift ('cosine' or 'polynomial'). Defaults to None.
46+
high_pass (float, optional): high-pass frequency if drift_model='cosine'. Defaults to 0.01
47+
drift_order (int, optional): what order if drift_model='polynomial'. Defaults to 0.
48+
fill_na (_type_, optional): _description_. Defaults to None.
5449
5550
Returns:
56-
nltools.data.Design_Matrix: design matrix organized as TRs x Stims
51+
_type_: _description_
5752
"""
58-
59-
if not isinstance(F, list):
60-
F = [F]
61-
62-
if not isinstance(sampling_freq, (float, np.floating)):
63-
raise TypeError("sampling_freq must be a float")
53+
if not isinstance(timings, list):
54+
timings = [timings]
55+
if not isinstance(run_length, list):
56+
run_length = [run_length]
57+
if len(timings) != len(run_length):
58+
raise ValueError("timings and run_length must have the same length")
59+
60+
# Nilearn auto-calculates approximate TR from diff-ing timings
61+
# when passing a string name to hrf_model
62+
# This approach gives us more control using the TR kwarg
63+
if TR is not None:
64+
if hrf_model == "glover":
65+
hrf_model = lambda arg1, oversampling: glover_hrf(TR, oversampling)
6466

6567
out = []
66-
TR = 1.0 / sampling_freq
67-
for f in F:
68-
if isinstance(f, str) or isinstance(f, Path):
69-
df = pd.read_csv(f, header=header, **kwargs)
70-
elif isinstance(f, pd.core.frame.DataFrame):
71-
df = f.copy()
72-
else:
73-
raise TypeError("Input needs to be file path or pandas dataframe!")
74-
75-
if verbose and df.shape[1] == 2:
76-
warnings.warn(
77-
"Only 2 columns detected in onset file (onset, stimulus). "
78-
"Assuming all stimuli have the same duration. "
79-
"Use 3 columns (onset, duration, stimulus) for variable durations.",
80-
UserWarning,
81-
)
82-
elif df.shape[1] == 1 or df.shape[1] > 3:
83-
raise ValueError("Can only handle files with 2 or 3 columns!")
84-
85-
# Try to infer the header
86-
if header is None:
87-
possibleHeaders = ["Stim", "Onset", "Duration"]
88-
if isinstance(df.iloc[0, 0], str):
89-
df.columns = possibleHeaders[: df.shape[1]]
90-
elif isinstance(df.iloc[0, df.shape[1] - 1], str):
91-
df.columns = possibleHeaders[1:] + [possibleHeaders[0]]
92-
else:
93-
raise ValueError(
94-
"Can't figure out onset file organization. Make sure file has no more than 3 columns specified as 'Stim,Onset,Duration' or 'Onset,Duration,Stim'"
95-
)
96-
97-
# Compute an offset in seconds if a Duration is provided
98-
if df.shape[1] == 3:
99-
df["Offset"] = df["Onset"] + df["Duration"]
100-
# Onset always starts at the closest TR rounded down, e.g.
101-
# with TR = 2, and onset = 10.1 or 11.7 will both have onset of TR 5 as it spans the window 10-12s
102-
df["Onset"] = df["Onset"].apply(lambda x: int(np.floor(x / TR)))
103-
104-
# Offset includes the subsequent if Offset falls within window covered by that TR
105-
# but not if it falls exactly on the subsequent TR, e.g. if TR = 2, and offset = 10.16, then TR 5 will be included but if offset = 10.00, TR 5 will not be included, as it covers the window 10-12s
106-
if "Offset" in df.columns:
107-
108-
def conditional_round(x, TR):
109-
"""Conditional rounding to the next TR if offset falls within window, otherwise not"""
110-
dur_in_TRs = x / TR
111-
dur_in_TRs_rounded_down = np.floor(dur_in_TRs)
112-
# If in the future we wanted to enable the ability to include a TR based on a % of that TR we can change the next line to compare to some value, e.g. at least 0.5s into that TR: dur_in_TRs - dur_in_TRs_rounded_down > 0.5
113-
if dur_in_TRs > dur_in_TRs_rounded_down:
114-
return dur_in_TRs_rounded_down
115-
else:
116-
return dur_in_TRs_rounded_down - 1
117-
118-
# Apply function
119-
df["Offset"] = df["Offset"].apply(conditional_round, args=(TR,))
120-
121-
# Build dummy codes
122-
X = Design_Matrix(
123-
np.zeros([run_length, df["Stim"].nunique()]),
124-
columns=df["Stim"].unique(),
125-
sampling_freq=sampling_freq,
68+
for file, run in zip(timings, run_length):
69+
frame_times = np.arange(run) * TR
70+
dm = make_dm(
71+
frame_times,
72+
events=file,
73+
hrf_model=hrf_model,
74+
drift_model=drift_model,
75+
high_pass=high_pass,
76+
drift_order=drift_order,
77+
**kwargs,
12678
)
127-
for i, row in df.iterrows():
128-
if "Offset" in df.columns:
129-
X.loc[row["Onset"] : row["Offset"], row["Stim"]] = 1
130-
else:
131-
X.loc[row["Onset"], row["Stim"]] = 1
132-
# DISABLED cause this isn't quite accurate for stimuli of different durations
133-
# Run a check
134-
# if "Offset" in df.columns:
135-
# onsets = X.sum().values
136-
# stim_counts = data.Stim.value_counts(sort=False)[X.columns]
137-
# durations = data.groupby("Stim").Duration.mean().values
138-
# for i, (o, c, d) in enumerate(zip(onsets, stim_counts, durations)):
139-
# if c * (d / TR) <= o <= c * ((d / TR) + 1):
140-
# pass
141-
# else:
142-
# warnings.warn(
143-
# f"Computed onsets for {data.Stim.unique()[i]} are inconsistent ({o}) with expected values ({c * (d / TR)} to {c * ((d / TR) + 1)}). Please manually verify the outputted Design_Matrix!"
144-
# )
145-
146-
if sort:
147-
X = X.reindex(sorted(X.columns), axis=1)
148-
149-
out.append(X)
150-
if len(out) > 1:
151-
if add_poly is not None:
152-
out = [e.add_poly(add_poly) for e in out]
153-
154-
out_dm = out[0].append(
155-
out[1:],
156-
keep_separate=keep_separate,
157-
unique_cols=unique_cols,
158-
fill_na=fill_na,
159-
)
160-
else:
161-
out_dm = out[0]
162-
if add_poly is not None:
163-
out_dm = out_dm.add_poly(add_poly)
164-
if fill_na is not None:
165-
out_dm = out_dm.fill_na(fill_na)
79+
dm = dm.fill_na(fill_na) if fill_na is not None else dm
80+
if isinstance(hrf_model, Callable):
81+
dm.columns = [c.rstrip("_<lambda>") for c in dm.columns]
82+
if hrf_model is not None:
83+
convolved = [
84+
c for c in dm.columns if "drift" not in c and "constant" not in c
85+
]
86+
polys = [c for c in dm.columns if "drift" in c or "constant" in c]
87+
else:
88+
convolved, polys = [], []
89+
dm = Design_Matrix(dm, convolved=convolved, sampling_freq=1 / TR, polys=polys)
90+
out.append(dm)
16691

167-
return out_dm
92+
return out if len(out) > 1 else out[0]

nltools/prefs.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,24 @@
99
@dataclass
1010
class MNI_Template_Factory:
1111
"""Global MNI template configuration.
12-
12+
1313
This class manages the global MNI template settings used throughout nltools.
1414
Users should interact with the exported MNI_Template instance rather than
1515
creating new instances.
16-
16+
1717
Parameters
1818
----------
1919
template : {'default', 'nilearn', 'fmriprep'}
2020
Template variant to use. Each template represents a different MNI space:
2121
- 'default': Original MNI152 6th generation templates
22-
- 'nilearn': Nilearn's MNI152 templates
22+
- 'nilearn': Nilearn's MNI152 templates
2323
- 'fmriprep': fMRIPrep's MNI152NLin2009cAsym templates
2424
resolution : {1, 2, 3}
2525
Resolution in mm. Not all resolutions are available for all templates:
2626
- 'default': 2mm, 3mm
2727
- 'nilearn': 1mm, 2mm, 3mm
2828
- 'fmriprep': 1mm, 2mm
29-
29+
3030
Attributes
3131
----------
3232
mask : str
@@ -35,51 +35,53 @@ class MNI_Template_Factory:
3535
Path to the brain-extracted image
3636
plot : str
3737
Path to the full T1 image for plotting
38-
38+
3939
Examples
4040
--------
4141
>>> from nltools.prefs import MNI_Template
4242
>>> MNI_Template.template = 'fmriprep'
4343
>>> MNI_Template.resolution = 1
4444
>>> print(MNI_Template.mask)
4545
"""
46-
46+
4747
template: Literal["default", "nilearn", "fmriprep"] = "default"
4848
resolution: Literal[1, 2, 3] = 2
49-
49+
5050
# Auto-populated paths
5151
mask: str = field(init=False)
5252
brain: str = field(init=False)
5353
plot: str = field(init=False)
54-
54+
5555
# Define supported combinations
5656
_supported_combinations = {
5757
"default": [2, 3],
5858
"nilearn": [1, 2, 3],
59-
"fmriprep": [1, 2]
59+
"fmriprep": [1, 2],
6060
}
61-
61+
6262
def __post_init__(self):
6363
"""Initialize paths after dataclass creation."""
6464
self._validate_and_resolve()
65-
65+
6666
def __setattr__(self, name, value):
6767
"""Custom setter to re-resolve paths when attributes change."""
6868
# Use object.__setattr__ to avoid recursion
6969
object.__setattr__(self, name, value)
7070
# Only resolve paths if we're setting template or resolution
7171
# and the object has been fully initialized
72-
if name in ["template", "resolution"] and hasattr(self, "_validate_and_resolve"):
72+
if name in ["template", "resolution"] and hasattr(
73+
self, "_validate_and_resolve"
74+
):
7375
self._validate_and_resolve()
74-
76+
7577
def __repr__(self) -> str:
7678
return (
7779
f"MNI_Template(template='{self.template}', resolution={self.resolution}mm)\n"
7880
f" mask: {os.path.basename(self.mask)}\n"
7981
f" brain: {os.path.basename(self.brain)}\n"
8082
f" plot: {os.path.basename(self.plot)}"
8183
)
82-
84+
8385
def _validate_and_resolve(self):
8486
"""Validate inputs and resolve file paths."""
8587
# Validate resolution is supported for this template
@@ -88,18 +90,22 @@ def _validate_and_resolve(self):
8890
f"Resolution {self.resolution}mm is not supported for template '{self.template}'. "
8991
f"Supported resolutions: {self._supported_combinations[self.template]}"
9092
)
91-
93+
9294
# Build paths based on template and resolution
9395
base_path = join(dirname(__file__), "resources", "niftis", self.template)
9496
res_str = f"{self.resolution}mm"
95-
97+
9698
# Set paths following the naming convention
9799
self.mask = join(base_path, f"MNI152_{res_str}_mask.nii.gz")
98100
self.brain = join(base_path, f"MNI152_{res_str}_brain.nii.gz")
99101
self.plot = join(base_path, f"MNI152_{res_str}_T1.nii.gz")
100-
102+
101103
# Verify files exist
102-
for attr, path in [("mask", self.mask), ("brain", self.brain), ("plot", self.plot)]:
104+
for attr, path in [
105+
("mask", self.mask),
106+
("brain", self.brain),
107+
("plot", self.plot),
108+
]:
103109
if not os.path.exists(path):
104110
raise FileNotFoundError(
105111
f"Template file not found: {path}\n"
@@ -109,4 +115,4 @@ def _validate_and_resolve(self):
109115

110116
# NOTE: We export this from the module and expect users to interact with it instead of
111117
# the class constructor above
112-
MNI_Template = MNI_Template_Factory()
118+
MNI_Template = MNI_Template_Factory()

0 commit comments

Comments
 (0)