Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
dependencies = [
"numpy",
"packaging",
"requests"
]

[project.urls]
Expand Down
2 changes: 1 addition & 1 deletion src/probeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@
generate_multi_columns_probe,
generate_multi_shank,
)
from .library import get_probe
from .library import get_probe, get_manufacturers_in_library, get_probes_in_library, get_tags_in_library
from .wiring import get_available_pathways
80 changes: 80 additions & 0 deletions src/probeinterface/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
from pathlib import Path
from urllib.request import urlopen
import requests
from typing import Optional

from .io import read_probeinterface
Expand Down Expand Up @@ -104,3 +105,82 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) ->
probe.name = name

return probe


def get_manufacturers_in_library(tag=None) -> list[str]:
"""
Get the list of available manufacturers in the library

Returns
-------
manufacturers : list of str
List of available manufacturers
"""
return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag)


def get_probes_in_library(manufacturer: str, tag=None) -> list[str]:
"""
Get the list of available probes for a given manufacturer

Parameters
----------
manufacturer : str
The probe manufacturer

Returns
-------
probes : list of str
List of available probes for the given manufacturer
"""
return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag)


def get_tags_in_library() -> list[str]:
"""
Get the list of available tags in the library

Returns
-------
tags : list of str
List of available tags
"""
tags = get_all_tags("SpikeInterface", "probeinterface_library")
return tags


### UTILS
def get_all_tags(owner: str, repo: str, token: str = None):
"""
Get all tags for a repo.
Returns a list of tag names, or an empty list if no tags exist.
"""
url = f"https://api.github.com/repos/{owner}/{repo}/tags"
headers = {}
if token:
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}")
tags = resp.json()
return [tag["name"] for tag in tags]


def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None):
"""
Return a list of directory names in the given repo at the specified path.
You can pass a branch, tag, or commit SHA via `ref`.
If token is provided, use it for authenticated requests (higher rate limits).
"""
url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
headers = {}
if token:
headers["Authorization"] = f"token {token}"
resp = requests.get(url, headers=headers, params=params)
if resp.status_code != 200:
raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}")
items = resp.json()
return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."]
34 changes: 33 additions & 1 deletion tests/test_library.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from probeinterface import Probe
from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe
from probeinterface.library import (
download_probeinterface_file,
get_from_cache,
get_probe,
get_tags_in_library,
get_manufacturers_in_library,
get_probes_in_library,
)


from pathlib import Path
Expand Down Expand Up @@ -31,7 +38,32 @@ def test_get_probe():
assert probe.get_contact_count() == 32


def test_available_tags():
tags = get_tags_in_library()
if len(tags) > 0:
for tag in tags:
assert isinstance(tag, str)
assert len(tag) > 0


def test_get_manufacturers_in_library():
manufacturers = get_manufacturers_in_library()
assert isinstance(manufacturers, list)
assert "neuronexus" in manufacturers
assert "imec" in manufacturers


def test_get_probes_in_library():
manufacturers = get_manufacturers_in_library()
for manufacturer in manufacturers:
probes = get_probes_in_library(manufacturer)
assert isinstance(probes, list)
assert len(probes) > 0


if __name__ == "__main__":
test_download_probeinterface_file()
test_get_from_cache()
test_get_probe()
test_get_manufacturers_in_library()
test_get_probes_in_library()