diff --git a/doc/changes/dev/13435.newfeature.rst b/doc/changes/dev/13435.newfeature.rst new file mode 100644 index 00000000000..1aed60dc123 --- /dev/null +++ b/doc/changes/dev/13435.newfeature.rst @@ -0,0 +1 @@ +Add support for BDF export in :func:`mne.export.export_raw`, by `Clemens Brunner`_ \ No newline at end of file diff --git a/environment.yml b/environment.yml index 4586093473c..8e7e083563b 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ dependencies: - decorator - defusedxml - dipy - - edfio >=0.2.1 + - edfio >=0.4.10 - eeglabio - filelock >=3.18.0 - h5io >=0.2.4 diff --git a/mne/conftest.py b/mne/conftest.py index 5cee2258836..57d14205e17 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -78,7 +78,7 @@ event_id, tmin, tmax = 1, -0.1, 1.0 vv_layout = read_layout("Vectorview-all") -collect_ignore = ["export/_brainvision.py", "export/_eeglab.py", "export/_edf.py"] +collect_ignore = ["export/_brainvision.py", "export/_eeglab.py", "export/_edf_bdf.py"] def pytest_configure(config: pytest.Config): diff --git a/mne/export/_edf.py b/mne/export/_edf_bdf.py similarity index 74% rename from mne/export/_edf.py rename to mne/export/_edf_bdf.py index d537d55868f..fa4f9ebcf64 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf_bdf.py @@ -7,11 +7,19 @@ import numpy as np -from ..annotations import _sync_onset -from ..utils import _check_edfio_installed, warn +from mne.annotations import _sync_onset +from mne.utils import _check_edfio_installed, warn _check_edfio_installed() -from edfio import Edf, EdfAnnotation, EdfSignal, Patient, Recording # noqa: E402 +from edfio import ( # noqa: E402 + Bdf, + BdfSignal, + Edf, + EdfAnnotation, + EdfSignal, + Patient, + Recording, +) # copied from edfio (Apache license) @@ -29,44 +37,61 @@ def _round_float_to_8_characters( return round_func(value * factor) / factor -def _export_raw(fname, raw, physical_range, add_ch_type): - """Export Raw objects to EDF files. +def _export_raw_edf_bdf(fname, raw, physical_range, add_ch_type, file_format): + """Export Raw objects to EDF/BDF files. + Parameters + ---------- + fname : str + Output file name. + raw : instance of Raw + The raw instance to export. + physical_range : str or tuple + Physical range setting. + add_ch_type : bool + Whether to add channel type to signal label. + file_format : str + File format ("EDF" or "BDF"). + + Notes + ----- TODO: if in future the Info object supports transducer or technician information, allow writing those here. """ - # get voltage-based data in uV units = dict( eeg="uV", ecog="uV", seeg="uV", eog="uV", ecg="uV", emg="uV", bio="uV", dbs="uV" ) - digital_min, digital_max = -32767, 32767 - annotations = [] - - # load data first - raw.load_data() + if file_format == "EDF": + digital_min, digital_max = -32767, 32767 # 16-bit + signal_class = EdfSignal + writer_class = Edf + else: # BDF + digital_min, digital_max = -8388607, 8388607 # 24-bit + signal_class = BdfSignal + writer_class = Bdf ch_types = np.array(raw.get_channel_types()) - n_times = raw.n_times - # get the entire dataset in uV + # load and prepare data + raw.load_data() data = raw.get_data(units=units) - - # Sampling frequency in EDF only supports integers, so to allow for float sampling - # rates from Raw, we adjust the output sampling rate for all channels and the data - # record duration. sfreq = raw.info["sfreq"] + pad_annotations = [] + + # Sampling frequency in EDF/BDF only supports integers, so to allow for float + # sampling rates from Raw, we adjust the output sampling rate for all channels and + # the data record duration. if float(sfreq).is_integer(): out_sfreq = int(sfreq) data_record_duration = None # make non-integer second durations work - if (pad_width := int(np.ceil(n_times / sfreq) * sfreq - n_times)) > 0: + if (pad_width := int(np.ceil(raw.n_times / sfreq) * sfreq - raw.n_times)) > 0: warn( - "EDF format requires equal-length data blocks, so " + f"{file_format} format requires equal-length data blocks, so " f"{pad_width / sfreq:.3g} seconds of edge values were appended to all " "channels when writing the final block." ) - orig_shape = data.shape data = np.pad( data, ( @@ -75,10 +100,8 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ), "edge", ) - assert data.shape[0] == orig_shape[0] - assert data.shape[1] > orig_shape[1] - annotations.append( + pad_annotations.append( EdfAnnotation( raw.times[-1] + 1 / sfreq, pad_width / sfreq, "BAD_ACQ_SKIP" ) @@ -89,18 +112,19 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ) out_sfreq = np.floor(sfreq) / data_record_duration warn( - f"Data has a non-integer sampling rate of {sfreq}; writing to EDF format " - "may cause a small change to sample times." + f"Data has a non-integer sampling rate of {sfreq}; writing to " + f"{file_format} format may cause a small change to sample times." ) - # get any filter information applied to the data + # extract filter information lowpass = raw.info["lowpass"] highpass = raw.info["highpass"] linefreq = raw.info["line_freq"] filter_str_info = f"HP:{highpass}Hz LP:{lowpass}Hz" if linefreq is not None: - filter_str_info += " N:{linefreq}Hz" + filter_str_info += f" N:{linefreq}Hz" + # compute physical range if physical_range == "auto": # get max and min for each channel type data ch_types_phys_max = dict() @@ -136,6 +160,8 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ) data = np.clip(data, pmin, pmax) prange = pmin, pmax + + # create signals signals = [] for idx, ch in enumerate(raw.ch_names): ch_type = ch_types[idx] @@ -143,8 +169,8 @@ def _export_raw(fname, raw, physical_range, add_ch_type): if len(signal_label) > 16: raise RuntimeError( f"Signal label for {ch} ({ch_type}) is longer than 16 characters, which" - " is not supported by the EDF standard. Please shorten the channel name" - "before exporting to EDF." + f" is not supported by the {file_format} standard. Please shorten the " + f"channel name before exporting to {file_format}." ) if physical_range == "auto": # per channel type @@ -155,7 +181,7 @@ def _export_raw(fname, raw, physical_range, add_ch_type): prange = pmin, pmax signals.append( - EdfSignal( + signal_class( data[idx], out_sfreq, label=signal_label, @@ -167,7 +193,7 @@ def _export_raw(fname, raw, physical_range, add_ch_type): ) ) - # set patient info + # create patient info subj_info = raw.info.get("subject_info") if subj_info is not None: # get the full name of subject if available @@ -197,7 +223,7 @@ def _export_raw(fname, raw, physical_range, add_ch_type): else: patient = None - # set measurement date + # create recording info if (meas_date := raw.info["meas_date"]) is not None: startdate = dt.date(meas_date.year, meas_date.month, meas_date.day) starttime = dt.time( @@ -214,9 +240,11 @@ def _export_raw(fname, raw, physical_range, add_ch_type): else: recording = Recording(startdate=startdate) + # create annotations + annotations = [] for desc, onset, duration, ch_names in zip( raw.annotations.description, - # subtract raw.first_time because EDF marks events starting from the first + # subtract raw.first_time because EDF/BDF marks events starting from the first # available data point and ignores raw.first_time _sync_onset(raw, raw.annotations.onset, inverse=False), raw.annotations.duration, @@ -230,7 +258,10 @@ def _export_raw(fname, raw, physical_range, add_ch_type): else: annotations.append(EdfAnnotation(onset, duration, desc)) - Edf( + annotations.extend(pad_annotations) + + # write to file + writer_class( signals=signals, patient=patient, recording=recording, @@ -238,3 +269,13 @@ def _export_raw(fname, raw, physical_range, add_ch_type): data_record_duration=data_record_duration, annotations=annotations, ).write(fname) + + +def _export_raw_edf(fname, raw, physical_range, add_ch_type): + """Export Raw object to EDF.""" + _export_raw_edf_bdf(fname, raw, physical_range, add_ch_type, file_format="EDF") + + +def _export_raw_bdf(fname, raw, physical_range, add_ch_type): + """Export Raw object to BDF.""" + _export_raw_edf_bdf(fname, raw, physical_range, add_ch_type, file_format="BDF") diff --git a/mne/export/_export.py b/mne/export/_export.py index 4b93fda917e..2842b747f21 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -2,10 +2,10 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. -import os.path as op +import os -from ..utils import _check_fname, _validate_type, logger, verbose, warn -from ._egimff import export_evokeds_mff +from mne.export._egimff import export_evokeds_mff +from mne.utils import _check_fname, _validate_type, logger, verbose, warn @verbose @@ -56,13 +56,14 @@ def export_raw( """ fname = str(_check_fname(fname, overwrite=overwrite)) supported_export_formats = { # format : (extensions,) - "eeglab": ("set",), - "edf": ("edf",), + "bdf": ("bdf",), "brainvision": ( "eeg", "vmrk", "vhdr", ), + "edf": ("edf",), + "eeglab": ("set",), } fmt = _infer_check_export_fmt(fmt, fname, supported_export_formats) @@ -73,18 +74,23 @@ def export_raw( "them before exporting with raw.apply_proj()." ) - if fmt == "eeglab": - from ._eeglab import _export_raw + match fmt: + case "bdf": + from mne.export._edf_bdf import _export_raw_bdf + + _export_raw_bdf(fname, raw, physical_range, add_ch_type) + case "brainvision": + from mne.export._brainvision import _export_raw - _export_raw(fname, raw) - elif fmt == "edf": - from ._edf import _export_raw + _export_raw(fname, raw, overwrite) + case "edf": + from mne.export._edf_bdf import _export_raw_edf - _export_raw(fname, raw, physical_range, add_ch_type) - elif fmt == "brainvision": - from ._brainvision import _export_raw + _export_raw_edf(fname, raw, physical_range, add_ch_type) + case "eeglab": + from mne.export._eeglab import _export_raw - _export_raw(fname, raw, overwrite) + _export_raw(fname, raw) @verbose @@ -127,7 +133,7 @@ def export_epochs(fname, epochs, fmt="auto", *, overwrite=False, verbose=None): ) if fmt == "eeglab": - from ._eeglab import _export_epochs + from mne.export._eeglab import _export_epochs _export_epochs(fname, epochs) @@ -204,7 +210,7 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): _validate_type(fmt, str, "fmt") fmt = fmt.lower() if fmt == "auto": - fmt = op.splitext(fname)[1] + fmt = os.path.splitext(fname)[1] if fmt: fmt = fmt[1:].lower() # find fmt in supported formats dict's tuples diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 743491f26c9..f9146227d50 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -25,6 +25,7 @@ from mne.fixes import _compare_version from mne.io import ( RawArray, + read_raw_bdf, read_raw_brainvision, read_raw_edf, read_raw_eeglab, @@ -670,3 +671,50 @@ def test_export_evokeds_unsupported_format(fmt, ext): errstr = fmt.lower() if fmt != "auto" else "vhdr" with pytest.raises(ValueError, match=f"Format '{errstr}' is not .*"): export_evokeds(f"output.{ext}", evoked, fmt=fmt) + + +@edfio_mark() +@pytest.mark.parametrize( + ("input_path", "warning_msg"), + [ + (fname_raw, "Data has a non-integer"), + pytest.param( + misc_path / "ecog" / "sample_ecog_ieeg.fif", + "BDF format requires", + marks=[pytest.mark.slowtest, misc._pytest_mark()], + ), + ], +) +def test_export_raw_bdf(tmp_path, input_path, warning_msg): + """Test saving a Raw instance to BDF format.""" + raw = read_raw_fif(input_path) + + # only test with EEG channels + raw.pick(picks=["eeg", "ecog", "seeg"]).load_data() + temp_fname = tmp_path / "test.bdf" + + with pytest.warns(RuntimeWarning, match=warning_msg): + raw.export(temp_fname) + + if "epoc" in raw.ch_names: + raw.drop_channels(["epoc"]) + + raw_read = read_raw_bdf(temp_fname, preload=True) + assert raw.ch_names == raw_read.ch_names + # only compare the original length, since extra zeros are appended + orig_raw_len = len(raw) + + # assert data and times are not different + # Due to the physical range of the data, reading and writing is not lossless. For + # example, a physical min/max of -/+ 3200 uV will result in a resolution of 0.38 nV. + # This resolution is more than sufficient for EEG. + assert_array_almost_equal( + raw.get_data(), raw_read.get_data()[:, :orig_raw_len], decimal=11 + ) + + # Due to the data record duration limitations of BDF files, one cannot store + # arbitrary float sampling rate exactly. Usually this results in two sampling rates + # that are off by very low number of decimal points. This for practical purposes + # does not matter but will result in an error when say the number of time points is + # very very large. + assert_allclose(raw.times, raw_read.times[:orig_raw_len], rtol=0, atol=1e-5) diff --git a/pyproject.toml b/pyproject.toml index 71a28352184..453b8634148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ full-no-qt = [ "darkdetect", "defusedxml", "dipy", - "edfio >= 0.2.1", + "edfio >= 0.4.10", "eeglabio", "filelock >= 3.18.0", "h5py", @@ -160,7 +160,7 @@ test = [ # Dependencies for being able to run additional tests (rare/CIs/advanced devs) # Changes here should be reflected in the mne/utils/config.py dev dependencies section test_extra = [ - "edfio >= 0.2.1", + "edfio >= 0.4.10", "eeglabio", "imageio >= 2.6.1", "imageio-ffmpeg >= 0.4.1",