Skip to content

Commit 01163d2

Browse files
Handle returning a numpy array of strings
Also make all the tests pass. Involved skipping
1 parent f127bf2 commit 01163d2

File tree

3 files changed

+67
-21
lines changed

3 files changed

+67
-21
lines changed

softioc/builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def Action(name, **fields):
148148
'uint32': 'ULONG',
149149
'float32': 'FLOAT',
150150
'float64': 'DOUBLE',
151-
'bytes32': 'STRING',
152-
'bytes320': 'STRING',
151+
'bytes320': 'STRING', # Numpy's term for a 40-character string (40*8 bits)
153152
}
154153

155154
# Coverts FTVL string to numpy type
@@ -214,8 +213,8 @@ def _waveform(value, fields):
214213

215214
# Special case for [u]int64: if the initial value comes in as 64 bit
216215
# integers we cannot represent that, so recast it as [u]int32
217-
# Special case for array of strings to correctly identify each element
218-
# of the array as a string type.
216+
# Special case for array of strings to mark each element as conforming
217+
# to EPICS 40-character string limit
219218
if datatype is None:
220219
if initial_value.dtype == numpy.int64:
221220
initial_value = numpy.require(initial_value, numpy.int32)

softioc/device.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,16 @@ def _require_waveform(value, dtype):
356356
if isinstance(value, bytes):
357357
# Special case hack for byte arrays. Surprisingly tricky:
358358
value = numpy.frombuffer(value, dtype = numpy.uint8)
359+
360+
if dtype and dtype.char == 'S':
361+
result = numpy.empty(len(value), 'S40')
362+
for n, s in enumerate(value):
363+
if isinstance(s, str):
364+
result[n] = s.encode('UTF-8')
365+
else:
366+
result[n] = s
367+
return result
368+
359369
value = numpy.require(value, dtype = dtype)
360370
if value.shape == ():
361371
value.shape = (1,)
@@ -391,7 +401,6 @@ def _read_value(self, record):
391401
return result
392402

393403
def _write_value(self, record, value):
394-
value = _require_waveform(value, self._dtype)
395404
nord = len(value)
396405
memmove(
397406
record.BPTR, value.ctypes.data_as(c_void_p),

tests/test_record_values.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
"might think to put into a record that can theoretically hold a huge " \
3232
"string and so lets test it and prove that shall we?"
3333

34+
# The numpy dtype of all arrays of strings
35+
DTYPE_STRING = "S40"
36+
3437

3538
def record_func_names(fixture_value):
3639
"""Provide a nice name for the record_func fixture"""
@@ -113,6 +116,8 @@ def record_values_names(fixture_value):
113116
("strOut_39chars", builder.stringOut, MAX_LEN_STR, MAX_LEN_STR, str),
114117
("strIn_empty", builder.stringIn, "", "", str),
115118
("strOut_empty", builder.stringOut, "", "", str),
119+
# TODO: Add Invalid-utf8 tests?
120+
# TODO: Add tests for bytes-strings to stringIn/Out?
116121
("strin_utf8", builder.stringIn, "%a€b", "%a€b", str), # Valid UTF-8
117122
("strOut_utf8", builder.stringOut, "%a€b", "%a€b", str), # Valid UTF-8
118123
(
@@ -189,39 +194,42 @@ def record_values_names(fixture_value):
189194
),
190195
numpy.ndarray,
191196
),
197+
198+
# TODO: Unicode versions of below tests?
199+
192200
(
193201
"wIn_byte_string_array",
194202
builder.WaveformIn,
195-
[b"AB", b"CD", b"EF"],
203+
[b"AB123", b"CD456", b"EF789"],
196204
numpy.array(
197-
[b"AB", b"CD", b"EF"], dtype=numpy.dtype("|S40")
205+
["AB123", "CD456", "EF789"], dtype=DTYPE_STRING
198206
),
199207
numpy.ndarray,
200208
),
201209
(
202210
"wOut_byte_string_array",
203211
builder.WaveformOut,
204-
[b"AB", b"CD", b"EF"],
212+
[b"12AB", b"34CD", b"56EF"],
205213
numpy.array(
206-
[b"AB", b"CD", b"EF"], dtype=numpy.dtype("|S40")
214+
["12AB", "34CD", "56EF"], dtype=DTYPE_STRING
207215
),
208216
numpy.ndarray,
209217
),
210218
(
211219
"wIn_string_array",
212220
builder.WaveformIn,
213-
["123", "456", "7890"],
221+
["123abc", "456def", "7890ghi"],
214222
numpy.array(
215-
[b"123", b"456", b"7890"], dtype=numpy.dtype("|S40")
223+
["123abc", "456def", "7890ghi"], dtype=DTYPE_STRING
216224
),
217225
numpy.ndarray,
218226
),
219227
(
220228
"wOut_string_array",
221229
builder.WaveformOut,
222-
["123", "456", "7890"],
230+
["123abc", "456def", "7890ghi"],
223231
numpy.array(
224-
[b"123", b"456", b"7890"], dtype=numpy.dtype("|S40")
232+
["123abc", "456def", "7890ghi"], dtype=DTYPE_STRING
225233
),
226234
numpy.ndarray,
227235
),
@@ -275,6 +283,7 @@ def record_values(request):
275283
"""A list of parameters for record value setting/getting tests.
276284
277285
Fields are:
286+
- Record name
278287
- Record builder function
279288
- Input value passed to .set()/initial_value/caput
280289
- Expected output value after doing .get()/caget
@@ -412,8 +421,31 @@ def run_test_function(
412421
expected value. set_enum and get_enum determine when the record's value is
413422
set and how the value is retrieved, respectively."""
414423

415-
ctx = get_multiprocessing_context()
424+
def is_valid(configuration):
425+
"""Remove some cases that cannot succeed.
426+
Waveforms of Strings must have the value specified as initial value."""
427+
(
428+
record_name,
429+
creation_func,
430+
initial_value,
431+
expected_value,
432+
expected_type,
433+
) = configuration
434+
435+
if creation_func in (builder.WaveformIn, builder.WaveformOut):
436+
if isinstance(initial_value, list) and \
437+
all(isinstance(val, (str, bytes)) for val in initial_value):
438+
if set_enum is not SetValueEnum.INITIAL_VALUE:
439+
print(f"Removing {configuration}")
440+
return False
441+
442+
return True
443+
444+
record_configurations = [
445+
x for x in record_configurations if is_valid(x)
446+
]
416447

448+
ctx = get_multiprocessing_context()
417449
parent_conn, child_conn = ctx.Pipe()
418450

419451
ioc_process = ctx.Process(
@@ -498,6 +530,7 @@ def run_test_function(
498530

499531
if (
500532
creation_func in [builder.WaveformOut, builder.WaveformIn]
533+
and expected_value.dtype
501534
and expected_value.dtype in [numpy.float64, numpy.int32]
502535
):
503536
log(
@@ -506,6 +539,10 @@ def run_test_function(
506539
"scalar. Therefore we skip this check.")
507540
continue
508541

542+
if isinstance(rec_val, numpy.ndarray) and len(rec_val) > 1 \
543+
and rec_val.dtype.char in ["S", "U"]:
544+
# caget won't retrieve the full length 40 buffer
545+
rec_val = rec_val.astype(DTYPE_STRING)
509546

510547
record_value_asserts(
511548
creation_func, rec_val, expected_value, expected_type
@@ -521,13 +558,6 @@ def run_test_function(
521558
pytest.fail("Process did not terminate")
522559

523560

524-
def skip_long_strings(record_values):
525-
if (
526-
record_values[0] in [builder.stringIn, builder.stringOut]
527-
and len(record_values[1]) > 40
528-
):
529-
pytest.skip("CAPut blocks strings > 40 characters.")
530-
531561

532562
class TestGetValue:
533563
"""Tests that use .get() to check whether values applied with .set(),
@@ -545,6 +575,14 @@ def test_value_pre_init_set(self, record_values):
545575
expected_type,
546576
) = record_values
547577

578+
if (
579+
creation_func in [builder.WaveformIn, builder.WaveformOut] and
580+
isinstance(initial_value, list) and
581+
all(isinstance(s, (str, bytes)) for s in initial_value)
582+
):
583+
pytest.skip("Cannot .set() a list of strings to a waveform, must"
584+
"initially specify using initial_value or FTVL")
585+
548586
kwarg = {}
549587
if creation_func in [builder.WaveformIn, builder.WaveformOut]:
550588
kwarg = {"length": WAVEFORM_LENGTH} # Must specify when no value

0 commit comments

Comments
 (0)