Skip to content

Commit 9ed3bb0

Browse files
committed
initial go at dae file
1 parent c90f579 commit 9ed3bb0

File tree

2 files changed

+75
-17
lines changed

2 files changed

+75
-17
lines changed

src/genie_python/genie.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import types
1010
from builtins import FileNotFoundError, str
1111
from io import open
12-
from typing import Any, Callable, TypedDict
12+
from typing import Any, Callable, TypedDict, Union
1313

1414
import numpy as np
1515
import numpy.typing as npt
@@ -2175,12 +2175,18 @@ def change_rb(rb: int | str) -> None:
21752175

21762176

21772177
class _GetspectrumReturn(TypedDict):
2178-
time: list[float]
2179-
signal: list[float]
2178+
time: Union[list[float], None]
2179+
signal: Union[list[float], None]
21802180
sum: None
21812181
mode: str
21822182

21832183

2184+
class _GetspectrumReturnNumpy(TypedDict):
2185+
time: Union[npt.NDArray[float], None]
2186+
signal: Union[npt.NDArray[float], None]
2187+
sum: None
2188+
mode: str
2189+
21842190
@usercommand
21852191
@helparglist("spectrum[, period][, dist]")
21862192
@log_command_and_handle_exception

src/genie_python/genie_dae.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import re
6+
import typing
67
import xml.etree.ElementTree as ET
78
import zlib
89
from binascii import hexlify
@@ -36,7 +37,7 @@
3637
)
3738

3839
if TYPE_CHECKING:
39-
from genie_python.genie import PVValue, _GetspectrumReturn
40+
from genie_python.genie import PVValue, _GetspectrumReturn, _GetspectrumReturnNumpy
4041
from genie_python.genie_epics_api import API
4142

4243
## for beginrun etc. there exists both the PV specified here and also a PV with
@@ -202,6 +203,22 @@ def _get_dae_pv_name(self, name: str, base: bool = False) -> str:
202203
else:
203204
return self._prefix_pv_name(DAE_PVS_LOOKUP[name.lower()])
204205

206+
@typing.overload
207+
def _get_pv_value(
208+
self,
209+
name: str,
210+
to_string: typing.Literal[False] = False,
211+
use_numpy: bool | None = None
212+
) -> "PVValue": ...
213+
214+
@typing.overload
215+
def _get_pv_value(
216+
self,
217+
name: str,
218+
to_string: typing.Literal[True],
219+
use_numpy: bool | None = None
220+
) -> str: ...
221+
205222
def _get_pv_value(
206223
self, name: str, to_string: bool = False, use_numpy: bool | None = None
207224
) -> "PVValue":
@@ -1886,9 +1903,27 @@ def _change_period_settings(self) -> None:
18861903
"set a number that is too large for the DAE memory. Try a smaller number!"
18871904
)
18881905

1906+
1907+
@typing.overload
1908+
def get_spectrum(
1909+
self, spectrum: int, period: int = 1, dist: bool = True, use_numpy: typing.Literal[
1910+
False] = False
1911+
) -> "_GetspectrumReturn": ...
1912+
1913+
@typing.overload
18891914
def get_spectrum(
1890-
self, spectrum: int, period: int = 1, dist: bool = True, use_numpy: bool | None = None
1891-
) -> "_GetspectrumReturn":
1915+
self, spectrum: int, period: int = 1, dist: bool = True, *, use_numpy: typing.Literal[
1916+
True]
1917+
) -> "_GetspectrumReturnNumpy": ...
1918+
1919+
@typing.overload
1920+
def get_spectrum(
1921+
self, spectrum: int, period: int = 1, dist: bool = True, *, use_numpy: bool
1922+
) -> typing.Union["_GetspectrumReturnNumpy", "_GetspectrumReturn"]: ...
1923+
1924+
def get_spectrum(
1925+
self, spectrum: int, period: int = 1, dist: bool = True, use_numpy: bool = False
1926+
) -> typing.Union["_GetspectrumReturnNumpy", "_GetspectrumReturn"]:
18921927
"""
18931928
Gets a spectrum from the DAE via a PV.
18941929
@@ -1909,6 +1944,10 @@ def get_spectrum(
19091944
y_size = self._get_pv_value(
19101945
self._get_dae_pv_name("getspectrum_y_size").format(period, spectrum)
19111946
)
1947+
if use_numpy:
1948+
assert isinstance(y_data, npt.NDArray[float])
1949+
else:
1950+
assert isinstance(y_data, list)
19121951
y_data = y_data[:y_size]
19131952
mode = "distribution"
19141953
x_size = y_size
@@ -1920,12 +1959,20 @@ def get_spectrum(
19201959
y_size = self._get_pv_value(
19211960
self._get_dae_pv_name("getspectrum_yc_size").format(period, spectrum)
19221961
)
1962+
if use_numpy:
1963+
assert isinstance(y_data, npt.NDArray[float])
1964+
else:
1965+
assert isinstance(y_data, list)
19231966
y_data = y_data[:y_size]
19241967
mode = "non-distribution"
19251968
x_size = y_size + 1
19261969
x_data = self._get_pv_value(
19271970
self._get_dae_pv_name("getspectrum_x").format(period, spectrum), use_numpy=use_numpy
19281971
)
1972+
if use_numpy:
1973+
assert isinstance(x_data, npt.NDArray[float])
1974+
else:
1975+
assert isinstance(x_data, list)
19291976
x_data = x_data[:x_size]
19301977

19311978
return {"time": x_data, "signal": y_data, "sum": None, "mode": mode}
@@ -2009,20 +2056,22 @@ def get_tcb_settings(self, trange: int, regime: int = 1) -> dict[str, int]:
20092056

20102057
for top in root.iter("DBL"):
20112058
n = top.find("Name")
2012-
match = regex.search(n.text)
2013-
if match is not None:
2059+
if isinstance(n, ET.Element) and n.text is not None:
2060+
match = regex.search(n.text)
20142061
v = top.find("Val")
2015-
out[match.group(1)] = v.text
2016-
2062+
if match is not None and isinstance(v, ET.Element) and v.text is not None:
2063+
out[match.group(1)] = v.text
20172064
return out
20182065

20192066
def get_table_path(self, table_type: str) -> str:
20202067
dae_xml = self._get_dae_settings_xml()
20212068
for top in dae_xml.iter("String"):
20222069
n = top.find("Name")
2023-
if n.text == "{} Table".format(table_type):
2070+
if isinstance(n, ET.Element) and n.text == "{} Table".format(table_type):
20242071
val = top.find("Val")
2072+
assert val is not None and val.text is not None
20252073
return val.text
2074+
raise Exception("No tables found in DAE settings XML")
20262075

20272076
def _get_dae_settings_xml(self) -> ET.Element:
20282077
xml_value = self._get_pv_value(self._get_dae_pv_name("daesettings"))
@@ -2039,7 +2088,7 @@ def _wait_for_isis_dae_state(self, state: str, timeout: int) -> tuple[bool, str]
20392088
state_attained = False
20402089
current_state = ""
20412090
for _ in range(timeout):
2042-
current_state = self._get_pv_value(self._prefix_pv_name("CS:PS:ISISDAE_01:STATUS"))
2091+
current_state = str(self._get_pv_value(self._prefix_pv_name("CS:PS:ISISDAE_01:STATUS")))
20432092
if current_state == state:
20442093
state_attained = True
20452094
break
@@ -2053,7 +2102,7 @@ def _isis_dae_triggered_state_was_reached(
20532102
state: str,
20542103
timeout_per_trigger: int = 20,
20552104
max_number_of_triggers: int = 5,
2056-
) -> str:
2105+
) -> bool:
20572106
"""
20582107
Trigger a state and wait for the state to be reached. For example stop the
20592108
ISIS DAE and wait for it to be
@@ -2152,7 +2201,7 @@ def is_changing(self) -> bool:
21522201

21532202
def integrate_spectrum(
21542203
self, spectrum: int, period: int = 1, t_min: float | None = None, t_max: float | None = None
2155-
) -> float:
2204+
) -> float | None:
21562205
"""
21572206
Integrates the spectrum within the time period and returns neutron counts.
21582207
@@ -2168,9 +2217,9 @@ def integrate_spectrum(
21682217
Returns:
21692218
float: integral of the spectrum (neutron counts)
21702219
"""
2171-
spectrum = self.get_spectrum(spectrum, period, False, use_numpy=True)
2172-
time = spectrum["time"]
2173-
count = spectrum["signal"]
2220+
spectrum_dict = self.get_spectrum(spectrum, period, False, use_numpy=True)
2221+
time = spectrum_dict["time"]
2222+
count = spectrum_dict["signal"]
21742223

21752224
if time is None or count is None:
21762225
return None
@@ -2201,6 +2250,9 @@ def integrate_spectrum(
22012250
)
22022251
last_complete_bin = time.searchsorted(t_max, side="left")
22032252

2253+
assert t_max is not None
2254+
assert t_min is not None
2255+
22042256
# Error check
22052257
if t_max < t_min:
22062258
raise ValueError("Time range is not valid, to_time is less than from_time.")

0 commit comments

Comments
 (0)