|
| 1 | +import sys |
| 2 | +from importlib import import_module |
| 3 | +from importlib.util import find_spec |
| 4 | +from pathlib import Path |
| 5 | +from types import FunctionType, ModuleType |
| 6 | +from typing import Dict, List |
| 7 | + |
| 8 | +__all__ = ["category_to_funcs", "array", "extension_to_funcs"] |
| 9 | + |
| 10 | + |
| 11 | +spec_dir = Path(__file__).parent / "array-api" / "spec" / "API_specification" |
| 12 | +assert spec_dir.exists(), f"{spec_dir} not found - try `git pull --recurse-submodules`" |
| 13 | +sigs_dir = spec_dir / "signatures" |
| 14 | +assert sigs_dir.exists() |
| 15 | + |
| 16 | +spec_abs_path: str = str(spec_dir.resolve()) |
| 17 | +sys.path.append(spec_abs_path) |
| 18 | +assert find_spec("signatures") is not None |
| 19 | + |
| 20 | +name_to_mod: Dict[str, ModuleType] = {} |
| 21 | +for path in sigs_dir.glob("*.py"): |
| 22 | + name = path.name.replace(".py", "") |
| 23 | + name_to_mod[name] = import_module(f"signatures.{name}") |
| 24 | + |
| 25 | + |
| 26 | +category_to_funcs: Dict[str, List[FunctionType]] = {} |
| 27 | +for name, mod in name_to_mod.items(): |
| 28 | + if name.endswith("_functions"): |
| 29 | + category = name.replace("_functions", "") |
| 30 | + objects = [getattr(mod, name) for name in mod.__all__] |
| 31 | + assert all(isinstance(o, FunctionType) for o in objects) |
| 32 | + category_to_funcs[category] = objects |
| 33 | + |
| 34 | + |
| 35 | +array = name_to_mod["array_object"].array |
| 36 | + |
| 37 | + |
| 38 | +EXTENSIONS = ["linalg"] |
| 39 | +extension_to_funcs: Dict[str, List[FunctionType]] = {} |
| 40 | +for ext in EXTENSIONS: |
| 41 | + mod = name_to_mod[ext] |
| 42 | + objects = [getattr(mod, name) for name in mod.__all__] |
| 43 | + assert all(isinstance(o, FunctionType) for o in objects) |
| 44 | + extension_to_funcs[ext] = objects |
| 45 | + |
| 46 | + |
| 47 | +sys.path.remove(spec_abs_path) |
0 commit comments