Skip to content

Commit 61eec43

Browse files
committed
Add array-api spec as submodule, load its signatures
1 parent 41c338b commit 61eec43

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "array_api_tests/array-api"]
2+
path = array_api_tests/array-api
3+
url = https://github.com/data-apis/array-api/

array_api_tests/array-api

Submodule array-api added at 2b9c402

array_api_tests/stubs.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import sys
2+
from importlib import import_module
3+
from importlib.util import find_spec
4+
from pathlib import Path
5+
from types import FunctionType, ModuleType
6+
from typing import Dict, List
7+
8+
__all__ = ["category_to_funcs", "array", "extension_to_funcs"]
9+
10+
11+
spec_dir = Path(__file__).parent / "array-api" / "spec" / "API_specification"
12+
assert spec_dir.exists(), f"{spec_dir} not found - try `git pull --recurse-submodules`"
13+
sigs_dir = spec_dir / "signatures"
14+
assert sigs_dir.exists()
15+
16+
spec_abs_path: str = str(spec_dir.resolve())
17+
sys.path.append(spec_abs_path)
18+
assert find_spec("signatures") is not None
19+
20+
name_to_mod: Dict[str, ModuleType] = {}
21+
for path in sigs_dir.glob("*.py"):
22+
name = path.name.replace(".py", "")
23+
name_to_mod[name] = import_module(f"signatures.{name}")
24+
25+
26+
category_to_funcs: Dict[str, List[FunctionType]] = {}
27+
for name, mod in name_to_mod.items():
28+
if name.endswith("_functions"):
29+
category = name.replace("_functions", "")
30+
objects = [getattr(mod, name) for name in mod.__all__]
31+
assert all(isinstance(o, FunctionType) for o in objects)
32+
category_to_funcs[category] = objects
33+
34+
35+
array = name_to_mod["array_object"].array
36+
37+
38+
EXTENSIONS = ["linalg"]
39+
extension_to_funcs: Dict[str, List[FunctionType]] = {}
40+
for ext in EXTENSIONS:
41+
mod = name_to_mod[ext]
42+
objects = [getattr(mod, name) for name in mod.__all__]
43+
assert all(isinstance(o, FunctionType) for o in objects)
44+
extension_to_funcs[ext] = objects
45+
46+
47+
sys.path.remove(spec_abs_path)

0 commit comments

Comments
 (0)