Skip to content

Commit 10d2713

Browse files
authored
Merge branch 'main' into add_i09_1_converter
2 parents d7a1f49 + 635235f commit 10d2713

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

src/daq_config_server/models/converters/lookup_tables/_models.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,45 @@ def check_row_length_matches_n_columns(self):
1717
f"of columns: {self.column_names}"
1818
)
1919
return self
20+
21+
def get_value(
22+
self,
23+
column_name: str,
24+
value: int | float,
25+
target_column_name: str,
26+
value_must_exist: bool = True,
27+
) -> int | float:
28+
"""Look up a value in one column and return the corresponding entry from another
29+
column.
30+
31+
Args:
32+
column_name (str): The name of the column to search in.
33+
value (int | float): The numeric value to look for within `column_name`.
34+
target_column_name (str): The name of the column from which to return the
35+
corresponding entry (same row as the matched value).
36+
value_must_exist (bool, optional): If true, value must exist in the LUT or
37+
an error will be thrown. Otherwise, the closest value will be used.
38+
Defaults to True.
39+
40+
Returns:
41+
int | float: The entry from `target_column_name` in the row where the
42+
matching (or closest) value was found in `column_name`.
43+
"""
44+
column_index = self.column_names.index(column_name)
45+
column = [row[column_index] for row in self.rows]
46+
target_column_index = self.column_names.index(target_column_name)
47+
48+
closest_value = (
49+
min(column, key=lambda x: abs(x - value)) if not value_must_exist else value
50+
)
51+
try:
52+
target_row = self.rows[column.index(closest_value)]
53+
except ValueError as e:
54+
raise ValueError(
55+
f"'{closest_value}' doesn't exist in column '{column_name}': {column}"
56+
) from e
57+
58+
return target_row[target_column_index]
59+
60+
def columns(self) -> list[list[int | float]]:
61+
return [[row[i] for row in self.rows] for i in range(len(self.column_names))]

tests/unit_tests/converters/test_lookup_tables_converters.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
import pytest
24
from tests.constants import TestDataPaths
35

@@ -142,3 +144,48 @@ def test_i09_hu_undulator_gap_lut_gives_expected_result():
142144
)
143145
result = i09_hu_undulator_energy_gap_lut(input)
144146
assert result == expected
147+
148+
149+
@pytest.mark.parametrize(
150+
"args, expected_value",
151+
[
152+
(("detector_distances_mm", 150, "beam_centre_x_mm", True), 152.2),
153+
(("beam_centre_y_mm", 160.96, "detector_distances_mm", True), 800),
154+
(
155+
("beam_centre_x_mm", 153, "beam_centre_y_mm", False),
156+
166.26, # get closest value when value_must_exist == False
157+
),
158+
],
159+
)
160+
def test_generic_lut_model_get_value_function(
161+
args: tuple[str, int | float, str, bool], expected_value: int | float
162+
):
163+
my_lut = GenericLookupTable(
164+
column_names=["detector_distances_mm", "beam_centre_x_mm", "beam_centre_y_mm"],
165+
rows=[[150, 152.2, 166.26], [800, 152.08, 160.96]],
166+
)
167+
assert my_lut.get_value(*args) == expected_value
168+
169+
170+
def test_generic_lut_model_get_value_errors_if_value_doesnt_exist():
171+
my_lut = GenericLookupTable(
172+
column_names=["detector_distances_mm", "beam_centre_x_mm", "beam_centre_y_mm"],
173+
rows=[[150, 152.2, 166.26], [800, 152.08, 160.96]],
174+
)
175+
with pytest.raises(
176+
ValueError,
177+
match=re.escape(
178+
"'160.97' doesn't exist in column 'beam_centre_y_mm': [166.26, 160.96]"
179+
),
180+
):
181+
# value doesn't exist
182+
my_lut.get_value("beam_centre_y_mm", 160.97, "detector_distances_mm")
183+
184+
185+
def test_generic_lut_model_columns_function():
186+
my_lut = GenericLookupTable(
187+
column_names=["detector_distances_mm", "beam_centre_x_mm", "beam_centre_y_mm"],
188+
rows=[[150, 152.2, 166.26], [800, 152.08, 160.96]],
189+
)
190+
expected_columns = [[150, 800], [152.2, 152.08], [166.26, 160.96]]
191+
assert my_lut.columns() == expected_columns

0 commit comments

Comments
 (0)