Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/resins/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from collections import ChainMap
import dataclasses
import os
import importlib

import numpy as np
import yaml
Expand All @@ -18,8 +18,6 @@
from .models.model_base import ModelData, InstrumentModel
from inspect import Signature

INSTRUMENT_DATA_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'instrument_data')

INSTRUMENT_MAP: dict[str, tuple[str, None | str]] = {
'ARCS': ('arcs.yaml', None),
'CNCS': ('cncs.yaml', None),
Expand Down Expand Up @@ -360,7 +358,7 @@ def _get_file(instrument_name: str) -> tuple[str, Union[str, None]]:
f'"{instrument_name}" is not a valid instrument name. Only the following instruments are '
f'supported: {list(INSTRUMENT_MAP)}')

return os.path.join(INSTRUMENT_DATA_PATH, file_name), implied_version
return str(importlib.resources.files("resins.instrument_data") / file_name), implied_version

def get_model_data(self, model_name: Optional[str] = None, **kwargs) -> ModelData:
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/test_ideal_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_bad_width_boxcar():
- boxcar values are normalised based on _actual_ kernel area
"""
result, _ = _get_data("boxcar", Feature.KERNEL, case_name="bad_width_kernel")
assert_allclose(result, np.load(DATA_PATH / f"_get_boxcar_kernel.npy"))
assert_allclose(result, np.load(DATA_PATH / "_get_boxcar_kernel.npy"))


def test_bad_width_triangle():
Expand All @@ -133,7 +133,7 @@ def test_bad_width_triangle():
"""
result, mesh = _get_data("triangle", Feature.KERNEL, case_name="bad_width_kernel")

ref_triangle = np.load(DATA_PATH / f"_get_triangle_kernel.npy")
ref_triangle = np.load(DATA_PATH / "_get_triangle_kernel.npy")

assert np.flatnonzero(result).tolist() == np.flatnonzero(ref_triangle).tolist()

Expand All @@ -151,7 +151,7 @@ def test_bad_width_trapezoid():
"""
result, mesh = _get_data("trapezoid", Feature.PEAK, case_name="bad_width_peak")

ref = np.load(DATA_PATH / f"_get_trapezoid_peak.npy")
ref = np.load(DATA_PATH / "_get_trapezoid_peak.npy")

# Same non-zero points
assert np.flatnonzero(result).tolist() == np.flatnonzero(ref).tolist()
Expand Down
88 changes: 38 additions & 50 deletions tests/unit_tests/test_instrument.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from dataclasses import dataclass
import inspect
import os
from pathlib import Path
import typing

import numpy as np
import pytest
import yaml

from resins import instrument as i
from resins.models import MODELS
from resins.models.model_base import ModelData, InstrumentModel


TEST_DIR = os.path.dirname(os.path.abspath(__file__))
FAKE_YAML = os.path.join(TEST_DIR, 'fake_instrument.yaml')
TEST_DIR = Path(__file__).parent
FAKE_YAML = TEST_DIR / 'fake_instrument.yaml'


@dataclass(init=True, repr=True, frozen=True, slots=True)
Expand Down Expand Up @@ -45,36 +45,26 @@ def __init__(self,
self.kwarg1 = kwarg1
self.kwarg2 = kwarg2

def get_characteristics(self, energy_transfer):
return {}

def __call__(self, frequencies, *args, **kwargs):
return frequencies

def get_characteristics(self, omega_q):
return {"sigma": np.ones((len(omega_q), 1))}
def get_characteristics(self, points):
return {"sigma": np.ones((len(points), 1))}

def get_kernel(self, omega_q, mesh):
def get_kernel(self, points, mesh):
return np.zeros_like(mesh)

def get_peak(self, omega_q, mesh):
def get_peak(self, points, mesh):
return np.zeros_like(mesh)

def broaden(self, omega_q, data, mesh):
def broaden(self, points, data, mesh):
return np.zeros_like(mesh)


@pytest.fixture
def mock_models():
return {'mock': MockModel}

@pytest.fixture
def mock_instrument_map():
return {
'TEST': ('fake_instrument.yaml', None),
'ALIAS': ('fake_instrument.yaml', 'VERSION1'),
}


@pytest.fixture(scope='module')
def data():
Expand Down Expand Up @@ -127,35 +117,45 @@ def test_instrument(data):
)


def test_available_instruments(mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
assert i.Instrument.available_instruments() == ['TEST', 'ALIAS']
@pytest.fixture
def patch_importlib_files(mocker):
"""Make all calls to importlib.resources.files() return TEST_DIR

This is used to replace resins.instrument_data with test data
"""
mocker.patch('importlib.resources.files', lambda _: Path(TEST_DIR))


@pytest.fixture
def patch_instrument_map(mocker):
"""Replace main INSTRUMENT_MAP with test examples"""

instrument_map = {
'TEST': ('fake_instrument.yaml', None),
'ALIAS': ('fake_instrument.yaml', 'VERSION1'),
}
mocker.patch('resins.instrument.INSTRUMENT_MAP', instrument_map)


def test_available_instruments(patch_instrument_map):
assert i.Instrument.available_instruments() == ['TEST', 'ALIAS']

def test_private_available_versions(mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_private_available_versions(patch_instrument_map, patch_importlib_files):
actual_versions, actual_default = i.Instrument._available_versions(FAKE_YAML)

assert actual_default == 'TEST'
assert actual_versions == ['VERSION1', 'TEST']


def test_available_versions(mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_available_versions(patch_instrument_map, patch_importlib_files):
actual_versions, actual_default = i.Instrument.available_versions('TEST')

assert actual_default == 'TEST'
assert actual_versions == ['VERSION1', 'TEST']


def test_available_versions_alias(mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_available_versions_alias(patch_instrument_map, patch_importlib_files):
actual_versions, actual_default = i.Instrument.available_versions('ALIAS')

assert actual_default == 'VERSION1'
Expand Down Expand Up @@ -185,10 +185,7 @@ def test_from_file_invalid_version():
i.Instrument.from_file(FAKE_YAML, 'INVALID_VERSION')


def test_from_default(data, mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_from_default(data, patch_instrument_map, patch_importlib_files):
instrument = i.Instrument.from_default('TEST', 'VERSION1')

assert isinstance(instrument, i.Instrument)
Expand All @@ -197,10 +194,7 @@ def test_from_default(data, mock_instrument_map, mocker):
assert instrument._models == data['version']['VERSION1']['models']


def test_from_default_default(data, mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_from_default_default(data, patch_instrument_map, patch_importlib_files):
instrument = i.Instrument.from_default('TEST')

assert isinstance(instrument, i.Instrument)
Expand All @@ -211,20 +205,14 @@ def test_from_default_default(data, mock_instrument_map, mocker):

@pytest.mark.parametrize("name,expected_path,implied_ver",
[('TEST', FAKE_YAML, None), ('ALIAS', FAKE_YAML, 'VERSION1')])
def test_get_file(name, expected_path, implied_ver, mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_get_file(name, expected_path, implied_ver, patch_instrument_map, patch_importlib_files):
actual_path, actual_version = i.Instrument._get_file(name)

assert actual_path == expected_path
assert Path(actual_path) == expected_path
assert actual_version == implied_ver


def test_get_file_invalid(mock_instrument_map, mocker):
mocker.patch('resins.instrument.INSTRUMENT_MAP', mock_instrument_map)
mocker.patch('resins.instrument.INSTRUMENT_DATA_PATH', TEST_DIR)

def test_get_file_invalid(patch_instrument_map, patch_importlib_files):
with pytest.raises(i.InvalidInstrumentError):
i.Instrument._get_file('INVALID_INSTRUMENT')

Expand Down
1 change: 0 additions & 1 deletion tests/unit_tests/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
from numpy.testing import assert_allclose
import pytest

from resins.models.mixins import GaussianKernel1DMixin, SimpleBroaden1DMixin

Expand Down
Loading