|
28 | 28 | from mne.io.tests.test_raw import _test_concat, _test_raw_reader
|
29 | 29 | from mne import (concatenate_events, find_events, equalize_channels,
|
30 | 30 | compute_proj_raw, pick_types, pick_channels, create_info,
|
31 |
| - pick_info) |
| 31 | + pick_info, make_fixed_length_epochs) |
32 | 32 | from mne.utils import (requires_pandas, assert_object_equal, _dt_to_stamp,
|
33 | 33 | requires_mne, run_subprocess, _record_warnings,
|
34 | 34 | assert_and_remove_boundary_annot)
|
@@ -382,6 +382,54 @@ def test_concatenate_raws(on_mismatch):
|
382 | 382 | concatenate_raws(**kws)
|
383 | 383 |
|
384 | 384 |
|
| 385 | +def _create_toy_data(n_channels=3, sfreq=250, seed=None): |
| 386 | + rng = np.random.default_rng(seed) |
| 387 | + data = rng.standard_normal(size=(n_channels, 50 * sfreq)) * 5e-6 |
| 388 | + info = create_info(n_channels, sfreq, "eeg") |
| 389 | + return RawArray(data, info) |
| 390 | + |
| 391 | + |
| 392 | +def test_concatenate_raws_bads_order(): |
| 393 | + """Test concatenation of raw instances.""" |
| 394 | + raw0 = _create_toy_data() |
| 395 | + raw1 = _create_toy_data() |
| 396 | + |
| 397 | + # Test bad channel order |
| 398 | + raw0.info["bads"] = ["0", "1"] |
| 399 | + raw1.info["bads"] = ["1", "0"] |
| 400 | + |
| 401 | + # raw0 is modified in-place and therefore copied |
| 402 | + raw_concat = concatenate_raws([raw0.copy(), raw1]) |
| 403 | + |
| 404 | + # Check data are equal |
| 405 | + data_concat = np.concatenate([raw0.get_data(), raw1.get_data()], 1) |
| 406 | + assert np.all(raw_concat.get_data() == data_concat) |
| 407 | + |
| 408 | + # Check bad channels |
| 409 | + assert set(raw_concat.info["bads"]) == {"0", "1"} |
| 410 | + |
| 411 | + # Bad channel mismatch raises |
| 412 | + raw2 = raw1.copy() |
| 413 | + raw2.info["bads"] = ["0", "2"] |
| 414 | + with pytest.raises(ValueError): |
| 415 | + concatenate_raws([raw0, raw2]) |
| 416 | + |
| 417 | + # Type mismatch raises |
| 418 | + epochs1 = make_fixed_length_epochs(raw1) |
| 419 | + with pytest.raises(ValueError): |
| 420 | + concatenate_raws([raw0, epochs1]) |
| 421 | + |
| 422 | + # Sample rate mismatch |
| 423 | + raw3 = _create_toy_data(sfreq=500) |
| 424 | + with pytest.raises(ValueError): |
| 425 | + concatenate_raws([raw0, raw3]) |
| 426 | + |
| 427 | + # Number of channels mismatch |
| 428 | + raw4 = _create_toy_data(n_channels=4) |
| 429 | + with pytest.raises(ValueError): |
| 430 | + concatenate_raws([raw0, raw4]) |
| 431 | + |
| 432 | + |
385 | 433 | @testing.requires_testing_data
|
386 | 434 | @pytest.mark.parametrize('mod', (
|
387 | 435 | 'meg',
|
|
0 commit comments