Skip to content

Commit 02d13a6

Browse files
committed
add data-utils for cec test functions
1 parent ace21ac commit 02d13a6

File tree

2 files changed

+148
-6
lines changed

2 files changed

+148
-6
lines changed

src/surfaces/test_functions/cec/_base_cec.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44

55
"""Base class for all CEC competition benchmark functions."""
66

7+
from __future__ import annotations
8+
79
from abc import abstractmethod
810
from pathlib import Path
911
from typing import Any, Dict, Optional, Tuple
1012

1113
import numpy as np
1214

1315
from ..algebraic._base_algebraic_function import AlgebraicFunction
16+
from ._data_utils import get_data_file
1417

1518

1619
class CECFunction(AlgebraicFunction):
@@ -104,8 +107,8 @@ def _data_dir(self) -> Path:
104107
def _load_data(self) -> Dict[str, np.ndarray]:
105108
"""Load rotation matrices and shift vectors for this dimension.
106109
107-
Data files are fetched from GitHub releases on first use and cached
108-
locally. Subsequent calls use the cached files.
110+
Data files are loaded from the surfaces-cec-data package or local
111+
development directory. Results are cached in memory.
109112
110113
Returns
111114
-------
@@ -115,14 +118,14 @@ def _load_data(self) -> Dict[str, np.ndarray]:
115118
Raises
116119
------
117120
ImportError
118-
If pooch is not installed.
121+
If surfaces-cec-data is not installed.
122+
FileNotFoundError
123+
If the data file is not found.
119124
"""
120125
cache_key = (self.data_prefix, self.n_dim)
121126
if cache_key not in self._data_cache:
122-
from ..._data import fetch_file
123-
124127
filename = f"{self.data_prefix}_data_dim{self.n_dim}.npz"
125-
data_file = fetch_file(self.data_prefix, filename)
128+
data_file = get_data_file(self.data_prefix, filename)
126129
self._data_cache[cache_key] = dict(np.load(data_file))
127130
return self._data_cache[cache_key]
128131

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Author: Simon Blanke
2+
3+
# License: MIT License
4+
5+
"""Utilities for loading CEC benchmark data files."""
6+
7+
from __future__ import annotations
8+
9+
import sys
10+
from pathlib import Path
11+
from typing import Optional
12+
13+
# Data package name for CEC benchmark data
14+
_DATA_PACKAGE = "surfaces_cec_data"
15+
16+
17+
def _get_local_data_path(dataset: str, filename: str) -> Optional[Path]:
18+
"""Check if file exists in the local data package directory (for development)."""
19+
# This file is at src/surfaces/test_functions/cec/_data_utils.py
20+
# Need 5 parents to get to repo root
21+
repo_root = Path(__file__).parent.parent.parent.parent.parent
22+
local_path = (
23+
repo_root
24+
/ "data-packages"
25+
/ "surfaces-cec-data"
26+
/ "src"
27+
/ "surfaces_cec_data"
28+
/ dataset
29+
/ filename
30+
)
31+
if local_path.exists():
32+
return local_path
33+
return None
34+
35+
36+
def _get_installed_data_path(dataset: str, filename: str) -> Optional[Path]:
37+
"""Get file path from the installed surfaces-cec-data package."""
38+
try:
39+
if sys.version_info >= (3, 9):
40+
from importlib.resources import as_file, files
41+
42+
resource = files(_DATA_PACKAGE).joinpath(dataset, filename)
43+
try:
44+
with as_file(resource) as path:
45+
if path.exists():
46+
return path
47+
except (TypeError, FileNotFoundError):
48+
return None
49+
else:
50+
# Python 3.8 fallback
51+
try:
52+
from importlib_resources import as_file, files
53+
54+
resource = files(_DATA_PACKAGE).joinpath(dataset, filename)
55+
with as_file(resource) as path:
56+
if path.exists():
57+
return path
58+
except ImportError:
59+
import importlib.resources as pkg_resources
60+
61+
try:
62+
with pkg_resources.path(f"{_DATA_PACKAGE}.{dataset}", filename) as path:
63+
if path.exists():
64+
return path
65+
except (ModuleNotFoundError, FileNotFoundError, TypeError):
66+
return None
67+
except ModuleNotFoundError:
68+
return None
69+
return None
70+
71+
72+
def _is_data_package_installed() -> bool:
73+
"""Check if the surfaces-cec-data package is installed."""
74+
try:
75+
if sys.version_info >= (3, 9):
76+
from importlib.resources import files
77+
78+
files(_DATA_PACKAGE)
79+
else:
80+
try:
81+
from importlib_resources import files
82+
83+
files(_DATA_PACKAGE)
84+
except ImportError:
85+
import importlib
86+
87+
importlib.import_module(_DATA_PACKAGE)
88+
return True
89+
except ModuleNotFoundError:
90+
return False
91+
92+
93+
def get_data_file(dataset: str, filename: str) -> Path:
94+
"""Get path to a CEC data file.
95+
96+
Checks local development directory first, then installed package.
97+
98+
Parameters
99+
----------
100+
dataset : str
101+
Name of the dataset (e.g., "cec2014", "cec2017").
102+
filename : str
103+
Name of the data file (e.g., "cec2014_data_dim10.npz").
104+
105+
Returns
106+
-------
107+
Path
108+
Path to the data file.
109+
110+
Raises
111+
------
112+
ImportError
113+
If surfaces-cec-data is not installed and local files not found.
114+
FileNotFoundError
115+
If the specific file is not found.
116+
"""
117+
# Check local development directory first
118+
local_path = _get_local_data_path(dataset, filename)
119+
if local_path is not None:
120+
return local_path
121+
122+
# Check installed package
123+
installed_path = _get_installed_data_path(dataset, filename)
124+
if installed_path is not None:
125+
return installed_path
126+
127+
# Neither found - provide helpful error
128+
if not _is_data_package_installed():
129+
raise ImportError(
130+
"CEC benchmark data files are not available.\n"
131+
"Install the data package with: pip install surfaces-cec-data\n"
132+
"Or install surfaces with CEC support: pip install surfaces[cec]"
133+
)
134+
else:
135+
raise FileNotFoundError(
136+
f"Data file not found: {dataset}/{filename}\n"
137+
"The surfaces-cec-data package may be outdated. "
138+
"Try: pip install -U surfaces-cec-data"
139+
)

0 commit comments

Comments
 (0)