Skip to content

Commit 04a102f

Browse files
committed
Fix formatting
1 parent 0bb6549 commit 04a102f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

cebra/data/multi_session.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ class MultiSessionLoader(cebra_data.Loader):
130130

131131
def __post_init__(self):
132132
super().__post_init__()
133-
self.sampler = cebra.distributions.MultisessionSampler(self.dataset,
134-
self.time_offset)
133+
self.sampler = cebra.distributions.MultisessionSampler(
134+
self.dataset, self.time_offset)
135135

136136
def get_indices(self, num_samples: int) -> List[BatchIndex]:
137137
ref_idx = self.sampler.sample_prior(self.batch_size)
@@ -169,7 +169,8 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
169169
# Overwrite sampler with the discrete implementation
170170
# Generalize MultisessionSampler to avoid doing this?
171171
def __post_init__(self):
172-
self.sampler = cebra.distributions.DiscreteMultisessionSampler(self.dataset)
172+
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
173+
self.dataset)
173174

174175
@property
175176
def index(self):

tests/test_datasets.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ def test_allen():
153153

154154

155155
@pytest.mark.requires_dataset
156-
@pytest.mark.parametrize("options",
157-
cebra.datasets.get_options("*",
158-
expand_parametrized=False))
156+
@pytest.mark.parametrize(
157+
"options", cebra.datasets.get_options("*", expand_parametrized=False))
159158
def test_options(options):
160159
assert len(options) > 0
161160
assert len(multisubject_options) > 0

0 commit comments

Comments
 (0)