Skip to content

Commit 75da942

Browse files
committed
Modernise: importlib, pathlib, omega_q becomes "points"
Refactored instrument tests a bit to patch importlib.resources.files() as a fixture. This is a bit more risky than changing our own INSTRUMENT_DATA_PATH variable; if we run into conflict with other uses for this function then we might need to reintroduce the variable. Alternatively we could consider using an environment variable so additional paths can be set at runtime.
1 parent c87c371 commit 75da942

File tree

3 files changed

+41
-55
lines changed

3 files changed

+41
-55
lines changed

src/resins/instrument.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections import ChainMap
88
import dataclasses
9-
import os
9+
import importlib
1010

1111
import numpy as np
1212
import yaml
@@ -18,8 +18,6 @@
1818
from .models.model_base import ModelData, InstrumentModel
1919
from inspect import Signature
2020

21-
INSTRUMENT_DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instrument_data')
22-
2321
INSTRUMENT_MAP: dict[str, tuple[str, None | str]] = {
2422
'ARCS': ('arcs.yaml', None),
2523
'CNCS': ('cncs.yaml', None),
@@ -360,7 +358,7 @@ def _get_file(instrument_name: str) -> tuple[str, Union[str, None]]:
360358
f'"{instrument_name}" is not a valid instrument name. Only the following instruments are '
361359
f'supported: {list(INSTRUMENT_MAP)}')
362360

363-
return os.path.join(INSTRUMENT_DATA_PATH, file_name), implied_version
361+
return str(importlib.resources.files("resins.instrument_data") / file_name), implied_version
364362

365363
def get_model_data(self, model_name: Optional[str] = None, **kwargs) -> ModelData:
366364
"""

src/resins/models/lookuptables.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from __future__ import annotations
99

1010
from dataclasses import dataclass
11+
import importlib
1112
from typing import ClassVar, TYPE_CHECKING
1213

1314
import numpy as np
14-
from numpy.polynomial.polynomial import Polynomial
1515

1616
from .model_base import InstrumentModel, ModelData
1717
from .mixins import SimpleBroaden1DMixin
18-
from ..instrument import INSTRUMENT_DATA_PATH
1918

2019
if TYPE_CHECKING:
2120
from jaxtyping import Float
@@ -67,16 +66,14 @@ class ScaledTabulatedModel(SimpleBroaden1DMixin, InstrumentModel):
6766
"""
6867
input = ('energy_transfer',)
6968

70-
data_class: ClassVar[type[ScaledTabultedModelData]] = ScaledTabulatedModelData
69+
data_class: ClassVar[type[ScaledTabulatedModelData]] = ScaledTabulatedModelData
7170

7271
def __init__(self, model_data: ScaledTabulatedModelData, **_):
73-
from pathlib import Path
74-
7572
from numpy.polynomial import Polynomial
7673
from scipy.interpolate import RegularGridInterpolator
7774

7875
super().__init__(model_data)
79-
self.data = np.load(Path(INSTRUMENT_DATA_PATH) / model_data.npz)
76+
self.data = np.load(importlib.resources.files("resins.instrument_data") / model_data.npz)
8077

8178
self.polynomial = Polynomial(coef=self.data["coef"],
8279
domain=self.data["domain"],
@@ -110,7 +107,7 @@ def get_characteristics(self, omega_q: Float[np.ndarray, 'energy_transfer dimens
110107
return {'sigma': self.polynomial(omega_q[:, 0])}
111108

112109
def get_kernel(self,
113-
omega_q: Float[np.ndarray, 'sample dimension=1'],
110+
points: Float[np.ndarray, 'sample dimension=1'],
114111
mesh: Float[np.ndarray, 'mesh'],
115112
) -> Float[np.ndarray, 'sample mesh']:
116113

@@ -130,14 +127,14 @@ def get_kernel(self,
130127
return interp_kernels
131128

132129
def get_peak(self,
133-
omega_q: Float[np.ndarray, 'sample dimension=1'],
130+
points: Float[np.ndarray, 'sample dimension=1'],
134131
mesh: Float[np.ndarray, 'mesh'],
135132
) -> Float[np.ndarray, 'sample mesh']:
136133
shifted_meshes = [mesh - energy for energy in omega_q[:, 0]]
137134

138135
shifted_kernels = [
139-
self.get_kernel(np.array([omega_q]), shifted_mesh)
140-
for omega_q, shifted_mesh in zip(omega_q, shifted_meshes)
136+
self.get_kernel(np.array([point]), shifted_mesh)
137+
for point, shifted_mesh in zip(points, shifted_meshes)
141138
]
142139

143140
return np.array(np.vstack(shifted_kernels))

tests/unit_tests/test_instrument.py

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
import inspect
3-
import os
3+
from pathlib import Path
44
import typing
55

66
import pytest
@@ -11,8 +11,8 @@
1111
from resins.models.model_base import ModelData, InstrumentModel
1212

1313

14-
TEST_DIR = os.path.dirname(os.path.abspath(__file__))
15-
FAKE_YAML = os.path.join(TEST_DIR, 'fake_instrument.yaml')
14+
TEST_DIR = Path(__file__).parent
15+
FAKE_YAML = TEST_DIR / 'fake_instrument.yaml'
1616

1717

1818
@dataclass(init=True, repr=True, frozen=True, slots=True)
@@ -68,13 +68,6 @@ def broaden(self, omega_q, data, mesh):
6868
def mock_models():
6969
return {'mock': MockModel}
7070

71-
@pytest.fixture
72-
def mock_instrument_map():
73-
return {
74-
'TEST': ('fake_instrument.yaml', None),
75-
'ALIAS': ('fake_instrument.yaml', 'VERSION1'),
76-
}
77-
7871

7972
@pytest.fixture(scope='module')
8073
def data():
@@ -127,35 +120,45 @@ def test_instrument(data):
127120
)
128121

129122

130-
def test_available_instruments(mock_instrument_map, mocker):
131-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
132-
assert i.Instrument.available_instruments() == ['TEST', 'ALIAS']
123+
@pytest.fixture
124+
def patch_importlib_files(mocker):
125+
"""Make all calls to importlib.resources.files() return TEST_DIR
126+
127+
This is used to replace resins.instrument_data with test data
128+
"""
129+
mocker.patch('importlib.resources.files', lambda _: Path(TEST_DIR))
133130

134131

135-
def test_private_available_versions(mock_instrument_map, mocker):
136-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
137-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
132+
@pytest.fixture
133+
def patch_instrument_map(mocker):
134+
"""Replace main INSTRUMENT_MAP with test examples"""
138135

136+
instrument_map = {
137+
'TEST': ('fake_instrument.yaml', None),
138+
'ALIAS': ('fake_instrument.yaml', 'VERSION1'),
139+
}
140+
mocker.patch('resins.instrument.INSTRUMENT_MAP', instrument_map)
141+
142+
143+
def test_available_instruments(patch_instrument_map):
144+
assert i.Instrument.available_instruments() == ['TEST', 'ALIAS']
145+
146+
147+
def test_private_available_versions(patch_instrument_map, patch_importlib_files):
139148
actual_versions, actual_default = i.Instrument._available_versions(FAKE_YAML)
140149

141150
assert actual_default == 'TEST'
142151
assert actual_versions == ['VERSION1', 'TEST']
143152

144153

145-
def test_available_versions(mock_instrument_map, mocker):
146-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
147-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
148-
154+
def test_available_versions(patch_instrument_map, patch_importlib_files):
149155
actual_versions, actual_default = i.Instrument.available_versions('TEST')
150156

151157
assert actual_default == 'TEST'
152158
assert actual_versions == ['VERSION1', 'TEST']
153159

154160

155-
def test_available_versions_alias(mock_instrument_map, mocker):
156-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
157-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
158-
161+
def test_available_versions_alias(patch_instrument_map, patch_importlib_files):
159162
actual_versions, actual_default = i.Instrument.available_versions('ALIAS')
160163

161164
assert actual_default == 'VERSION1'
@@ -185,10 +188,7 @@ def test_from_file_invalid_version():
185188
i.Instrument.from_file(FAKE_YAML, 'INVALID_VERSION')
186189

187190

188-
def test_from_default(data, mock_instrument_map, mocker):
189-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
190-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
191-
191+
def test_from_default(data, patch_instrument_map, patch_importlib_files):
192192
instrument = i.Instrument.from_default('TEST', 'VERSION1')
193193

194194
assert isinstance(instrument, i.Instrument)
@@ -197,10 +197,7 @@ def test_from_default(data, mock_instrument_map, mocker):
197197
assert instrument._models == data['version']['VERSION1']['models']
198198

199199

200-
def test_from_default_default(data, mock_instrument_map, mocker):
201-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
202-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
203-
200+
def test_from_default_default(data, patch_instrument_map, patch_importlib_files):
204201
instrument = i.Instrument.from_default('TEST')
205202

206203
assert isinstance(instrument, i.Instrument)
@@ -211,20 +208,14 @@ def test_from_default_default(data, mock_instrument_map, mocker):
211208

212209
@pytest.mark.parametrize("name,expected_path,implied_ver",
213210
[('TEST', FAKE_YAML, None), ('ALIAS', FAKE_YAML, 'VERSION1')])
214-
def test_get_file(name, expected_path, implied_ver, mock_instrument_map, mocker):
215-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
216-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
217-
211+
def test_get_file(name, expected_path, implied_ver, patch_instrument_map, patch_importlib_files):
218212
actual_path, actual_version = i.Instrument._get_file(name)
219213

220-
assert actual_path == expected_path
214+
assert Path(actual_path) == expected_path
221215
assert actual_version == implied_ver
222216

223217

224-
def test_get_file_invalid(mock_instrument_map, mocker):
225-
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
226-
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)
227-
218+
def test_get_file_invalid(patch_instrument_map, patch_importlib_files):
228219
with pytest.raises(i.InvalidInstrumentError):
229220
i.Instrument._get_file('INVALID_INSTRUMENT')
230221

0 commit comments

Comments
 (0)