diff --git a/pyproject.toml b/pyproject.toml index 1612940..e364b2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "numpy", "packaging", + "requests" ] [project.urls] diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 4f8d746..7998e14 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -39,5 +39,11 @@ generate_multi_columns_probe, generate_multi_shank, ) -from .library import get_probe +from .library import ( + get_probe, + list_manufacturers_in_library, + list_probes_in_library, + get_tags_in_library, + cache_full_library, +) from .wiring import get_available_pathways diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index 7666cd8..ff77083 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -11,23 +11,33 @@ from __future__ import annotations import os +import warnings from pathlib import Path from urllib.request import urlopen +import requests from typing import Optional from .io import read_probeinterface # OLD URL on gin # public_url = "https://web.gin.g-node.org/spikeinterface/probeinterface_library/raw/master/" - # Now on github since 2023/06/15 -public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/main/" +public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/" + # check this for windows and osx -cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" +def get_cache_folder() -> Path: + """Get the cache folder for probeinterface library files. + Returns + ------- + cache_folder : Path + The path to the cache folder. + """ + return Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" -def download_probeinterface_file(manufacturer: str, probe_name: str): + +def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> None: """Download the probeinterface file to the cache directory. Note that the file is itself a ProbeGroup but on the repo each file represents one probe. @@ -38,16 +48,24 @@ def download_probeinterface_file(manufacturer: str, probe_name: str): The probe manufacturer probe_name : str (see probeinterface_libary for options) The probe name + tag : str | None, default: None + Optional tag for the probe """ - os.makedirs(cache_folder / manufacturer, exist_ok=True) - localfile = cache_folder / manufacturer / (probe_name + ".json") - distantfile = public_url + f"{manufacturer}/{probe_name}/{probe_name}.json" - dist = urlopen(distantfile) - with open(localfile, "wb") as f: - f.write(dist.read()) + cache_folder = get_cache_folder() + if tag is not None: + assert tag in get_tags_in_library(), f"Tag {tag} not found in library" + else: + tag = "main" + os.makedirs(cache_folder / tag / manufacturer, exist_ok=True) + local_file = cache_folder / tag / manufacturer / (probe_name + ".json") + remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json" + rem = urlopen(remote_file) + with open(local_file, "wb") as f: + f.write(rem.read()) -def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]: + +def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]: """ Get Probe from local cache @@ -57,24 +75,72 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]: The probe manufacturer probe_name : str (see probeinterface_libary for options) The probe name + tag : str | None, default: None + Optional tag for the probe Returns ------- probe : Probe object, or None if no probeinterface JSON file is found """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if not cache_folder_tag.is_dir(): + return None + cache_folder = cache_folder_tag + else: + cache_folder_tag = cache_folder / "main" - localfile = cache_folder / manufacturer / (probe_name + ".json") - if not localfile.is_file(): + local_file = cache_folder_tag / manufacturer / (probe_name + ".json") + if not local_file.is_file(): return None else: - probegroup = read_probeinterface(localfile) + probegroup = read_probeinterface(local_file) probe = probegroup.probes[0] probe._probe_group = None return probe -def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe": +def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]: + """ + Remove Probe from local cache + + Parameters + ---------- + manufacturer : "cambridgeneurotech" | "neuronexus" | "plexon" | "imec" | "sinaps" + The probe manufacturer + probe_name : str (see probeinterface_libary for options) + The probe name + tag : str | None, default: None + Optional tag for the probe + + Returns + ------- + probe : Probe object, or None if no probeinterface JSON file is found + + """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if not cache_folder_tag.is_dir(): + return None + cache_folder = cache_folder_tag + else: + cache_folder_tag = cache_folder / "main" + + local_file = cache_folder_tag / manufacturer / (probe_name + ".json") + if local_file.is_file(): + os.remove(local_file) + + +def get_probe( + manufacturer: str, + probe_name: str, + name: Optional[str] = None, + tag: Optional[str] = None, + force_download: bool = False, +) -> "Probe": """ Get probe from ProbeInterface library @@ -86,21 +152,123 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> The probe name name : str | None, default: None Optional name for the probe + tag : str | None, default: None + Optional tag for the probe + force_download : bool, default: False + If True, force re-download of the probe file. Returns ---------- probe : Probe object """ - - probe = get_from_cache(manufacturer, probe_name) + if not force_download: + probe = get_from_cache(manufacturer, probe_name, tag=tag) + else: + probe = None if probe is None: - download_probeinterface_file(manufacturer, probe_name) - probe = get_from_cache(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name, tag=tag) + probe = get_from_cache(manufacturer, probe_name, tag=tag) if probe.manufacturer == "": probe.manufacturer = manufacturer if name is not None: probe.name = name return probe + + +def cache_full_library(tag=None) -> None: + """ + Download all probes from the library to the cache directory. + """ + manufacturers = list_manufacturers_in_library(tag=tag) + + for manufacturer in manufacturers: + probes = list_probes_in_library(manufacturer, tag=tag) + for probe_name in probes: + try: + download_probeinterface_file(manufacturer, probe_name, tag=tag) + except Exception as e: + warnings.warn(f"Could not download {manufacturer}/{probe_name} (tag: {tag}): {e}") + + +def list_manufacturers_in_library(tag=None) -> list[str]: + """ + Get the list of available manufacturers in the library + + Returns + ------- + manufacturers : list of str + List of available manufacturers + """ + return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag) + + +def list_probes_in_library(manufacturer: str, tag=None) -> list[str]: + """ + Get the list of available probes for a given manufacturer + + Parameters + ---------- + manufacturer : str + The probe manufacturer + + Returns + ------- + probes : list of str + List of available probes for the given manufacturer + """ + return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag) + + +def get_tags_in_library() -> list[str]: + """ + Get the list of available tags in the library + + Returns + ------- + tags : list of str + List of available tags + """ + tags = get_all_tags("SpikeInterface", "probeinterface_library") + return tags + + +### UTILS +def get_all_tags(owner: str, repo: str, token: str = None): + """ + Get all tags for a repo. + Returns a list of tag names, or an empty list if no tags exist. + """ + url = f"https://api.github.com/repos/{owner}/{repo}/tags" + headers = {} + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") + tags = resp.json() + return [tag["name"] for tag in tags] + + +def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None): + """ + Return a list of directory names in the given repo at the specified path. + You can pass a branch, tag, or commit SHA via `ref`. + If token is provided, use it for authenticated requests (higher rate limits). + """ + url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}" + params = {} + if ref: + params["ref"] = ref + headers = {} + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers, params=params) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}") + items = resp.json() + return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."] diff --git a/tests/test_library.py b/tests/test_library.py index 8d4059d..2e90216 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -1,11 +1,14 @@ from probeinterface import Probe -from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe - - -from pathlib import Path -import numpy as np - -import pytest +from probeinterface.library import ( + download_probeinterface_file, + get_from_cache, + remove_from_cache, + get_probe, + get_tags_in_library, + list_manufacturers_in_library, + list_probes_in_library, + get_cache_folder, +) manufacturer = "neuronexus" @@ -13,7 +16,7 @@ def test_download_probeinterface_file(): - download_probeinterface_file(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name, tag=None) def test_get_from_cache(): @@ -21,6 +24,14 @@ def test_get_from_cache(): probe = get_from_cache(manufacturer, probe_name) assert isinstance(probe, Probe) + tag = get_tags_in_library()[0] + probe = get_from_cache(manufacturer, probe_name, tag=tag) + assert probe is None # because we did not download with this tag + download_probeinterface_file(manufacturer, probe_name, tag=tag) + probe = get_from_cache(manufacturer, probe_name, tag=tag) + remove_from_cache(manufacturer, probe_name, tag=tag) + assert isinstance(probe, Probe) + probe = get_from_cache("yep", "yop") assert probe is None @@ -31,7 +42,32 @@ def test_get_probe(): assert probe.get_contact_count() == 32 +def test_available_tags(): + tags = get_tags_in_library() + if len(tags) > 0: + for tag in tags: + assert isinstance(tag, str) + assert len(tag) > 0 + + +def test_list_manufacturers_in_library(): + manufacturers = list_manufacturers_in_library() + assert isinstance(manufacturers, list) + assert "neuronexus" in manufacturers + assert "imec" in manufacturers + + +def test_list_probes_in_library(): + manufacturers = list_manufacturers_in_library() + for manufacturer in manufacturers: + probes = list_probes_in_library(manufacturer) + assert isinstance(probes, list) + assert len(probes) > 0 + + if __name__ == "__main__": test_download_probeinterface_file() test_get_from_cache() test_get_probe() + test_list_manufacturers_in_library() + test_list_probes_in_library()