diff --git a/.gitignore b/.gitignore index d981c8de4e..6c9fa6869f 100644 --- a/.gitignore +++ b/.gitignore @@ -180,6 +180,7 @@ examples/tutorials/*.svg doc/_build/* doc/tutorials/* doc/sources/* +*sg_execution_times.rst examples/getting_started/tmp_* examples/getting_started/phy diff --git a/doc/conf.py b/doc/conf.py index 4373ec3c36..331642260a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -119,7 +119,7 @@ sphinx_gallery_conf = { 'only_warn_on_example_error': True, 'examples_dirs': ['../examples/tutorials'], - 'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples + 'gallery_dirs': ['tutorials'], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', @@ -130,9 +130,19 @@ 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_', 'nested_sections': False, - 'copyfile_regex': r'.*\.rst|.*\.png|.*\.svg' + 'copyfile_regex': r'.*\.rst|.*\.png|.*\.svg', + 'filename_pattern': '/plot_', } +if tags.has("handle_drift") or tags.has("all_long_plot"): + + if (handle_drift_path := (Path('long_tutorials/handle_drift'))).is_dir(): + shutil.rmtree(handle_drift_path) + + sphinx_gallery_conf['examples_dirs'].append('../examples/long_tutorials/handle_drift') + sphinx_gallery_conf["gallery_dirs"].append(handle_drift_path.as_posix()) + + intersphinx_mapping = { "neo": ("https://neo.readthedocs.io/en/latest/", None), "probeinterface": ("https://probeinterface.readthedocs.io/en/stable/", None), diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index 54fd404848..1011c235e6 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -13,3 +13,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. combine_recordings process_by_channel_group load_your_data_into_sorting + /long_tutorials/handle_drift/plot_handle_drift.rst diff --git a/doc/images/no-drift-example.png b/doc/images/no-drift-example.png new file mode 100644 index 0000000000..d4a0e0edfd Binary files /dev/null and b/doc/images/no-drift-example.png differ diff --git a/doc/long_tutorials/handle_drift/handle_drift_jupyter.zip b/doc/long_tutorials/handle_drift/handle_drift_jupyter.zip new file mode 100644 index 0000000000..79606a8739 Binary files /dev/null and b/doc/long_tutorials/handle_drift/handle_drift_jupyter.zip differ diff --git a/doc/long_tutorials/handle_drift/handle_drift_python.zip b/doc/long_tutorials/handle_drift/handle_drift_python.zip new file mode 100644 index 0000000000..f3d5d59d94 Binary files /dev/null and b/doc/long_tutorials/handle_drift/handle_drift_python.zip differ diff --git a/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_001.png b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_001.png new file mode 100644 index 0000000000..66bd7ce2b0 Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_001.png differ diff --git a/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_002.png b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_002.png new file mode 100644 index 0000000000..29e1e1d4bf Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_002.png differ diff --git a/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_003.png b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_003.png new file mode 100644 index 0000000000..eb25eb8ede Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_003.png differ diff --git a/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_004.png b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_004.png new file mode 100644 index 0000000000..bba58ce513 Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_004.png differ diff --git a/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_005.png b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_005.png new file mode 100644 index 0000000000..6b35267d45 Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_005.png differ diff --git a/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_006.png b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_006.png new file mode 100644 index 0000000000..b58087f7e9 Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/sphx_glr_plot_handle_drift_006.png differ diff --git a/doc/long_tutorials/handle_drift/images/thumb/sphx_glr_plot_handle_drift_thumb.png b/doc/long_tutorials/handle_drift/images/thumb/sphx_glr_plot_handle_drift_thumb.png new file mode 100644 index 0000000000..636a35686e Binary files /dev/null and b/doc/long_tutorials/handle_drift/images/thumb/sphx_glr_plot_handle_drift_thumb.png differ diff --git a/doc/long_tutorials/handle_drift/index.rst b/doc/long_tutorials/handle_drift/index.rst new file mode 100644 index 0000000000..8247b44383 --- /dev/null +++ b/doc/long_tutorials/handle_drift/index.rst @@ -0,0 +1,72 @@ +:orphan: + +Handle Drift Tutorial +--------------------- + +This tutorial is not mean to be displayed on +a sphinx gallery. The generated index.rst is not +meant to be linked to in any toctree. + +Instead, sphinx-gallery is used to +automatically build this page, which +takes a long time (~25 minutes), and it is +linked too manually, directly to the +rst (TODO: fill in filename) that +sphinx-gallery generates. + + + +.. raw:: html + +
+ +.. thumbnail-parent-div-open + +.. raw:: html + +
+ +.. only:: html + + .. image:: /long_tutorials/handle_drift/images/thumb/sphx_glr_plot_handle_drift_thumb.png + :alt: + + :ref:`sphx_glr_long_tutorials_handle_drift_plot_handle_drift.py` + +.. raw:: html + +
Handle probe drift with spikeinterface NEW
+
+ + +.. thumbnail-parent-div-close + +.. raw:: html + +
+ + +.. toctree:: + :hidden: + + /long_tutorials/handle_drift/plot_handle_drift + + +.. only:: html + + .. container:: sphx-glr-footer sphx-glr-footer-gallery + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download all examples in Python source code: handle_drift_python.zip ` + + .. container:: sphx-glr-download sphx-glr-download-jupyter + + :download:`Download all examples in Jupyter notebooks: handle_drift_jupyter.zip ` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/doc/long_tutorials/handle_drift/plot_handle_drift.ipynb b/doc/long_tutorials/handle_drift/plot_handle_drift.ipynb new file mode 100644 index 0000000000..4bbbb33f63 --- /dev/null +++ b/doc/long_tutorials/handle_drift/plot_handle_drift.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n# Handle probe drift with spikeinterface NEW\n\nProbe movement is an inevitability when running\n*in vivo* electrophysiology recordings. Motion, caused by physical\nmovement of the probe or the sliding of brain tissue\ndeforming across the probe, can complicate the sorting\nand analysis of units.\n\nSpikeInterface offers a flexible framework to handle motion correction\nas a preprocessing step. In this tutorial we will cover the three main\ndrift-correction algorithms implemented in SpikeInterface\n(**rigid_fast**, **kilosort_like** and **nonrigid_accurate**) with\na focus on running the methods and interpreting the output.\n\nFor more information on the theory and implementation of these methods,\nsee the `motion_correction` section of the documentation and\nthe [kilosort4 page](https://kilosort.readthedocs.io/en/latest/drift.html)\non drift correction. Drift correction may not always work as expected\n(for example, if the probe has a small number of channels), see the\n`When do I need to apply drift correction?`_ section for assessing\ndrift correction output.\n\n## What is probe drift?\n\nThe inserted probe can move from side-to-side (*'x' direction*),\nup-or-down (*'y' direction*) or forwards-or-backwards (*'z' direction*).\nMovement in the 'x' and 'z' direction is harder to model than vertical\ndrift (i.e. along the probe depth), and are not handled by most motion\ncorrection algorithms. Fortunately, vertical drift which is most easily\nhandled is most pronounced as the probe is most likely to move along the path\nof insertion.\n\nVertical drift can come in two forms, *'rigid'* and *'non-rigid'*. Rigid drift\nis drift caused by movement of the entire probe, and the motion is\nsimilar across all channels along the probe depth. In contrast,\nnon-rigid drift is instead caused by local movements of neuronal tissue along the\nprobe, and can selectively affect subsets of channels.\n\n## The drift correction steps\n\nThe easiest way to run drift correction in SpikeInterface is with the\nhigh-level :py:func:`~spikeinterface.preprocessing.correct_motion()` function.\nThis function takes a recording as input and returns a motion-corrected\nrecording object. As with all other preprocessing steps, the correction (in this\ncase interpolation of the data to correct the detected motion) is lazy and applied on-the-fly when data is needed).\n\nThe :py:func:`~spikeinterface.preprocessing.correct_motion()`\nfunction implements motion correction algorithms in a modular way\nwrapping a number of subfunctions that together implement the\nfull drift correction algorithm.\n\nThese drift-correction modules are:\n\n| **1.** ``localize_peaks()`` (detect spikes and localize their position on the probe)\n| **2.** ``select_peaks()`` (optional, select a subset of peaks to use to estimate motion)\n| **3.** ``estimate_motion()`` (estimate motion using the detected spikes)\n| **4.** ``interpolate_motion()`` (perform interpolation on the raw data to account for the estimated drift).\n\nAll these sub-steps have many parameters which dictate the\nspeed and effectiveness of motion correction. As such, ``correct_motion``\nprovides three setting 'presets' which configure the motion correct\nto proceed either as:\n\n* **rigid_fast** - a fast, not particularly accurate correction assuming rigid drift.\n* **kilosort-like** - Mimics what is done in Kilosort.\n* **nonrigid_accurate** - A decentralized drift correction (DREDGE), introduced by the Paninski group.\n\nWhen using motion correction in your analysis, please make sure to\n`cite the appropriate paper for your chosen method`.\n\n\n**Now, let's dive into running motion correction with these three\nmethods on a simulated dataset.**\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting up and preprocessing the recording\n\nFirst, we will import the modules we will need for this tutorial:\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\nimport spikeinterface.full as si\nfrom spikeinterface.generation.drifting_generator import generate_drifting_recording\nfrom spikeinterface.preprocessing.motion import motion_options_preset\nfrom spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks\nfrom spikeinterface.widgets import plot_peaks_on_probe" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we will generate a synthetic, drifting recording. This recording will\nhave 100 separate units with firing rates randomly distributed between\n15 and 25 Hz.\n\nWe will create a zigzag drift pattern on the recording, starting at\n100 seconds and with a peak-to-peak period of 100 seconds (so we will\nhave 9 zigzags through our recording). We also add some non-linearity\nto the imposed motion.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Note

This tutorial can take a long time to run with the default arguments.\n If you would like to run this locally, you may want to edit ``num_units``\n and ``duration`` to smaller values (e.g. 25 and 100 respectively).\n\n Also note, the below code uses multiprocessing. If you are on Windows, you may\n need to place the code within a ``if __name__ == \"__main__\":`` block.

\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "num_units = 200 # 250 still too many I think!\nduration = 1000\n\n_, raw_recording, _ = generate_drifting_recording(\n num_units=num_units,\n duration=duration,\n generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0),\n seed=42,\n generate_displacement_vector_kwargs=dict(motion_list=[\n dict(\n drift_mode=\"zigzag\",\n non_rigid_gradient=0.01,\n t_start_drift=int(duration/10),\n t_end_drift=None,\n period_s=int(duration/10),\n ),\n ],\n )\n)\nprint(raw_recording)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Before performing motion correction, we will **preprocess** the recording\nwith a bandpass filter and a common median reference.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "filtered_recording = si.bandpass_filter(raw_recording, freq_min=300.0, freq_max=6000.0)\npreprocessed_recording = si.common_reference(filtered_recording, reference=\"global\", operator=\"median\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Warning

It is better to not whiten the recording before motion estimation, as this\n will give a better estimate of the peak locations. Whitening should\n be performed after motion correction.

\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run motion correction with one function!\n\nCorrecting for drift is easy! You just need to run a single function.\nWe will now run motion correction on our recording using the three\npresets described above - **rigid_fast**, **kilosort_like** and\n**nonrigid_accurate**.\n\nWe can run these presents with the ``preset`` argument of\n:py:func:`~spikeinterface.preprocessing.correct_motion()`. Under the\nhood, the presets define a set of parameters by set how to run the\n4 submodules that make up motion correction (described above).\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "print(motion_options_preset[\"kilosort_like\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, lets run motion correction with our three presets. We will\nset the ``job_kwargs`` to parallelize the job over a number of CPU cores\u2014motion\ncorrection is computationally intensive and will run faster with parallelization.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "presets_to_run = (\"rigid_fast\", \"kilosort_like\", \"nonrigid_accurate\")\n\njob_kwargs = dict(n_jobs=40, chunk_duration=\"1s\", progress_bar=True)\n\nresults = {preset: {} for preset in presets_to_run}\nfor preset in presets_to_run:\n\n corrected_recording, motion_info = si.correct_motion(\n preprocessed_recording, preset=preset, output_motion_info=True, **job_kwargs\n )\n results[preset][\"motion_info\"] = motion_info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ".. seealso::\n It is often very useful to save ``motion_info`` to a\n file, so it can be loaded and visualized later. This can be done by setting\n the ``folder`` argument of\n :py:func:`~spikeinterface.preprocessing.correct_motion()` to a path to write\n all motion outputs to. The ``motion_info`` can be loaded back with\n ``si.load_motion_info``.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting the results\n\nNext, let's plot the results of our motion estimation using the ``plot_motion_info()``\nfunction. The plot contains 4 panels, on the x-axis of all plots we have\nthe (binned time). The plots display:\n * **top left:** The estimated peak depth for every detected peak.\n * **top right:** The estimated peak depths after motion correction.\n * **bottom left:** The average motion vector across depths and all motion across spatial depths (for non-rigid estimation).\n * **bottom right:** if motion correction is non-rigid, the motion vector across depths is plotted as a map, with the color code representing the motion in micrometers.\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "for preset in presets_to_run:\n\n fig = plt.figure(figsize=(7, 7))\n\n si.plot_motion_info(\n results[preset][\"motion_info\"],\n recording=corrected_recording, # the recording is only used to get the real times\n figure=fig,\n depth_lim=(400, 600),\n color_amplitude=True,\n amplitude_cmap=\"inferno\",\n scatter_decimate=10, # Only plot every 10th peak\n )\n fig.suptitle(f\"{preset=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "These plots are quite complicated, so it is worth covering them in detail.\nFor every detected spike in our recording, we first estimate\nits depth (first panel) using a method from\n:py:func:`~spikeinterface.postprocessing.compute_unit_locations()`.\n\nThen, the probe motion is estimated and the location of the\nspikes are adjusted to account for the motion (second panel).\n\nThe motion estimation produces\na measure of how much and in what direction the probe is moving at any given\ntime bin (third panel). For non-rigid motion correction, the probe is divided\ninto subsections - the motion vectors displayed are per subsection (i.e. per\n'binned spatial depth') as well as the average.\n\nOn the fourth panel, we see a\nmore detailed representation of the motion vectors. We can see the motion plotted\nas a heatmap at each binned spatial depth across all time bins. It captures\nthe zigzag pattern (alternating light and dark colors) of the injected motion.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few comments on the figures:\n * The preset **'rigid_fast'** has only one motion vector for the entire probe because it is a 'rigid' case.\n The motion amplitude is globally underestimated because it averages across depths.\n However, the corrected peaks are flatter than the non-corrected ones, so the job is partially done.\n The big jump at 600s when the probe start moving is recovered quite well.\n * The preset **kilosort_like** gives better results because it is a non-rigid case.\n The motion vector is computed for different depths.\n The corrected peak locations are flatter than the rigid case.\n The motion vector map is still a bit noisy at some depths (e.g around 1000um).\n * The preset **nonrigid_accurate** seems to give the best results on this recording.\n The motion vector seems less noisy globally, but it is not 'perfect' (see at the top of the probe 3200um to 3800um).\n Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion:\n the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable.\n The method defined by this preset is able to capture this.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Correcting Peak Locations after Motion Correction\n\nThe result of motion correction can be applied to the data in two ways.\nThe first is by interpolating the raw traces to correct for the estimated drift.\nThis changes the data in the\nrecording by shifting the signal across channels, and is given in the\n`corrected_recording` output from :py:func:`~spikeinterface.preprocessing.correct_motion()`.\nThis is useful in most cases, for continuing\nwith preprocessing and sorting with the corrected recording.\n\nThe second way is to apply the results of motion correction directly\nto the ``peak_locations`` object. If you are not familiar with\nSpikeInterface's ``peak`` and ``peak_locations`` objects,\nthese are explored further in the below dropdown.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + ".. dropdown:: 'Peaks' and 'Peak Locations' in SpikeInterface\n\n Information about detected spikes is represented in\n SpikeInterface's ``peaks`` and ``peak_locations`` objects. The\n ``peaks`` object is an array for containing the\n sample index, channel index (where its signal\n is strongest), amplitude and recording segment index for every detected spike\n in the dataset. It is created by the\n :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks()`\n function.\n\n The ``peak_locations`` is a partner object to the ``peaks`` object,\n and contains the estimated location (``\"x\"``, ``\"y\"``) of the spike. For every spike in\n ``peaks`` there is a corresponding location in ``peak_locations``.\n The peak locations is estimated using the\n :py:func:`~spikeinterface.sortingcomponents.peak_localization.localise_peaks()`\n function.\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The other way to apply the motion correction is to the ``peaks`` and\n``peaks_location`` objects directly. This is done using the function\n``correct_motion_on_peaks()``. Given a set of peaks, peak locations and\nthe ``motion`` object output from :py:func:`~spikeinterface.preprocessing.correct_motion()`,\nit will shift the location of the peaks according to the motion estimate, outputting a new\n``peak_locations`` object. This is done to plot the peak locations in\nthe next section.\n\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

Warning

Note that the ``peak_locations`` output by\n :py:func:`~spikeinterface.preprocessing.correct_motion()`\n (in the ``motion_info`` object) is the original (uncorrected) peak locations.\n To get the corrected peak locations, ``correct_motion_on_peaks()`` must be used!

\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "for preset in presets_to_run:\n\n motion_info = results[preset][\"motion_info\"]\n\n peaks = motion_info[\"peaks\"]\n\n original_peak_locations = motion_info[\"peak_locations\"]\n\n corrected_peak_locations = correct_motion_on_peaks(peaks, original_peak_locations, motion_info['motion'], corrected_recording)\n\n widget = plot_peaks_on_probe(corrected_recording, [peaks, peaks], [original_peak_locations, corrected_peak_locations], ylim=(300,600))\n widget.figure.suptitle(preset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing the Run Times\n\nThe different methods also have different speeds, the 'nonrigid_accurate'\nrequires more computation time, in particular at the ``estimate_motion`` phase,\nas seen in the run times:\n\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "for preset in presets_to_run:\n print(preset)\n print(results[preset][\"motion_info\"][\"run_times\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## When do I need to apply drift correction?\n\nDrift correction may not always be necessary for your data, for\nexample, for example when there is not much drift in the data to begin with.\nFurther, in some cases (e.g. when the probe has a smaller number of channels,\ne.g. 64 or less) the drift correction algorithms may not perfect well.\n\nTo check whether drift correction is required and how it is performing,\nit is necessary to run drift correction as above and then check the output plots.\nIn the below example, the 'Peak depth' plot shows minimal drift in the peak position.\nIn this example, it does not look like drift correction is that necessary. Further,\nbecause there are only 16 channels in this recording, the drift correction is failing.\nThe 'Correct Peak Depth' as erroenously shifted peaks to the wrong position, spreading\nthem across the probe. In this instance, drift correction could be skipped.\n\n\n\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n\nThat's it for our tour of motion correction in\nSpikeInterface. Remember that correcting motion makes some\nassumptions on your data (e.g. number of channels, noise in the recording)\u2014always\nplot the motion correction information for your\nrecordings, to make sure the correction is behaving as expected!\n\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/doc/long_tutorials/handle_drift/plot_handle_drift.py b/doc/long_tutorials/handle_drift/plot_handle_drift.py new file mode 100644 index 0000000000..497040f376 --- /dev/null +++ b/doc/long_tutorials/handle_drift/plot_handle_drift.py @@ -0,0 +1,368 @@ +""" +=========================================== +Handle probe drift with spikeinterface NEW +=========================================== + +Probe movement is an inevitability when running +*in vivo* electrophysiology recordings. Motion, caused by physical +movement of the probe or the sliding of brain tissue +deforming across the probe, can complicate the sorting +and analysis of units. + +SpikeInterface offers a flexible framework to handle motion correction +as a preprocessing step. In this tutorial we will cover the three main +drift-correction algorithms implemented in SpikeInterface +(**rigid_fast**, **kilosort_like** and **nonrigid_accurate**) with +a focus on running the methods and interpreting the output. + +For more information on the theory and implementation of these methods, +see the :ref:`motion_correction` section of the documentation and +the `kilosort4 page `_ +on drift correction. Drift correction may not always work as expected +(for example, if the probe has a small number of channels), see the +`When do I need to apply drift correction?`_ section for assessing +drift correction output. + +--------------------- +What is probe drift? +--------------------- + +The inserted probe can move from side-to-side (*'x' direction*), +up-or-down (*'y' direction*) or forwards-or-backwards (*'z' direction*). +Movement in the 'x' and 'z' direction is harder to model than vertical +drift (i.e. along the probe depth), and are not handled by most motion +correction algorithms. Fortunately, vertical drift which is most easily +handled is most pronounced as the probe is most likely to move along the path +of insertion. + +Vertical drift can come in two forms, *'rigid'* and *'non-rigid'*. Rigid drift +is drift caused by movement of the entire probe, and the motion is +similar across all channels along the probe depth. In contrast, +non-rigid drift is instead caused by local movements of neuronal tissue along the +probe, and can selectively affect subsets of channels. + +-------------------------- +The drift correction steps +-------------------------- + +The easiest way to run drift correction in SpikeInterface is with the +high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` function. +This function takes a recording as input and returns a motion-corrected +recording object. As with all other preprocessing steps, the correction (in this +case interpolation of the data to correct the detected motion) is lazy and applied on-the-fly when data is needed). + +The :py:func:`~spikeinterface.preprocessing.correct_motion()` +function implements motion correction algorithms in a modular way +wrapping a number of subfunctions that together implement the +full drift correction algorithm. + +These drift-correction modules are: + +| **1.** ``localize_peaks()`` (detect spikes and localize their position on the probe) +| **2.** ``select_peaks()`` (optional, select a subset of peaks to use to estimate motion) +| **3.** ``estimate_motion()`` (estimate motion using the detected spikes) +| **4.** ``interpolate_motion()`` (perform interpolation on the raw data to account for the estimated drift). + +All these sub-steps have many parameters which dictate the +speed and effectiveness of motion correction. As such, ``correct_motion`` +provides three setting 'presets' which configure the motion correct +to proceed either as: + +* **rigid_fast** - a fast, not particularly accurate correction assuming rigid drift. +* **kilosort-like** - Mimics what is done in Kilosort. +* **nonrigid_accurate** - A decentralized drift correction (DREDGE), introduced by the Paninski group. + +When using motion correction in your analysis, please make sure to +:ref:`cite the appropriate paper for your chosen method`. + + +**Now, let's dive into running motion correction with these three +methods on a simulated dataset.** + +""" + +# %% +# ------------------------------------------- +# Setting up and preprocessing the recording +# ------------------------------------------- +# +# First, we will import the modules we will need for this tutorial: + +import matplotlib.pyplot as plt +import spikeinterface.full as si +from spikeinterface.generation.drifting_generator import generate_drifting_recording +from spikeinterface.preprocessing.motion import motion_options_preset +from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks +from spikeinterface.widgets import plot_peaks_on_probe + +# %% read the file +# Next, we will generate a synthetic, drifting recording. This recording will +# have 100 separate units with firing rates randomly distributed between +# 15 and 25 Hz. +# +# We will create a zigzag drift pattern on the recording, starting at +# 100 seconds and with a peak-to-peak period of 100 seconds (so we will +# have 9 zigzags through our recording). We also add some non-linearity +# to the imposed motion. + +# %% +#.. note:: +# This tutorial can take a long time to run with the default arguments. +# If you would like to run this locally, you may want to edit ``num_units`` +# and ``duration`` to smaller values (e.g. 25 and 100 respectively). +# +# Also note, the below code uses multiprocessing. If you are on Windows, you may +# need to place the code within a ``if __name__ == "__main__":`` block. + + +num_units = 200 # 250 still too many I think! +duration = 1000 + +_, raw_recording, _ = generate_drifting_recording( + num_units=num_units, + duration=duration, + generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0), + seed=42, + generate_displacement_vector_kwargs=dict(motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=0.01, + t_start_drift=int(duration/10), + t_end_drift=None, + period_s=int(duration/10), + ), + ], + ) +) +print(raw_recording) + +# %% +# Before performing motion correction, we will **preprocess** the recording +# with a bandpass filter and a common median reference. + +filtered_recording = si.bandpass_filter(raw_recording, freq_min=300.0, freq_max=6000.0) +preprocessed_recording = si.common_reference(filtered_recording, reference="global", operator="median") + +# %% +#.. warning:: +# It is better to not whiten the recording before motion estimation, as this +# will give a better estimate of the peak locations. Whitening should +# be performed after motion correction. + +# %% +# ---------------------------------------- +# Run motion correction with one function! +# ---------------------------------------- +# +# Correcting for drift is easy! You just need to run a single function. +# We will now run motion correction on our recording using the three +# presets described above - **rigid_fast**, **kilosort_like** and +# **nonrigid_accurate**. +# +# We can run these presents with the ``preset`` argument of +# :py:func:`~spikeinterface.preprocessing.correct_motion()`. Under the +# hood, the presets define a set of parameters by set how to run the +# 4 submodules that make up motion correction (described above). +print(motion_options_preset["kilosort_like"]) + +# %% +# Now, lets run motion correction with our three presets. We will +# set the ``job_kwargs`` to parallelize the job over a number of CPU cores—motion +# correction is computationally intensive and will run faster with parallelization. + +presets_to_run = ("rigid_fast", "kilosort_like", "nonrigid_accurate") + +job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) + +results = {preset: {} for preset in presets_to_run} +for preset in presets_to_run: + + corrected_recording, motion_info = si.correct_motion( + preprocessed_recording, preset=preset, output_motion_info=True, **job_kwargs + ) + results[preset]["motion_info"] = motion_info + +# %% +#.. seealso:: +# It is often very useful to save ``motion_info`` to a +# file, so it can be loaded and visualized later. This can be done by setting +# the ``folder`` argument of +# :py:func:`~spikeinterface.preprocessing.correct_motion()` to a path to write +# all motion outputs to. The ``motion_info`` can be loaded back with +# ``si.load_motion_info``. + +# %% +# -------------------- +# Plotting the results +# -------------------- +# +# Next, let's plot the results of our motion estimation using the ``plot_motion_info()`` +# function. The plot contains 4 panels, on the x-axis of all plots we have +# the (binned time). The plots display: +# * **top left:** The estimated peak depth for every detected peak. +# * **top right:** The estimated peak depths after motion correction. +# * **bottom left:** The average motion vector across depths and all motion across spatial depths (for non-rigid estimation). +# * **bottom right:** if motion correction is non-rigid, the motion vector across depths is plotted as a map, with the color code representing the motion in micrometers. + +for preset in presets_to_run: + + fig = plt.figure(figsize=(7, 7)) + + si.plot_motion_info( + results[preset]["motion_info"], + recording=corrected_recording, # the recording is only used to get the real times + figure=fig, + depth_lim=(400, 600), + color_amplitude=True, + amplitude_cmap="inferno", + scatter_decimate=10, # Only plot every 10th peak + ) + fig.suptitle(f"{preset=}") + +# %% +# These plots are quite complicated, so it is worth covering them in detail. +# For every detected spike in our recording, we first estimate +# its depth (first panel) using a method from +# :py:func:`~spikeinterface.postprocessing.compute_unit_locations()`. +# +# Then, the probe motion is estimated and the location of the +# spikes are adjusted to account for the motion (second panel). +# +# The motion estimation produces +# a measure of how much and in what direction the probe is moving at any given +# time bin (third panel). For non-rigid motion correction, the probe is divided +# into subsections - the motion vectors displayed are per subsection (i.e. per +# 'binned spatial depth') as well as the average. +# +# On the fourth panel, we see a +# more detailed representation of the motion vectors. We can see the motion plotted +# as a heatmap at each binned spatial depth across all time bins. It captures +# the zigzag pattern (alternating light and dark colors) of the injected motion. + +# %% +# A few comments on the figures: +# * The preset **'rigid_fast'** has only one motion vector for the entire probe because it is a 'rigid' case. +# The motion amplitude is globally underestimated because it averages across depths. +# However, the corrected peaks are flatter than the non-corrected ones, so the job is partially done. +# The big jump at 600s when the probe start moving is recovered quite well. +# * The preset **kilosort_like** gives better results because it is a non-rigid case. +# The motion vector is computed for different depths. +# The corrected peak locations are flatter than the rigid case. +# The motion vector map is still a bit noisy at some depths (e.g around 1000um). +# * The preset **nonrigid_accurate** seems to give the best results on this recording. +# The motion vector seems less noisy globally, but it is not 'perfect' (see at the top of the probe 3200um to 3800um). +# Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion: +# the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable. +# The method defined by this preset is able to capture this. + +# %% +# ------------------------------------------------- +# Correcting Peak Locations after Motion Correction +# ------------------------------------------------- +# +# The result of motion correction can be applied to the data in two ways. +# The first is by interpolating the raw traces to correct for the estimated drift. +# This changes the data in the +# recording by shifting the signal across channels, and is given in the +# `corrected_recording` output from :py:func:`~spikeinterface.preprocessing.correct_motion()`. +# This is useful in most cases, for continuing +# with preprocessing and sorting with the corrected recording. +# +# The second way is to apply the results of motion correction directly +# to the ``peak_locations`` object. If you are not familiar with +# SpikeInterface's ``peak`` and ``peak_locations`` objects, +# these are explored further in the below dropdown. + +# %% +# .. dropdown:: 'Peaks' and 'Peak Locations' in SpikeInterface +# +# Information about detected spikes is represented in +# SpikeInterface's ``peaks`` and ``peak_locations`` objects. The +# ``peaks`` object is an array for containing the +# sample index, channel index (where its signal +# is strongest), amplitude and recording segment index for every detected spike +# in the dataset. It is created by the +# :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks()` +# function. +# +# The ``peak_locations`` is a partner object to the ``peaks`` object, +# and contains the estimated location (``"x"``, ``"y"``) of the spike. For every spike in +# ``peaks`` there is a corresponding location in ``peak_locations``. +# The peak locations is estimated using the +# :py:func:`~spikeinterface.sortingcomponents.peak_localization.localise_peaks()` +# function. + +# %% +# The other way to apply the motion correction is to the ``peaks`` and +# ``peaks_location`` objects directly. This is done using the function +# ``correct_motion_on_peaks()``. Given a set of peaks, peak locations and +# the ``motion`` object output from :py:func:`~spikeinterface.preprocessing.correct_motion()`, +# it will shift the location of the peaks according to the motion estimate, outputting a new +# ``peak_locations`` object. This is done to plot the peak locations in +# the next section. +# + + +# %% +#.. warning:: +# Note that the ``peak_locations`` output by +# :py:func:`~spikeinterface.preprocessing.correct_motion()` +# (in the ``motion_info`` object) is the original (uncorrected) peak locations. +# To get the corrected peak locations, ``correct_motion_on_peaks()`` must be used! + +for preset in presets_to_run: + + motion_info = results[preset]["motion_info"] + + peaks = motion_info["peaks"] + + original_peak_locations = motion_info["peak_locations"] + + corrected_peak_locations = correct_motion_on_peaks(peaks, original_peak_locations, motion_info['motion'], corrected_recording) + + widget = plot_peaks_on_probe(corrected_recording, [peaks, peaks], [original_peak_locations, corrected_peak_locations], ylim=(300,600)) + widget.figure.suptitle(preset) + +# %% +# ------------------------- +# Comparing the Run Times +# ------------------------- +# +# The different methods also have different speeds, the 'nonrigid_accurate' +# requires more computation time, in particular at the ``estimate_motion`` phase, +# as seen in the run times: + +for preset in presets_to_run: + print(preset) + print(results[preset]["motion_info"]["run_times"]) + +# %% +# ----------------------------------------- +# When do I need to apply drift correction? +# ----------------------------------------- +# +# Drift correction may not always be necessary for your data, for +# example, for example when there is not much drift in the data to begin with. +# Further, in some cases (e.g. when the probe has a smaller number of channels, +# e.g. 64 or less) the drift correction algorithms may not perfect well. +# +# To check whether drift correction is required and how it is performing, +# it is necessary to run drift correction as above and then check the output plots. +# In the below example, the 'Peak depth' plot shows minimal drift in the peak position. +# In this example, it does not look like drift correction is that necessary. Further, +# because there are only 16 channels in this recording, the drift correction is failing. +# The 'Correct Peak Depth' as erroenously shifted peaks to the wrong position, spreading +# them across the probe. In this instance, drift correction could be skipped. +# +# .. image:: ../../images/no-drift-example.png + +# %% +# ------------------------ +# Summary +# ------------------------ +# +# That's it for our tour of motion correction in +# SpikeInterface. Remember that correcting motion makes some +# assumptions on your data (e.g. number of channels, noise in the recording)—always +# plot the motion correction information for your +# recordings, to make sure the correction is behaving as expected! diff --git a/doc/long_tutorials/handle_drift/plot_handle_drift.py.md5 b/doc/long_tutorials/handle_drift/plot_handle_drift.py.md5 new file mode 100644 index 0000000000..cf19ae18eb --- /dev/null +++ b/doc/long_tutorials/handle_drift/plot_handle_drift.py.md5 @@ -0,0 +1 @@ +8f6fa0b9f5b79377c60cbf789c293c27 diff --git a/doc/long_tutorials/handle_drift/plot_handle_drift.rst b/doc/long_tutorials/handle_drift/plot_handle_drift.rst new file mode 100644 index 0000000000..1107a9015c --- /dev/null +++ b/doc/long_tutorials/handle_drift/plot_handle_drift.rst @@ -0,0 +1,605 @@ + +.. DO NOT EDIT. +.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. +.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: +.. "long_tutorials/handle_drift/plot_handle_drift.py" +.. LINE NUMBERS ARE GIVEN BELOW. + +.. only:: html + + .. note:: + :class: sphx-glr-download-link-note + + :ref:`Go to the end ` + to download the full example code. + +.. rst-class:: sphx-glr-example-title + +.. _sphx_glr_long_tutorials_handle_drift_plot_handle_drift.py: + + +=========================================== +Handle probe drift with spikeinterface NEW +=========================================== + +Probe movement is an inevitability when running +*in vivo* electrophysiology recordings. Motion, caused by physical +movement of the probe or the sliding of brain tissue +deforming across the probe, can complicate the sorting +and analysis of units. + +SpikeInterface offers a flexible framework to handle motion correction +as a preprocessing step. In this tutorial we will cover the three main +drift-correction algorithms implemented in SpikeInterface +(**rigid_fast**, **kilosort_like** and **nonrigid_accurate**) with +a focus on running the methods and interpreting the output. + +For more information on the theory and implementation of these methods, +see the :ref:`motion_correction` section of the documentation and +the `kilosort4 page `_ +on drift correction. Drift correction may not always work as expected +(for example, if the probe has a small number of channels), see the +`When do I need to apply drift correction?`_ section for assessing +drift correction output. + +--------------------- +What is probe drift? +--------------------- + +The inserted probe can move from side-to-side (*'x' direction*), +up-or-down (*'y' direction*) or forwards-or-backwards (*'z' direction*). +Movement in the 'x' and 'z' direction is harder to model than vertical +drift (i.e. along the probe depth), and are not handled by most motion +correction algorithms. Fortunately, vertical drift which is most easily +handled is most pronounced as the probe is most likely to move along the path +of insertion. + +Vertical drift can come in two forms, *'rigid'* and *'non-rigid'*. Rigid drift +is drift caused by movement of the entire probe, and the motion is +similar across all channels along the probe depth. In contrast, +non-rigid drift is instead caused by local movements of neuronal tissue along the +probe, and can selectively affect subsets of channels. + +-------------------------- +The drift correction steps +-------------------------- + +The easiest way to run drift correction in SpikeInterface is with the +high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` function. +This function takes a recording as input and returns a motion-corrected +recording object. As with all other preprocessing steps, the correction (in this +case interpolation of the data to correct the detected motion) is lazy and applied on-the-fly when data is needed). + +The :py:func:`~spikeinterface.preprocessing.correct_motion()` +function implements motion correction algorithms in a modular way +wrapping a number of subfunctions that together implement the +full drift correction algorithm. + +These drift-correction modules are: + +| **1.** ``localize_peaks()`` (detect spikes and localize their position on the probe) +| **2.** ``select_peaks()`` (optional, select a subset of peaks to use to estimate motion) +| **3.** ``estimate_motion()`` (estimate motion using the detected spikes) +| **4.** ``interpolate_motion()`` (perform interpolation on the raw data to account for the estimated drift). + +All these sub-steps have many parameters which dictate the +speed and effectiveness of motion correction. As such, ``correct_motion`` +provides three setting 'presets' which configure the motion correct +to proceed either as: + +* **rigid_fast** - a fast, not particularly accurate correction assuming rigid drift. +* **kilosort-like** - Mimics what is done in Kilosort. +* **nonrigid_accurate** - A decentralized drift correction (DREDGE), introduced by the Paninski group. + +When using motion correction in your analysis, please make sure to +:ref:`cite the appropriate paper for your chosen method`. + + +**Now, let's dive into running motion correction with these three +methods on a simulated dataset.** + +.. GENERATED FROM PYTHON SOURCE LINES 85-90 + +------------------------------------------- +Setting up and preprocessing the recording +------------------------------------------- + +First, we will import the modules we will need for this tutorial: + +.. GENERATED FROM PYTHON SOURCE LINES 90-98 + +.. code-block:: Python + + + import matplotlib.pyplot as plt + import spikeinterface.full as si + from spikeinterface.generation.drifting_generator import generate_drifting_recording + from spikeinterface.preprocessing.motion import motion_options_preset + from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks + from spikeinterface.widgets import plot_peaks_on_probe + + + + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 99-107 + +Next, we will generate a synthetic, drifting recording. This recording will +have 100 separate units with firing rates randomly distributed between +15 and 25 Hz. + +We will create a zigzag drift pattern on the recording, starting at +100 seconds and with a peak-to-peak period of 100 seconds (so we will +have 9 zigzags through our recording). We also add some non-linearity +to the imposed motion. + +.. GENERATED FROM PYTHON SOURCE LINES 109-116 + +.. note:: + This tutorial can take a long time to run with the default arguments. + If you would like to run this locally, you may want to edit ``num_units`` + and ``duration`` to smaller values (e.g. 25 and 100 respectively). + + Also note, the below code uses multiprocessing. If you are on Windows, you may + need to place the code within a ``if __name__ == "__main__":`` block. + +.. GENERATED FROM PYTHON SOURCE LINES 116-139 + +.. code-block:: Python + + + + num_units = 200 # 250 still too many I think! + duration = 1000 + + _, raw_recording, _ = generate_drifting_recording( + num_units=num_units, + duration=duration, + generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0), + seed=42, + generate_displacement_vector_kwargs=dict(motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=0.01, + t_start_drift=int(duration/10), + t_end_drift=None, + period_s=int(duration/10), + ), + ], + ) + ) + print(raw_recording) + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + InjectDriftingTemplatesRecording: 128 channels - 30.0kHz - 1 segments - 30,000,000 samples + 1,000.00s (16.67 minutes) - float32 dtype - 14.31 GiB + + + + +.. GENERATED FROM PYTHON SOURCE LINES 140-142 + +Before performing motion correction, we will **preprocess** the recording +with a bandpass filter and a common median reference. + +.. GENERATED FROM PYTHON SOURCE LINES 142-146 + +.. code-block:: Python + + + filtered_recording = si.bandpass_filter(raw_recording, freq_min=300.0, freq_max=6000.0) + preprocessed_recording = si.common_reference(filtered_recording, reference="global", operator="median") + + + + + + + + +.. GENERATED FROM PYTHON SOURCE LINES 147-151 + +.. warning:: + It is better to not whiten the recording before motion estimation, as this + will give a better estimate of the peak locations. Whitening should + be performed after motion correction. + +.. GENERATED FROM PYTHON SOURCE LINES 153-166 + +---------------------------------------- +Run motion correction with one function! +---------------------------------------- + +Correcting for drift is easy! You just need to run a single function. +We will now run motion correction on our recording using the three +presets described above - **rigid_fast**, **kilosort_like** and +**nonrigid_accurate**. + +We can run these presents with the ``preset`` argument of +:py:func:`~spikeinterface.preprocessing.correct_motion()`. Under the +hood, the presets define a set of parameters by set how to run the +4 submodules that make up motion correction (described above). + +.. GENERATED FROM PYTHON SOURCE LINES 166-168 + +.. code-block:: Python + + print(motion_options_preset["kilosort_like"]) + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + {'doc': 'Mimic the drift correction of kilosort (grid_convolution + iterative_template)', 'detect_kwargs': {'method': 'locally_exclusive', 'peak_sign': 'neg', 'detect_threshold': 8.0, 'exclude_sweep_ms': 0.1, 'radius_um': 50}, 'select_kwargs': {}, 'localize_peaks_kwargs': {'method': 'grid_convolution', 'radius_um': 40.0, 'upsampling_um': 5.0, 'weight_method': {'mode': 'gaussian_2d', 'sigma_list_um': array([ 5., 10., 15., 20., 25.])}, 'sigma_ms': 0.25, 'margin_um': 30.0, 'prototype': None, 'percentile': 5.0}, 'estimate_motion_kwargs': {'method': 'iterative_template', 'bin_duration_s': 2.0, 'rigid': False, 'win_step_um': 50.0, 'win_sigma_um': 150.0, 'margin_um': 0, 'win_shape': 'rect'}, 'interpolate_motion_kwargs': {'border_mode': 'force_extrapolate', 'spatial_interpolation_method': 'kriging', 'sigma_um': 20.0, 'p': 2}} + + + + +.. GENERATED FROM PYTHON SOURCE LINES 169-172 + +Now, lets run motion correction with our three presets. We will +set the ``job_kwargs`` to parallelize the job over a number of CPU cores—motion +correction is computationally intensive and will run faster with parallelization. + +.. GENERATED FROM PYTHON SOURCE LINES 172-185 + +.. code-block:: Python + + + presets_to_run = ("rigid_fast", "kilosort_like", "nonrigid_accurate") + + job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) + + results = {preset: {} for preset in presets_to_run} + for preset in presets_to_run: + + corrected_recording, motion_info = si.correct_motion( + preprocessed_recording, preset=preset, output_motion_info=True, **job_kwargs + ) + results[preset]["motion_info"] = motion_info + + + + + +.. rst-class:: sphx-glr-script-out + + .. code-block:: none + + detect and localize: 0%| | 0/1000 [00:00` + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download Python source code: plot_handle_drift.py ` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery `_ diff --git a/doc/long_tutorials/handle_drift/plot_handle_drift_codeobj.pickle b/doc/long_tutorials/handle_drift/plot_handle_drift_codeobj.pickle new file mode 100644 index 0000000000..75964e699e Binary files /dev/null and b/doc/long_tutorials/handle_drift/plot_handle_drift_codeobj.pickle differ diff --git a/doc/references.rst b/doc/references.rst index ace51db951..6193eced2f 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -19,6 +19,7 @@ If you use one of the following preprocessing methods, please cite the appropria - :code:`detect_bad_channels(method='coherence+psd')` [IBL]_ - :code:`common_reference` [Rolston]_ +.. _cite-motion-correction: Motion Correction ^^^^^^^^^^^^^^^^^ If you use the :code:`correct_motion` method in the preprocessing module, please cite [Garcia]_ diff --git a/examples/long_tutorials/handle_drift/README.rst b/examples/long_tutorials/handle_drift/README.rst new file mode 100644 index 0000000000..0eb00bb349 --- /dev/null +++ b/examples/long_tutorials/handle_drift/README.rst @@ -0,0 +1,13 @@ +Handle Drift Tutorial +--------------------- + +This tutorial is not mean to be displayed on +a sphinx gallery. The generated index.rst is not +meant to be linked to in any toctree. + +Instead, sphinx-gallery is used to +automatically build this page, which +takes a long time (~25 minutes), and it is +linked too manually, directly to the +rst (TODO: fill in filename) that +sphinx-gallery generates. diff --git a/examples/long_tutorials/handle_drift/plot_handle_drift.py b/examples/long_tutorials/handle_drift/plot_handle_drift.py new file mode 100644 index 0000000000..0adc838dbf --- /dev/null +++ b/examples/long_tutorials/handle_drift/plot_handle_drift.py @@ -0,0 +1,368 @@ +""" +=========================================== +Handle probe drift with spikeinterface NEW +=========================================== + +Probe movement is an inevitability when running +*in vivo* electrophysiology recordings. Motion, caused by physical +movement of the probe or the sliding of brain tissue +deforming across the probe, can complicate the sorting +and analysis of units. + +SpikeInterface offers a flexible framework to handle motion correction +as a preprocessing step. In this tutorial we will cover the three main +drift-correction algorithms implemented in SpikeInterface +(**rigid_fast**, **kilosort_like** and **nonrigid_accurate**) with +a focus on running the methods and interpreting the output. + +For more information on the theory and implementation of these methods, +see the :ref:`motion_correction` section of the documentation and +the `kilosort4 page `_ +on drift correction. Drift correction may not always work as expected +(for example, if the probe has a small number of channels), see the +`When do I need to apply drift correction?`_ section for assessing +drift correction output. + +--------------------- +What is probe drift? +--------------------- + +The inserted probe can move from side-to-side (*'x' direction*), +up-or-down (*'y' direction*) or forwards-or-backwards (*'z' direction*). +Movement in the 'x' and 'z' direction is harder to model than vertical +drift (i.e. along the probe depth), and are not handled by most motion +correction algorithms. Fortunately, vertical drift which is most easily +handled is most pronounced as the probe is most likely to move along the path +of insertion. + +Vertical drift can come in two forms, *'rigid'* and *'non-rigid'*. Rigid drift +is drift caused by movement of the entire probe, and the motion is +similar across all channels along the probe depth. In contrast, +non-rigid drift is instead caused by local movements of neuronal tissue along the +probe, and can selectively affect subsets of channels. + +-------------------------- +The drift correction steps +-------------------------- + +The easiest way to run drift correction in SpikeInterface is with the +high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` function. +This function takes a recording as input and returns a motion-corrected +recording object. As with all other preprocessing steps, the correction (in this +case interpolation of the data to correct the detected motion) is lazy and applied on-the-fly when data is needed). + +The :py:func:`~spikeinterface.preprocessing.correct_motion()` +function implements motion correction algorithms in a modular way +wrapping a number of subfunctions that together implement the +full drift correction algorithm. + +These drift-correction modules are: + +| **1.** ``localize_peaks()`` (detect spikes and localize their position on the probe) +| **2.** ``select_peaks()`` (optional, select a subset of peaks to use to estimate motion) +| **3.** ``estimate_motion()`` (estimate motion using the detected spikes) +| **4.** ``interpolate_motion()`` (perform interpolation on the raw data to account for the estimated drift). + +All these sub-steps have many parameters which dictate the +speed and effectiveness of motion correction. As such, ``correct_motion`` +provides three setting 'presets' which configure the motion correct +to proceed either as: + +* **rigid_fast** - a fast, not particularly accurate correction assuming rigid drift. +* **kilosort-like** - Mimics what is done in Kilosort. +* **nonrigid_accurate** - A decentralized drift correction (DREDGE), introduced by the Paninski group. + +When using motion correction in your analysis, please make sure to +:ref:`cite the appropriate paper for your chosen method`. + + +**Now, let's dive into running motion correction with these three +methods on a simulated dataset.** + +""" + +# %% +# ------------------------------------------- +# Setting up and preprocessing the recording +# ------------------------------------------- +# +# First, we will import the modules we will need for this tutorial: + +import matplotlib.pyplot as plt +import spikeinterface.full as si +from spikeinterface.generation.drifting_generator import generate_drifting_recording +from spikeinterface.preprocessing.motion import motion_options_preset +from spikeinterface.sortingcomponents.motion_interpolation import correct_motion_on_peaks +from spikeinterface.widgets import plot_peaks_on_probe + +# %% read the file +# Next, we will generate a synthetic, drifting recording. This recording will +# have 100 separate units with firing rates randomly distributed between +# 15 and 25 Hz. +# +# We will create a zigzag drift pattern on the recording, starting at +# 100 seconds and with a peak-to-peak period of 100 seconds (so we will +# have 9 zigzags through our recording). We also add some non-linearity +# to the imposed motion. + +# %% +#.. note:: +# This tutorial can take a long time to run with the default arguments. +# If you would like to run this locally, you may want to edit ``num_units`` +# and ``duration`` to smaller values (e.g. 25 and 100 respectively). +# +# Also note, the below code uses multiprocessing. If you are on Windows, you may +# need to place the code within a ``if __name__ == "__main__":`` block. + + +num_units = 200 +duration = 1000 + +_, raw_recording, _ = generate_drifting_recording( + num_units=num_units, + duration=duration, + generate_sorting_kwargs=dict(firing_rates=(15, 25), refractory_period_ms=4.0), + seed=42, + generate_displacement_vector_kwargs=dict(motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=0.01, + t_start_drift=int(duration/10), + t_end_drift=None, + period_s=int(duration/10), + ), + ], + ) +) +print(raw_recording) + +# %% +# Before performing motion correction, we will **preprocess** the recording +# with a bandpass filter and a common median reference. + +filtered_recording = si.bandpass_filter(raw_recording, freq_min=300.0, freq_max=6000.0) +preprocessed_recording = si.common_reference(filtered_recording, reference="global", operator="median") + +# %% +#.. warning:: +# It is better to not whiten the recording before motion estimation, as this +# will give a better estimate of the peak locations. Whitening should +# be performed after motion correction. + +# %% +# ---------------------------------------- +# Run motion correction with one function! +# ---------------------------------------- +# +# Correcting for drift is easy! You just need to run a single function. +# We will now run motion correction on our recording using the three +# presets described above - **rigid_fast**, **kilosort_like** and +# **nonrigid_accurate**. +# +# We can run these presents with the ``preset`` argument of +# :py:func:`~spikeinterface.preprocessing.correct_motion()`. Under the +# hood, the presets define a set of parameters by set how to run the +# 4 submodules that make up motion correction (described above). +print(motion_options_preset["kilosort_like"]) + +# %% +# Now, lets run motion correction with our three presets. We will +# set the ``job_kwargs`` to parallelize the job over a number of CPU cores—motion +# correction is computationally intensive and will run faster with parallelization. + +presets_to_run = ("rigid_fast", "kilosort_like", "nonrigid_accurate") + +job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True) + +results = {preset: {} for preset in presets_to_run} +for preset in presets_to_run: + + corrected_recording, motion_info = si.correct_motion( + preprocessed_recording, preset=preset, output_motion_info=True, **job_kwargs + ) + results[preset]["motion_info"] = motion_info + +# %% +#.. seealso:: +# It is often very useful to save ``motion_info`` to a +# file, so it can be loaded and visualized later. This can be done by setting +# the ``folder`` argument of +# :py:func:`~spikeinterface.preprocessing.correct_motion()` to a path to write +# all motion outputs to. The ``motion_info`` can be loaded back with +# ``si.load_motion_info``. + +# %% +# -------------------- +# Plotting the results +# -------------------- +# +# Next, let's plot the results of our motion estimation using the ``plot_motion_info()`` +# function. The plot contains 4 panels, on the x-axis of all plots we have +# the (binned time). The plots display: +# * **top left:** The estimated peak depth for every detected peak. +# * **top right:** The estimated peak depths after motion correction. +# * **bottom left:** The average motion vector across depths and all motion across spatial depths (for non-rigid estimation). +# * **bottom right:** if motion correction is non-rigid, the motion vector across depths is plotted as a map, with the color code representing the motion in micrometers. + +for preset in presets_to_run: + + fig = plt.figure(figsize=(7, 7)) + + si.plot_motion_info( + results[preset]["motion_info"], + recording=corrected_recording, # the recording is only used to get the real times + figure=fig, + depth_lim=(400, 600), + color_amplitude=True, + amplitude_cmap="inferno", + scatter_decimate=10, # Only plot every 10th peak + ) + fig.suptitle(f"{preset=}") + +# %% +# These plots are quite complicated, so it is worth covering them in detail. +# For every detected spike in our recording, we first estimate +# its depth (first panel) using a method from +# :py:func:`~spikeinterface.postprocessing.compute_unit_locations()`. +# +# Then, the probe motion is estimated and the location of the +# spikes are adjusted to account for the motion (second panel). +# +# The motion estimation produces +# a measure of how much and in what direction the probe is moving at any given +# time bin (third panel). For non-rigid motion correction, the probe is divided +# into subsections - the motion vectors displayed are per subsection (i.e. per +# 'binned spatial depth') as well as the average. +# +# On the fourth panel, we see a +# more detailed representation of the motion vectors. We can see the motion plotted +# as a heatmap at each binned spatial depth across all time bins. It captures +# the zigzag pattern (alternating light and dark colors) of the injected motion. + +# %% +# A few comments on the figures: +# * The preset **'rigid_fast'** has only one motion vector for the entire probe because it is a 'rigid' case. +# The motion amplitude is globally underestimated because it averages across depths. +# However, the corrected peaks are flatter than the non-corrected ones, so the job is partially done. +# The big jump at 600s when the probe start moving is recovered quite well. +# * The preset **kilosort_like** gives better results because it is a non-rigid case. +# The motion vector is computed for different depths. +# The corrected peak locations are flatter than the rigid case. +# The motion vector map is still a bit noisy at some depths (e.g around 1000um). +# * The preset **nonrigid_accurate** seems to give the best results on this recording. +# The motion vector seems less noisy globally, but it is not 'perfect' (see at the top of the probe 3200um to 3800um). +# Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion: +# the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable. +# The method defined by this preset is able to capture this. + +# %% +# ------------------------------------------------- +# Correcting Peak Locations after Motion Correction +# ------------------------------------------------- +# +# The result of motion correction can be applied to the data in two ways. +# The first is by interpolating the raw traces to correct for the estimated drift. +# This changes the data in the +# recording by shifting the signal across channels, and is given in the +# `corrected_recording` output from :py:func:`~spikeinterface.preprocessing.correct_motion()`. +# This is useful in most cases, for continuing +# with preprocessing and sorting with the corrected recording. +# +# The second way is to apply the results of motion correction directly +# to the ``peak_locations`` object. If you are not familiar with +# SpikeInterface's ``peak`` and ``peak_locations`` objects, +# these are explored further in the below dropdown. + +# %% +# .. dropdown:: 'Peaks' and 'Peak Locations' in SpikeInterface +# +# Information about detected spikes is represented in +# SpikeInterface's ``peaks`` and ``peak_locations`` objects. The +# ``peaks`` object is an array for containing the +# sample index, channel index (where its signal +# is strongest), amplitude and recording segment index for every detected spike +# in the dataset. It is created by the +# :py:func:`~spikeinterface.sortingcomponents.peak_detection.detect_peaks()` +# function. +# +# The ``peak_locations`` is a partner object to the ``peaks`` object, +# and contains the estimated location (``"x"``, ``"y"``) of the spike. For every spike in +# ``peaks`` there is a corresponding location in ``peak_locations``. +# The peak locations is estimated using the +# :py:func:`~spikeinterface.sortingcomponents.peak_localization.localise_peaks()` +# function. + +# %% +# The other way to apply the motion correction is to the ``peaks`` and +# ``peaks_location`` objects directly. This is done using the function +# ``correct_motion_on_peaks()``. Given a set of peaks, peak locations and +# the ``motion`` object output from :py:func:`~spikeinterface.preprocessing.correct_motion()`, +# it will shift the location of the peaks according to the motion estimate, outputting a new +# ``peak_locations`` object. This is done to plot the peak locations in +# the next section. +# + + +# %% +#.. warning:: +# Note that the ``peak_locations`` output by +# :py:func:`~spikeinterface.preprocessing.correct_motion()` +# (in the ``motion_info`` object) is the original (uncorrected) peak locations. +# To get the corrected peak locations, ``correct_motion_on_peaks()`` must be used! + +for preset in presets_to_run: + + motion_info = results[preset]["motion_info"] + + peaks = motion_info["peaks"] + + original_peak_locations = motion_info["peak_locations"] + + corrected_peak_locations = correct_motion_on_peaks(peaks, original_peak_locations, motion_info['motion'], corrected_recording) + + widget = plot_peaks_on_probe(corrected_recording, [peaks, peaks], [original_peak_locations, corrected_peak_locations], ylim=(300,600)) + widget.figure.suptitle(preset) + +# %% +# ------------------------- +# Comparing the Run Times +# ------------------------- +# +# The different methods also have different speeds, the 'nonrigid_accurate' +# requires more computation time, in particular at the ``estimate_motion`` phase, +# as seen in the run times: + +for preset in presets_to_run: + print(preset) + print(results[preset]["motion_info"]["run_times"]) + +# %% +# ----------------------------------------- +# When do I need to apply drift correction? +# ----------------------------------------- +# +# Drift correction may not always be necessary for your data, for +# example, for example when there is not much drift in the data to begin with. +# Further, in some cases (e.g. when the probe has a smaller number of channels, +# e.g. 64 or less) the drift correction algorithms may not perfect well. +# +# To check whether drift correction is required and how it is performing, +# it is necessary to run drift correction as above and then check the output plots. +# In the below example, the 'Peak depth' plot shows minimal drift in the peak position. +# In this example, it does not look like drift correction is that necessary. Further, +# because there are only 16 channels in this recording, the drift correction is failing. +# The 'Correct Peak Depth' as erroenously shifted peaks to the wrong position, spreading +# them across the probe. In this instance, drift correction could be skipped. +# +# .. image:: ../../images/no-drift-example.png + +# %% +# ------------------------ +# Summary +# ------------------------ +# +# That's it for our tour of motion correction in +# SpikeInterface. Remember that correcting motion makes some +# assumptions on your data (e.g. number of channels, noise in the recording)—always +# plot the motion correction information for your +# recordings, to make sure the correction is behaving as expected!