From b064adc2a4476e242b54b2ba689a2763c11bbcc4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Sep 2025 15:49:34 +0200 Subject: [PATCH 1/8] Improve interaction with probeinterface-library --- src/probeinterface/__init__.py | 2 +- src/probeinterface/library.py | 99 ++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 4f8d746e..0d5281dc 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -39,5 +39,5 @@ generate_multi_columns_probe, generate_multi_shank, ) -from .library import get_probe +from .library import get_probe, get_manufacturers_in_library, get_probes_in_library, get_tags_in_library from .wiring import get_available_pathways diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index 7666cd8a..7d86bb3b 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -13,6 +13,7 @@ import os from pathlib import Path from urllib.request import urlopen +import requests from typing import Optional from .io import read_probeinterface @@ -104,3 +105,101 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> probe.name = name return probe + + +def get_manufacturers_in_library(tag=None) -> list[str]: + """ + Get the list of available manufacturers in the library + + Returns + ------- + manufacturers : list of str + List of available manufacturers + """ + return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag) + + +def get_probes_in_library(manufacturer: str, tag=None) -> list[str]: + """ + Get the list of available probes for a given manufacturer + + Parameters + ---------- + manufacturer : str + The probe manufacturer + + Returns + ------- + probes : list of str + List of available probes for the given manufacturer + """ + return list_github_folders("SpikeInterface", "probeinterface_library", path=manufacturer, ref=tag) + + +def get_tags_in_library() -> list[str]: + """ + Get the list of available tags in the library + + Returns + ------- + tags : list of str + List of available tags + """ + tags = [] + tags = get_all_tags("SpikeInterface", "probeinterface_library") + return tags + + +### UTILS +def get_latest_tag(owner: str, repo: str, token: str = None): + """ + Get the latest tag (by order returned from GitHub) for a repo. + Returns the tag name, or None if no tags exist. + """ + url = f"https://api.github.com/repos/{owner}/{repo}/tags" + headers = {} + if token: + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") + tags = resp.json() + if not tags: + return None + return tags[0]["name"] # first entry is the latest + + +def get_all_tags(owner: str, repo: str, token: str = None): + """ + Get all tags for a repo. + Returns a list of tag names, or an empty list if no tags exist. + """ + url = f"https://api.github.com/repos/{owner}/{repo}/tags" + headers = {} + if token: + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") + tags = resp.json() + return [tag["name"] for tag in tags] + + +def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, token: str = None): + """ + Return a list of directory names in the given repo at the specified path. + You can pass a branch, tag, or commit SHA via `ref`. + If token is provided, use it for authenticated requests (higher rate limits). + """ + url = f"https://api.github.com/repos/{owner}/{repo}/contents/{path}" + params = {} + if ref: + params["ref"] = ref + headers = {} + if token: + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers, params=params) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}") + items = resp.json() + return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."] From fc4d1bf1a69ba8cb3867ec95c9577f795574dd7c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Sep 2025 16:00:28 +0200 Subject: [PATCH 2/8] Add requests as requirement --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 16129400..e364b2bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ dependencies = [ "numpy", "packaging", + "requests" ] [project.urls] From 11ddb960871803d1d76c5be025df2036be5c0e8d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Sep 2025 16:46:32 +0200 Subject: [PATCH 3/8] Add tests --- tests/test_library.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_library.py b/tests/test_library.py index 8d4059da..6b18f4c6 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -1,5 +1,12 @@ from probeinterface import Probe -from probeinterface.library import download_probeinterface_file, get_from_cache, get_probe +from probeinterface.library import ( + download_probeinterface_file, + get_from_cache, + get_probe, + get_tags_in_library, + get_manufacturers_in_library, + get_probes_in_library, +) from pathlib import Path @@ -31,7 +38,33 @@ def test_get_probe(): assert probe.get_contact_count() == 32 +def test_available_tags(): + tags = get_tags_in_library() + if len(tags) > 0: + for tag in tags: + assert isinstance(tag, str) + assert len(tag) > 0 + + +def test_get_manufacturers_in_library(): + manufacturers = get_manufacturers_in_library() + assert isinstance(manufacturers, list) + assert "neuronexus" in manufacturers + assert "imec" in manufacturers + + +def test_get_probes_in_library(): + manufacturers = get_manufacturers_in_library() + for manufacturer in manufacturers: + probes = get_probes_in_library(manufacturer) + assert isinstance(probes, list) + assert len(probes) > 0 + + if __name__ == "__main__": test_download_probeinterface_file() test_get_from_cache() test_get_probe() + test_get_latest_tag() + test_get_manufacturers_in_library() + test_get_probes_in_library() From ddabe5c991a2283a7d34dc0237c5f0bed0c0d4a9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 26 Sep 2025 16:51:48 +0200 Subject: [PATCH 4/8] remove unused function --- src/probeinterface/library.py | 19 ------------------- tests/test_library.py | 1 - 2 files changed, 20 deletions(-) diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index 7d86bb3b..13767aa4 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -145,30 +145,11 @@ def get_tags_in_library() -> list[str]: tags : list of str List of available tags """ - tags = [] tags = get_all_tags("SpikeInterface", "probeinterface_library") return tags ### UTILS -def get_latest_tag(owner: str, repo: str, token: str = None): - """ - Get the latest tag (by order returned from GitHub) for a repo. - Returns the tag name, or None if no tags exist. - """ - url = f"https://api.github.com/repos/{owner}/{repo}/tags" - headers = {} - if token: - headers["Authorization"] = f"token {token}" - resp = requests.get(url, headers=headers) - if resp.status_code != 200: - raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") - tags = resp.json() - if not tags: - return None - return tags[0]["name"] # first entry is the latest - - def get_all_tags(owner: str, repo: str, token: str = None): """ Get all tags for a repo. diff --git a/tests/test_library.py b/tests/test_library.py index 6b18f4c6..acf0a4d5 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -65,6 +65,5 @@ def test_get_probes_in_library(): test_download_probeinterface_file() test_get_from_cache() test_get_probe() - test_get_latest_tag() test_get_manufacturers_in_library() test_get_probes_in_library() From 4f9315d4bc3de8edf1eb592c0a4c578bc5767df6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 3 Oct 2025 10:48:27 +0200 Subject: [PATCH 5/8] Add tag to get_probe and cache and use GITHUB_TOKEN --- src/probeinterface/__init__.py | 2 +- src/probeinterface/library.py | 105 ++++++++++++++++++++++++++------- tests/test_library.py | 35 ++++++----- 3 files changed, 103 insertions(+), 39 deletions(-) diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 0d5281dc..25d1f203 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -39,5 +39,5 @@ generate_multi_columns_probe, generate_multi_shank, ) -from .library import get_probe, get_manufacturers_in_library, get_probes_in_library, get_tags_in_library +from .library import get_probe, list_manufacturers_in_library, list_probes_in_library, get_tags_in_library from .wiring import get_available_pathways diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index 13767aa4..c98fa10a 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -20,15 +20,23 @@ # OLD URL on gin # public_url = "https://web.gin.g-node.org/spikeinterface/probeinterface_library/raw/master/" - # Now on github since 2023/06/15 -public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/main/" +public_url = "https://raw.githubusercontent.com/SpikeInterface/probeinterface_library/" + # check this for windows and osx -cache_folder = Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" +def get_cache_folder() -> Path: + """Get the cache folder for probeinterface library files. + + Returns + ------- + cache_folder : Path + The path to the cache folder. + """ + return Path(os.path.expanduser("~")) / ".config" / "probeinterface" / "library" -def download_probeinterface_file(manufacturer: str, probe_name: str): +def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> None: """Download the probeinterface file to the cache directory. Note that the file is itself a ProbeGroup but on the repo each file represents one probe. @@ -39,16 +47,23 @@ def download_probeinterface_file(manufacturer: str, probe_name: str): The probe manufacturer probe_name : str (see probeinterface_libary for options) The probe name + tag : str | None, default: None + Optional tag for the probe """ - os.makedirs(cache_folder / manufacturer, exist_ok=True) - localfile = cache_folder / manufacturer / (probe_name + ".json") - distantfile = public_url + f"{manufacturer}/{probe_name}/{probe_name}.json" - dist = urlopen(distantfile) - with open(localfile, "wb") as f: - f.write(dist.read()) + cache_folder = get_cache_folder() + if tag is not None: + assert tag in get_tags_in_library(), f"Tag {tag} not found in library" + else: + tag = "main" + os.makedirs(cache_folder / tag / manufacturer, exist_ok=True) + local_file = cache_folder / tag / manufacturer / (probe_name + ".json") + remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json" + rem = urlopen(remote_file) + with open(local_file, "wb") as f: + f.write(rem.read()) -def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]: +def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]: """ Get Probe from local cache @@ -58,24 +73,66 @@ def get_from_cache(manufacturer: str, probe_name: str) -> Optional["Probe"]: The probe manufacturer probe_name : str (see probeinterface_libary for options) The probe name + tag : str | None, default: None + Optional tag for the probe Returns ------- probe : Probe object, or None if no probeinterface JSON file is found """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if not cache_folder_tag.is_dir(): + return None + cache_folder = cache_folder_tag + else: + cache_folder_tag = cache_folder / "main" - localfile = cache_folder / manufacturer / (probe_name + ".json") - if not localfile.is_file(): + local_file = cache_folder_tag / manufacturer / (probe_name + ".json") + if not local_file.is_file(): return None else: - probegroup = read_probeinterface(localfile) + probegroup = read_probeinterface(local_file) probe = probegroup.probes[0] probe._probe_group = None return probe -def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> "Probe": +def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None) -> Optional["Probe"]: + """ + Remove Probe from local cache + + Parameters + ---------- + manufacturer : "cambridgeneurotech" | "neuronexus" | "plexon" | "imec" | "sinaps" + The probe manufacturer + probe_name : str (see probeinterface_libary for options) + The probe name + tag : str | None, default: None + Optional tag for the probe + + Returns + ------- + probe : Probe object, or None if no probeinterface JSON file is found + + """ + cache_folder = get_cache_folder() + if tag is not None: + cache_folder_tag = cache_folder / tag + if not cache_folder_tag.is_dir(): + return None + cache_folder = cache_folder_tag + else: + cache_folder_tag = cache_folder / "main" + + local_file = cache_folder_tag / manufacturer / (probe_name + ".json") + if local_file.is_file(): + os.remove(local_file) + + +def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, tag: Optional[str] = None) -> "Probe": """ Get probe from ProbeInterface library @@ -87,6 +144,8 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> The probe name name : str | None, default: None Optional name for the probe + tag : str | None, default: None + Optional tag for the probe Returns ---------- @@ -94,11 +153,11 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> """ - probe = get_from_cache(manufacturer, probe_name) + probe = get_from_cache(manufacturer, probe_name, tag=tag) if probe is None: - download_probeinterface_file(manufacturer, probe_name) - probe = get_from_cache(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name, tag=tag) + probe = get_from_cache(manufacturer, probe_name, tag=tag) if probe.manufacturer == "": probe.manufacturer = manufacturer if name is not None: @@ -107,7 +166,7 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None) -> return probe -def get_manufacturers_in_library(tag=None) -> list[str]: +def list_manufacturers_in_library(tag=None) -> list[str]: """ Get the list of available manufacturers in the library @@ -119,7 +178,7 @@ def get_manufacturers_in_library(tag=None) -> list[str]: return list_github_folders("SpikeInterface", "probeinterface_library", ref=tag) -def get_probes_in_library(manufacturer: str, tag=None) -> list[str]: +def list_probes_in_library(manufacturer: str, tag=None) -> list[str]: """ Get the list of available probes for a given manufacturer @@ -157,7 +216,8 @@ def get_all_tags(owner: str, repo: str, token: str = None): """ url = f"https://api.github.com/repos/{owner}/{repo}/tags" headers = {} - if token: + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") headers["Authorization"] = f"token {token}" resp = requests.get(url, headers=headers) if resp.status_code != 200: @@ -177,7 +237,8 @@ def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, if ref: params["ref"] = ref headers = {} - if token: + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") headers["Authorization"] = f"token {token}" resp = requests.get(url, headers=headers, params=params) if resp.status_code != 200: diff --git a/tests/test_library.py b/tests/test_library.py index acf0a4d5..ab31b6a4 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -2,25 +2,20 @@ from probeinterface.library import ( download_probeinterface_file, get_from_cache, + remove_from_cache, get_probe, get_tags_in_library, - get_manufacturers_in_library, - get_probes_in_library, + list_manufacturers_in_library, + list_probes_in_library, ) -from pathlib import Path -import numpy as np - -import pytest - - manufacturer = "neuronexus" probe_name = "A1x32-Poly3-10mm-50-177" def test_download_probeinterface_file(): - download_probeinterface_file(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name, tag=None) def test_get_from_cache(): @@ -28,6 +23,14 @@ def test_get_from_cache(): probe = get_from_cache(manufacturer, probe_name) assert isinstance(probe, Probe) + tag = get_tags_in_library()[0] + probe = get_from_cache(manufacturer, probe_name, tag=tag) + assert probe is None # because we did not download with this tag + download_probeinterface_file(manufacturer, probe_name, tag=tag) + probe = get_from_cache(manufacturer, probe_name, tag=tag) + remove_from_cache(manufacturer, probe_name, tag=tag) + assert isinstance(probe, Probe) + probe = get_from_cache("yep", "yop") assert probe is None @@ -46,17 +49,17 @@ def test_available_tags(): assert len(tag) > 0 -def test_get_manufacturers_in_library(): - manufacturers = get_manufacturers_in_library() +def test_list_manufacturers_in_library(): + manufacturers = list_manufacturers_in_library() assert isinstance(manufacturers, list) assert "neuronexus" in manufacturers assert "imec" in manufacturers -def test_get_probes_in_library(): - manufacturers = get_manufacturers_in_library() +def test_list_probes_in_library(): + manufacturers = list_manufacturers_in_library() for manufacturer in manufacturers: - probes = get_probes_in_library(manufacturer) + probes = list_probes_in_library(manufacturer) assert isinstance(probes, list) assert len(probes) > 0 @@ -65,5 +68,5 @@ def test_get_probes_in_library(): test_download_probeinterface_file() test_get_from_cache() test_get_probe() - test_get_manufacturers_in_library() - test_get_probes_in_library() + test_list_manufacturers_in_library() + test_list_probes_in_library() From 03f93f9db610624e4231092ce1358b5c012b4e74 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 7 Oct 2025 10:53:53 +0200 Subject: [PATCH 6/8] Add latest_commit.txt --- src/probeinterface/library.py | 42 +++++++++++++++++++++++++++++++++++ tests/test_library.py | 28 +++++++++++++++++++++-- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index c98fa10a..c86de248 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -11,6 +11,7 @@ from __future__ import annotations import os +import warnings from pathlib import Path from urllib.request import urlopen import requests @@ -55,6 +56,7 @@ def download_probeinterface_file(manufacturer: str, probe_name: str, tag: Option assert tag in get_tags_in_library(), f"Tag {tag} not found in library" else: tag = "main" + os.makedirs(cache_folder / tag / manufacturer, exist_ok=True) local_file = cache_folder / tag / manufacturer / (probe_name + ".json") remote_file = public_url + tag + f"/{manufacturer}/{probe_name}/{probe_name}.json" @@ -88,6 +90,25 @@ def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None return None cache_folder = cache_folder_tag else: + # load latest commit if exists + commit_file = cache_folder / "main" / "latest_commit.txt" + commit = None + if commit_file.is_file(): + with open(commit_file, "r") as f: + commit = f.read().strip() + + # check against latest commit on github + try: + latest_commit = get_latest_commit("SpikeInterface", "probeinterface_library")["sha"] + if commit is None or commit != latest_commit: + # in this case we need to redownload the file and update the latest_commit.txt + with open(cache_folder / "main" / "latest_commit.txt", "w") as f: + f.write(latest_commit) + return None + except Exception: + warnings.warn("Could not check for latest commit on github. Using local 'main' cache.") + pass + cache_folder_tag = cache_folder / "main" local_file = cache_folder_tag / manufacturer / (probe_name + ".json") @@ -245,3 +266,24 @@ def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}") items = resp.json() return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."] + + +def get_latest_commit(owner: str, repo: str, branch: str = "main", token: str = None): + """ + Get the latest commit SHA and message from a given branch (default: main). + """ + url = f"https://api.github.com/repos/{owner}/{repo}/commits/{branch}" + headers = {} + if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): + token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") + headers["Authorization"] = f"token {token}" + resp = requests.get(url, headers=headers) + if resp.status_code != 200: + raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") + data = resp.json() + return { + "sha": data["sha"], + "message": data["commit"]["message"], + "author": data["commit"]["author"]["name"], + "date": data["commit"]["author"]["date"], + } diff --git a/tests/test_library.py b/tests/test_library.py index ab31b6a4..c5c4faf7 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -7,6 +7,7 @@ get_tags_in_library, list_manufacturers_in_library, list_probes_in_library, + get_cache_folder, ) @@ -18,9 +19,32 @@ def test_download_probeinterface_file(): download_probeinterface_file(manufacturer, probe_name, tag=None) +def test_latest_commit_mechanism(): + download_probeinterface_file(manufacturer, probe_name, tag=None) + cache_folder = get_cache_folder() + latest_commit_file = cache_folder / "main" / "latest_commit.txt" + if latest_commit_file.is_file(): + latest_commit_file.unlink() + + # first download + download_probeinterface_file(manufacturer, probe_name, tag=None) + assert latest_commit_file.is_file() + with open(latest_commit_file, "r") as f: + commit1 = f.read().strip() + assert len(commit1) == 40 + + # second download should not change latest_commit.txt + download_probeinterface_file(manufacturer, probe_name, tag=None) + assert latest_commit_file.is_file() + with open(latest_commit_file, "r") as f: + commit2 = f.read().strip() + assert commit1 == commit2 + + def test_get_from_cache(): - download_probeinterface_file(manufacturer, probe_name) - probe = get_from_cache(manufacturer, probe_name) + # TODO: fix this test!!! + remove_from_cache(manufacturer, probe_name) + probe = download_probeinterface_file(manufacturer, probe_name) assert isinstance(probe, Probe) tag = get_tags_in_library()[0] From 81968b8588e0597da53fd44866563f5008409491 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 7 Oct 2025 10:57:00 +0200 Subject: [PATCH 7/8] Fix tests --- tests/test_library.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/test_library.py b/tests/test_library.py index c5c4faf7..0e4e8e4a 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -20,31 +20,26 @@ def test_download_probeinterface_file(): def test_latest_commit_mechanism(): - download_probeinterface_file(manufacturer, probe_name, tag=None) + _ = get_probe(manufacturer, probe_name) cache_folder = get_cache_folder() latest_commit_file = cache_folder / "main" / "latest_commit.txt" - if latest_commit_file.is_file(): - latest_commit_file.unlink() - - # first download - download_probeinterface_file(manufacturer, probe_name, tag=None) assert latest_commit_file.is_file() - with open(latest_commit_file, "r") as f: - commit1 = f.read().strip() - assert len(commit1) == 40 - # second download should not change latest_commit.txt - download_probeinterface_file(manufacturer, probe_name, tag=None) + # now we manually change latest_commit.txt to something else + with open(latest_commit_file, "w") as f: + f.write("1234567890123456789012345678901234567890") + + # now we get the probe again and make sure the latest_commit.txt file is updated + _ = get_probe(manufacturer, probe_name) assert latest_commit_file.is_file() with open(latest_commit_file, "r") as f: - commit2 = f.read().strip() - assert commit1 == commit2 + latest_commit = f.read().strip() + assert latest_commit != "123456789012345678901234567890123456789" def test_get_from_cache(): - # TODO: fix this test!!! - remove_from_cache(manufacturer, probe_name) - probe = download_probeinterface_file(manufacturer, probe_name) + download_probeinterface_file(manufacturer, probe_name) + probe = get_from_cache(manufacturer, probe_name) assert isinstance(probe, Probe) tag = get_tags_in_library()[0] From 965ccb05d905903d55d03471ea4d85e7d632e2ef Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 7 Oct 2025 16:44:56 +0200 Subject: [PATCH 8/8] Remove latest_commit mechanism and add cache_full_library --- src/probeinterface/__init__.py | 8 +++- src/probeinterface/library.py | 71 ++++++++++++++-------------------- tests/test_library.py | 18 --------- 3 files changed, 35 insertions(+), 62 deletions(-) diff --git a/src/probeinterface/__init__.py b/src/probeinterface/__init__.py index 25d1f203..7998e142 100644 --- a/src/probeinterface/__init__.py +++ b/src/probeinterface/__init__.py @@ -39,5 +39,11 @@ generate_multi_columns_probe, generate_multi_shank, ) -from .library import get_probe, list_manufacturers_in_library, list_probes_in_library, get_tags_in_library +from .library import ( + get_probe, + list_manufacturers_in_library, + list_probes_in_library, + get_tags_in_library, + cache_full_library, +) from .wiring import get_available_pathways diff --git a/src/probeinterface/library.py b/src/probeinterface/library.py index c86de248..ff770833 100644 --- a/src/probeinterface/library.py +++ b/src/probeinterface/library.py @@ -90,25 +90,6 @@ def get_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = None return None cache_folder = cache_folder_tag else: - # load latest commit if exists - commit_file = cache_folder / "main" / "latest_commit.txt" - commit = None - if commit_file.is_file(): - with open(commit_file, "r") as f: - commit = f.read().strip() - - # check against latest commit on github - try: - latest_commit = get_latest_commit("SpikeInterface", "probeinterface_library")["sha"] - if commit is None or commit != latest_commit: - # in this case we need to redownload the file and update the latest_commit.txt - with open(cache_folder / "main" / "latest_commit.txt", "w") as f: - f.write(latest_commit) - return None - except Exception: - warnings.warn("Could not check for latest commit on github. Using local 'main' cache.") - pass - cache_folder_tag = cache_folder / "main" local_file = cache_folder_tag / manufacturer / (probe_name + ".json") @@ -153,7 +134,13 @@ def remove_from_cache(manufacturer: str, probe_name: str, tag: Optional[str] = N os.remove(local_file) -def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, tag: Optional[str] = None) -> "Probe": +def get_probe( + manufacturer: str, + probe_name: str, + name: Optional[str] = None, + tag: Optional[str] = None, + force_download: bool = False, +) -> "Probe": """ Get probe from ProbeInterface library @@ -167,14 +154,18 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, ta Optional name for the probe tag : str | None, default: None Optional tag for the probe + force_download : bool, default: False + If True, force re-download of the probe file. Returns ---------- probe : Probe object """ - - probe = get_from_cache(manufacturer, probe_name, tag=tag) + if not force_download: + probe = get_from_cache(manufacturer, probe_name, tag=tag) + else: + probe = None if probe is None: download_probeinterface_file(manufacturer, probe_name, tag=tag) @@ -187,6 +178,21 @@ def get_probe(manufacturer: str, probe_name: str, name: Optional[str] = None, ta return probe +def cache_full_library(tag=None) -> None: + """ + Download all probes from the library to the cache directory. + """ + manufacturers = list_manufacturers_in_library(tag=tag) + + for manufacturer in manufacturers: + probes = list_probes_in_library(manufacturer, tag=tag) + for probe_name in probes: + try: + download_probeinterface_file(manufacturer, probe_name, tag=tag) + except Exception as e: + warnings.warn(f"Could not download {manufacturer}/{probe_name} (tag: {tag}): {e}") + + def list_manufacturers_in_library(tag=None) -> list[str]: """ Get the list of available manufacturers in the library @@ -266,24 +272,3 @@ def list_github_folders(owner: str, repo: str, path: str = "", ref: str = None, raise RuntimeError(f"GitHub API returned status {resp.status_code}: {resp.text}") items = resp.json() return [item["name"] for item in items if item.get("type") == "dir" and item["name"][0] != "."] - - -def get_latest_commit(owner: str, repo: str, branch: str = "main", token: str = None): - """ - Get the latest commit SHA and message from a given branch (default: main). - """ - url = f"https://api.github.com/repos/{owner}/{repo}/commits/{branch}" - headers = {} - if token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN"): - token = token or os.getenv("GH_TOKEN") or os.getenv("GITHUB_TOKEN") - headers["Authorization"] = f"token {token}" - resp = requests.get(url, headers=headers) - if resp.status_code != 200: - raise RuntimeError(f"GitHub API returned {resp.status_code}: {resp.text}") - data = resp.json() - return { - "sha": data["sha"], - "message": data["commit"]["message"], - "author": data["commit"]["author"]["name"], - "date": data["commit"]["author"]["date"], - } diff --git a/tests/test_library.py b/tests/test_library.py index 0e4e8e4a..2e902162 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -19,24 +19,6 @@ def test_download_probeinterface_file(): download_probeinterface_file(manufacturer, probe_name, tag=None) -def test_latest_commit_mechanism(): - _ = get_probe(manufacturer, probe_name) - cache_folder = get_cache_folder() - latest_commit_file = cache_folder / "main" / "latest_commit.txt" - assert latest_commit_file.is_file() - - # now we manually change latest_commit.txt to something else - with open(latest_commit_file, "w") as f: - f.write("1234567890123456789012345678901234567890") - - # now we get the probe again and make sure the latest_commit.txt file is updated - _ = get_probe(manufacturer, probe_name) - assert latest_commit_file.is_file() - with open(latest_commit_file, "r") as f: - latest_commit = f.read().strip() - assert latest_commit != "123456789012345678901234567890123456789" - - def test_get_from_cache(): download_probeinterface_file(manufacturer, probe_name) probe = get_from_cache(manufacturer, probe_name)