Skip to content

Commit 5e1c110

Browse files
authored
Extend lookup table model (#142)
* Extend lookup table model * Add tests * PR comments
1 parent fe41594 commit 5e1c110

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

src/daq_config_server/models/converters/lookup_tables/_models.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,45 @@ def check_row_length_matches_n_columns(self):
1717
f"of columns: {self.column_names}"
1818
)
1919
return self
20+
21+
def get_value(
22+
self,
23+
column_name: str,
24+
value: int | float,
25+
target_column_name: str,
26+
value_must_exist: bool = True,
27+
) -> int | float:
28+
"""Look up a value in one columna nd return the corresponding entry from another
29+
column.
30+
31+
Args:
32+
column_name (str): The name of the column to search in.
33+
value (int | float): The numeric value to look for within `column_name`.
34+
target_column_name (str): The name of the column from which to return the
35+
corresponding entry (same row as the matched value).
36+
value_must_exist (bool, optional): If true, value must exist in the LUT or
37+
an error will be thrown. Otherwise, the closest value will be used.
38+
Defaults to True.
39+
40+
Returns:
41+
int | float: The entry from `target_column_name` in the row where the
42+
matching (or closest) value was found in `column_name`.
43+
"""
44+
column_index = self.column_names.index(column_name)
45+
column = [row[column_index] for row in self.rows]
46+
target_column_index = self.column_names.index(target_column_name)
47+
48+
closest_value = (
49+
min(column, key=lambda x: abs(x - value)) if not value_must_exist else value
50+
)
51+
try:
52+
target_row = self.rows[column.index(closest_value)]
53+
except ValueError as e:
54+
raise ValueError(
55+
f"'{closest_value}' doesn't exist in column '{column_name}': {column}"
56+
) from e
57+
58+
return target_row[target_column_index]
59+
60+
def columns(self) -> list[list[int | float]]:
61+
return [[row[i] for row in self.rows] for i in range(len(self.column_names))]

tests/unit_tests/converters/test_lookup_tables_converters.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
import pytest
24
from tests.constants import TestDataPaths
35

@@ -113,3 +115,48 @@ def test_undulator_gap_lut_gives_expected_result():
113115
)
114116
result = undulator_energy_gap_lut(input)
115117
assert result == expected
118+
119+
120+
@pytest.mark.parametrize(
121+
"args, expected_value",
122+
[
123+
(("detector_distances_mm", 150, "beam_centre_x_mm", True), 152.2),
124+
(("beam_centre_y_mm", 160.96, "detector_distances_mm", True), 800),
125+
(
126+
("beam_centre_x_mm", 153, "beam_centre_y_mm", False),
127+
166.26, # get closest value when value_must_exist == False
128+
),
129+
],
130+
)
131+
def test_generic_lut_model_get_value_function(
132+
args: tuple[str, int | float, str, bool], expected_value: int | float
133+
):
134+
my_lut = GenericLookupTable(
135+
column_names=["detector_distances_mm", "beam_centre_x_mm", "beam_centre_y_mm"],
136+
rows=[[150, 152.2, 166.26], [800, 152.08, 160.96]],
137+
)
138+
assert my_lut.get_value(*args) == expected_value
139+
140+
141+
def test_generic_lut_model_get_value_errors_if_value_doesnt_exist():
142+
my_lut = GenericLookupTable(
143+
column_names=["detector_distances_mm", "beam_centre_x_mm", "beam_centre_y_mm"],
144+
rows=[[150, 152.2, 166.26], [800, 152.08, 160.96]],
145+
)
146+
with pytest.raises(
147+
ValueError,
148+
match=re.escape(
149+
"'160.97' doesn't exist in column 'beam_centre_y_mm': [166.26, 160.96]"
150+
),
151+
):
152+
# value doesn't exist
153+
my_lut.get_value("beam_centre_y_mm", 160.97, "detector_distances_mm")
154+
155+
156+
def test_generic_lut_model_columns_function():
157+
my_lut = GenericLookupTable(
158+
column_names=["detector_distances_mm", "beam_centre_x_mm", "beam_centre_y_mm"],
159+
rows=[[150, 152.2, 166.26], [800, 152.08, 160.96]],
160+
)
161+
expected_columns = [[150, 800], [152.2, 152.08], [166.26, 160.96]]
162+
assert my_lut.columns() == expected_columns

0 commit comments

Comments
 (0)