Skip to content

Commit f422263

Browse files
authored
Merge pull request #810 from effigies/fix/X_patterns
ENH: Handle wildcards in model X
2 parents 8b11a02 + 875c57c commit f422263

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

bids/modeling/statsmodels.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections import namedtuple, OrderedDict, Counter, defaultdict
55
import itertools
66
from functools import reduce
7+
import re
8+
import fnmatch
79

810
import numpy as np
911
import pandas as pd
@@ -520,6 +522,15 @@ def get_collections(self, **filters):
520522
return [c for c in self._collections if matches_entities(c, filters)]
521523

522524

525+
def expand_wildcards(selectors, pool):
526+
out = list(selectors)
527+
for spec in selectors:
528+
if re.search(r'[\*\?\[\]]', spec):
529+
idx = out.index(spec)
530+
out[idx:idx + 1] = fnmatch.filter(pool, spec)
531+
return out
532+
533+
523534
class BIDSStatsModelsNodeOutput:
524535
"""Represents a single node in a BIDSStatsModelsGraph.
525536
@@ -597,6 +608,8 @@ def merge_dfs(a, b):
597608
else:
598609
var_names.append(int_name)
599610

611+
var_names = expand_wildcards(var_names, df.columns)
612+
600613
# Verify all X names are actually present
601614
missing = list(set(var_names) - set(df.columns))
602615
if missing:
@@ -642,7 +655,9 @@ def _collections_to_dfs(self, collections):
642655

643656
# Take the intersection of variables and Model.X (var_names), ignoring missing
644657
# variables (usually contrasts)
645-
coll.variables = {v: coll.variables[v] for v in var_names if v in coll.variables}
658+
coll.variables = {v: coll.variables[v]
659+
for v in expand_wildcards(var_names, coll.variables)
660+
if v in coll.variables}
646661
if not coll.variables:
647662
continue
648663

bids/modeling/tests/test_statsmodels.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111

1212
from bids.modeling import BIDSStatsModelsGraph
13-
from bids.modeling.statsmodels import ContrastInfo
13+
from bids.modeling.statsmodels import ContrastInfo, expand_wildcards
1414
from bids.layout import BIDSLayout
1515
from bids.tests import get_test_data_path
1616
from bids.variables import BIDSVariableCollection
@@ -152,3 +152,20 @@ def test_entire_graph_smoketest(graph):
152152
assert model_spec.X.shape == (2, 1)
153153
assert model_spec.Z is None
154154
assert not set(model_spec.terms.keys()) - {"RT", "gain", "RT:gain"}
155+
156+
157+
def test_expand_wildcards():
158+
# No wildcards == no modification
159+
assert expand_wildcards(["a", "b"], ["a", "c"]) == ["a", "b"]
160+
# No matches == removal
161+
assert expand_wildcards(["a", "b*"], ["a", "c"]) == ["a"]
162+
# Matches expand in-place
163+
assert expand_wildcards(["a*", "b"], ["a", "c"]) == ["a", "b"]
164+
assert expand_wildcards(["a*", "b"], ["a0", "c", "a1", "a2"]) == ["a0", "a1", "a2", "b"]
165+
# Some examples
166+
assert expand_wildcards(
167+
["trial_type.*"], ["trial_type.A", "trial_type.B"]
168+
) == ["trial_type.A", "trial_type.B"]
169+
assert expand_wildcards(
170+
["non_steady_state*"], ["non_steady_state00", "non_steady_state01"]
171+
) == ["non_steady_state00", "non_steady_state01"]

0 commit comments

Comments
 (0)