Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
dependencies = [
"numpy",
"packaging",
"requests"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion src/probeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@
generate_multi_columns_probe,
generate_multi_shank,
)
from .library import get_probe
from .library import get_probe, list_manufacturers_in_library, list_probes_in_library, get_tags_in_library
from .wiring import get_available_pathways
177 changes: 159 additions & 18 deletions src/probeinterface/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,30 @@
import os
from pathlib import Path
from urllib.request import urlopen
import requests
from typing import Optional

from .io import read_probeinterface

# OLD URL on gin
# public_url = "https://web.gin.g-node.org/spikeinterface/probeinterface_library/raw/master/"

# Now on github since 2023/06/15
public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/main/"
public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/"


# check this for windows and osx
cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library"
def get_cache_folder() -> Path:
"""Get the cache folder for probeinterface library files.

Returns
-------
cache_folder : Path
The path to the cache folder.
"""
return Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library"


def download_probeinterface_file(manufacturer: str, probe_name: str):
def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> None:
"""Download the probeinterface file to the cache directory.
Note that the file is itself a ProbeGroup but on the repo each file
represents one probe.
Expand All @@ -38,16 +47,23 @@ def download_probeinterface_file(manufacturer: str, probe_name: str):
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe
"""
os.makedirs(cache_folder / manufacturer, exist_ok=True)
localfile = cache_folder / manufacturer / (probe_name + ".json")
distantfile = public_url + f"{manufacturer}/{probe_name}/{probe_name}.json"
dist = urlopen(distantfile)
with open(localfile, "wb") as f:
f.write(dist.read())
cache_folder = get_cache_folder()
if tag is not None:
assert tag in get_tags_in_library(), f"Tag {tag} not found in library"
else:
tag = "main"
os.makedirs(cache_folder / tag / manufacturer, exist_ok=True)
local_file = cache_folder / tag / manufacturer / (probe_name + ".json")
remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json"
rem = urlopen(remote_file)
with open(local_file, "wb") as f:
f.write(rem.read())


def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:
def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]:
"""
Get Probe from local cache

Expand All @@ -57,24 +73,66 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]:
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe

Returns
-------
probe : Probe object, or None if no probeinterface JSON file is found

"""
cache_folder = get_cache_folder()
if tag is not None:
cache_folder_tag = cache_folder / tag
if not cache_folder_tag.is_dir():
return None
cache_folder = cache_folder_tag
else:
cache_folder_tag = cache_folder / "main"

localfile = cache_folder / manufacturer / (probe_name + ".json")
if not localfile.is_file():
local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
if not local_file.is_file():
return None
else:
probegroup = read_probeinterface(localfile)
probegroup = read_probeinterface(local_file)
probe = probegroup.probes[0]
probe._probe_group = None
return probe


def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe":
def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]:
"""
Remove Probe from local cache

Parameters
----------
manufacturer : "cambridgeneurotech" | "neuronexus" | "plexon" | "imec" | "sinaps"
The probe manufacturer
probe_name : str (see probeinterface_libary for options)
The probe name
tag : str | None, default: None
Optional tag for the probe

Returns
-------
probe : Probe object, or None if no probeinterface JSON file is found

"""
cache_folder = get_cache_folder()
if tag is not None:
cache_folder_tag = cache_folder / tag
if not cache_folder_tag.is_dir():
return None
cache_folder = cache_folder_tag
else:
cache_folder_tag = cache_folder / "main"

local_file = cache_folder_tag / manufacturer / (probe_name + ".json")
if local_file.is_file():
os.remove(local_file)


def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, tag: Optional[str] = None) -> "Probe":
"""
Get probe from ProbeInterface library

Expand All @@ -86,21 +144,104 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) ->
The probe name
name : str | None, default: None
Optional name for the probe
tag : str | None, default: None
Optional tag for the probe

Returns
----------
probe : Probe object

"""

probe = get_from_cache(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name, tag=tag)

if probe is None:
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
download_probeinterface_file(manufacturer, probe_name, tag=tag)
probe = get_from_cache(manufacturer, probe_name, tag=tag)
if probe.manufacturer == "":
probe.manufacturer = manufacturer
if name is not None:
probe.name = name

return probe


def list_manufacturers_in_library(tag=None) -> list[str]:
"""
Get the list of available manufacturers in the library

Returns
-------
manufacturers : list of str
List of available manufacturers
"""
return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag)


def list_probes_in_library(manufacturer: str, tag=None) -> list[str]:
"""
Get the list of available probes for a given manufacturer

Parameters
----------
manufacturer : str
The probe manufacturer

Returns
-------
probes : list of str
List of available probes for the given manufacturer
"""
return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag)


def get_tags_in_library() -> list[str]:
"""
Get the list of available tags in the library

Returns
-------
tags : list of str
List of available tags
"""
tags = get_all_tags("SpikeInterface", "probeinterface_library")
return tags


### UTILS
def get_all_tags(owner: str, repo: str, token: str = None):
"""
Get all tags for a repo.
Returns a list of tag names, or an empty list if no tags exist.
"""
url = f"https://api.github.com/repos/{owner}/{repo}/tags"
headers = {}
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}")
tags = resp.json()
return [tag["name"] for tag in tags]


def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None):
"""
Return a list of directory names in the given repo at the specified path.
You can pass a branch, tag, or commit SHA via `ref`.
If token is provided, use it for authenticated requests (higher rate limits).
"""
url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
headers = {}
if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"):
token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN")
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers, params=params)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}")
items = resp.json()
return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."]
51 changes: 43 additions & 8 deletions tests/test_library.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
from probeinterface import Probe
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe


from pathlib import Path
import numpy as np

import pytest
from probeinterface.library import (
download_probeinterface_file,
get_from_cache,
remove_from_cache,
get_probe,
get_tags_in_library,
list_manufacturers_in_library,
list_probes_in_library,
)


manufacturer = "neuronexus"
probe_name = "A1x32-Poly3-10mm-50-177"


def test_download_probeinterface_file():
download_probeinterface_file(manufacturer, probe_name)
download_probeinterface_file(manufacturer, probe_name, tag=None)


def test_get_from_cache():
download_probeinterface_file(manufacturer, probe_name)
probe = get_from_cache(manufacturer, probe_name)
assert isinstance(probe, Probe)

tag = get_tags_in_library()[0]
probe = get_from_cache(manufacturer, probe_name, tag=tag)
assert probe is None # because we did not download with this tag
download_probeinterface_file(manufacturer, probe_name, tag=tag)
probe = get_from_cache(manufacturer, probe_name, tag=tag)
remove_from_cache(manufacturer, probe_name, tag=tag)
assert isinstance(probe, Probe)

probe = get_from_cache("yep", "yop")
assert probe is None

Expand All @@ -31,7 +41,32 @@ def test_get_probe():
assert probe.get_contact_count() == 32


def test_available_tags():
tags = get_tags_in_library()
if len(tags) > 0:
for tag in tags:
assert isinstance(tag, str)
assert len(tag) > 0


def test_list_manufacturers_in_library():
manufacturers = list_manufacturers_in_library()
assert isinstance(manufacturers, list)
assert "neuronexus" in manufacturers
assert "imec" in manufacturers


def test_list_probes_in_library():
manufacturers = list_manufacturers_in_library()
for manufacturer in manufacturers:
probes = list_probes_in_library(manufacturer)
assert isinstance(probes, list)
assert len(probes) > 0


if __name__ == "__main__":
test_download_probeinterface_file()
test_get_from_cache()
test_get_probe()
test_list_manufacturers_in_library()
test_list_probes_in_library()