Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
dependencies = [
"numpy",
"packaging",
"requests"
]

[project.urls]
Expand Down
8 changes: 7 additions & 1 deletion src/probeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,11 @@
generate_multi_columns_probe,
generate_multi_shank,
)
from .library import get_probe
from .library import (
get_probe,
list_manufacturers_in_library,
list_probes_in_library,
get_tags_in_library,
cache_full_library,
)
from .wiring import get_available_pathways
206 changes: 187 additions & 19 deletions src/probeinterface/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,33 @@

from __future__ import annotations
import os
import warnings
from pathlib import Path
from urllib.request import urlopen
import requests
from typing import Optional

from .io import read_probeinterface

# OLD URL on gin
# public_url = "https://web.gin.g-node.org/spikeinterface/probeinterface_library/raw/master/"

# Now on github since 2023/06/15
public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/main/"
public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/"


# check this for windows and osx
cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library"
def get_cache_folder() -> Path:
"""Get the cache folder for probeinterface library files.

Returns
-------
cache_folder : Path
The path to the cache folder.
"""
return Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library"

def download_probeinterface_file(manufacturer: str, probe_name: str):

def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> None:
"""Download the probeinterface file to the cache directory.
Note that the file is itself a ProbeGroup but on the repo each file
represents one probe.
Expand All @@ -38,16 +48,24 @@ def download_probeinterface_file(manufacturer: str, probe_name: str):
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe
"""
os.makedirs(cache_folder / manufacturer, exist_ok=True)
localfile = cache_folder / manufacturer / (probe_name + ".json")
distantfile = public_url + f"{manufacturer}/{probe_name}/{probe_name}.json"
dist = urlopen(distantfile)
with open(localfile, "wb") as f:
f.write(dist.read())
cache_folder = get_cache_folder()
if tag is not None:
assert tag in get_tags_in_library(), f"Tag {tag} not found in library"
else:
tag = "main"

os.makedirs(cache_folder / tag / manufacturer, exist_ok=True)
local_file = cache_folder / tag / manufacturer / (probe_name + ".json")
remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json"
rem = urlopen(remote_file)
with open(local_file, "wb") as f:
f.write(rem.read())

def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:

def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]:
"""
Get Probe from local cache

Expand All @@ -57,24 +75,72 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe

Returns
-------
probe : Probe object, or None if no probeinterface JSON file is found

"""
cache_folder = get_cache_folder()
if tag is not None:
cache_folder_tag = cache_folder / tag
if not cache_folder_tag.is_dir():
return None
cache_folder = cache_folder_tag
else:
cache_folder_tag = cache_folder / "main"

localfile = cache_folder / manufacturer / (probe_name + ".json")
if not localfile.is_file():
local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
if not local_file.is_file():
return None
else:
probegroup = read_probeinterface(localfile)
probegroup = read_probeinterface(local_file)
probe = probegroup.probes[0]
probe._probe_group = None
return probe


def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe":
def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]:
"""
Remove Probe from local cache

Parameters
----------
manufacturer : "cambridgeneurotech" | "neuronexus" | "plexon" | "imec" | "sinaps"
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe

Returns
-------
probe : Probe object, or None if no probeinterface JSON file is found

"""
cache_folder = get_cache_folder()
if tag is not None:
cache_folder_tag = cache_folder / tag
if not cache_folder_tag.is_dir():
return None
cache_folder = cache_folder_tag
else:
cache_folder_tag = cache_folder / "main"

local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
if local_file.is_file():
os.remove(local_file)


def get_probe(
manufacturer: str,
probe_name: str,
name: Optional[str] = None,
tag: Optional[str] = None,
force_download: bool = False,
) -> "Probe":
"""
Get probe from ProbeInterface library

Expand All @@ -86,21 +152,123 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) ->
The probe name
name : str | None, default: None
Optional name for the probe
tag : str | None, default: None
Optional tag for the probe
force_download : bool, default: False
If True, force re-download of the probe file.

Returns
----------
probe : Probe object

"""

probe = get_from_cache(manufacturer, probe_name)
if not force_download:
probe = get_from_cache(manufacturer, probe_name, tag=tag)
else:
probe = None

if probe is None:
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
download_probeinterface_file(manufacturer, probe_name, tag=tag)
probe = get_from_cache(manufacturer, probe_name, tag=tag)
if probe.manufacturer == "":
probe.manufacturer = manufacturer
if name is not None:
probe.name = name

return probe


def cache_full_library(tag=None) -> None:
"""
Download all probes from the library to the cache directory.
"""
manufacturers = list_manufacturers_in_library(tag=tag)

for manufacturer in manufacturers:
probes = list_probes_in_library(manufacturer, tag=tag)
for probe_name in probes:
try:
download_probeinterface_file(manufacturer, probe_name, tag=tag)
except Exception as e:
warnings.warn(f"Could not download {manufacturer}/{probe_name} (tag: {tag}): {e}")


def list_manufacturers_in_library(tag=None) -> list[str]:
"""
Get the list of available manufacturers in the library

Returns
-------
manufacturers : list of str
List of available manufacturers
"""
return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag)


def list_probes_in_library(manufacturer: str, tag=None) -> list[str]:
"""
Get the list of available probes for a given manufacturer

Parameters
----------
manufacturer : str
The probe manufacturer

Returns
-------
probes : list of str
List of available probes for the given manufacturer
"""
return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag)


def get_tags_in_library() -> list[str]:
"""
Get the list of available tags in the library

Returns
-------
tags : list of str
List of available tags
"""
tags = get_all_tags("SpikeInterface", "probeinterface_library")
return tags


### UTILS
def get_all_tags(owner: str, repo: str, token: str = None):
"""
Get all tags for a repo.
Returns a list of tag names, or an empty list if no tags exist.
"""
url = f"https://api.github.com/repos/{owner}/{repo}/tags"
headers = {}
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}")
tags = resp.json()
return [tag["name"] for tag in tags]


def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None):
"""
Return a list of directory names in the given repo at the specified path.
You can pass a branch, tag, or commit SHA via `ref`.
If token is provided, use it for authenticated requests (higher rate limits).
"""
url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
headers = {}
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers, params=params)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}")
items = resp.json()
return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."]
52 changes: 44 additions & 8 deletions tests/test_library.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
from probeinterface import Probe
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe


from pathlib import Path
import numpy as np

import pytest
from probeinterface.library import (
download_probeinterface_file,
get_from_cache,
remove_from_cache,
get_probe,
get_tags_in_library,
list_manufacturers_in_library,
list_probes_in_library,
get_cache_folder,
)


manufacturer = "neuronexus"
probe_name = "A1x32-Poly3-10mm-50-177"


def test_download_probeinterface_file():
download_probeinterface_file(manufacturer, probe_name)
download_probeinterface_file(manufacturer, probe_name, tag=None)


def test_get_from_cache():
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
assert isinstance(probe, Probe)

tag = get_tags_in_library()[0]
probe = get_from_cache(manufacturer, probe_name, tag=tag)
assert probe is None # because we did not download with this tag
download_probeinterface_file(manufacturer, probe_name, tag=tag)
probe = get_from_cache(manufacturer, probe_name, tag=tag)
remove_from_cache(manufacturer, probe_name, tag=tag)
assert isinstance(probe, Probe)

probe = get_from_cache("yep", "yop")
assert probe is None

Expand All @@ -31,7 +42,32 @@ def test_get_probe():
assert probe.get_contact_count() == 32


def test_available_tags():
tags = get_tags_in_library()
if len(tags) > 0:
for tag in tags:
assert isinstance(tag, str)
assert len(tag) > 0


def test_list_manufacturers_in_library():
manufacturers = list_manufacturers_in_library()
assert isinstance(manufacturers, list)
assert "neuronexus" in manufacturers
assert "imec" in manufacturers


def test_list_probes_in_library():
manufacturers = list_manufacturers_in_library()
for manufacturer in manufacturers:
probes = list_probes_in_library(manufacturer)
assert isinstance(probes, list)
assert len(probes) > 0


if __name__ == "__main__":
test_download_probeinterface_file()
test_get_from_cache()
test_get_probe()
test_list_manufacturers_in_library()
test_list_probes_in_library()