Skip to content

Commit 20cf010

Browse files
Add CAPut - get() test cases
This required a large amount of checks in many places to cover edge cases. I think this test code is probably getting too complex and needs simplification.
1 parent c36a50d commit 20cf010

File tree

1 file changed

+64
-17
lines changed

1 file changed

+64
-17
lines changed

tests/test_records.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -321,12 +321,16 @@ def run_ioc(creation_func, initial_value, queue, set_enum):
321321

322322
if queue is not None:
323323
queue.put(out_rec.get())
324-
else:
325-
# Keep process alive while main thread works.
326-
asyncio.run_coroutine_threadsafe(
327-
asyncio.sleep(TIMEOUT),
328-
dispatcher.loop
329-
).result()
324+
# Some tests need to do a caput and then cause another .get() to happen
325+
val = queue.get(timeout=TIMEOUT)
326+
if val:
327+
queue.put(out_rec.get())
328+
329+
# Keep process alive while main thread works.
330+
asyncio.run_coroutine_threadsafe(
331+
asyncio.sleep(TIMEOUT),
332+
dispatcher.loop
333+
).result()
330334

331335

332336
def run_test_function(
@@ -371,11 +375,16 @@ def run_test_function(
371375
_channel_cache.purge()
372376

373377
if set_enum == SetValueEnum.CAPUT:
378+
if queue:
379+
queue.get(timeout=TIMEOUT)
374380
caput(
375381
DEVICE_NAME + ":" + RECORD_NAME,
376382
initial_value,
377383
wait=True,
378384
**put_kwargs)
385+
386+
if queue:
387+
queue.put("Do another get!")
379388
# Ensure IOC process has time to execute.
380389
# I saw failures on MacOS where it appeared the IOC had not
381390
# processed the put'ted value as the caget returned the same value
@@ -400,6 +409,9 @@ def run_test_function(
400409
expected_value,
401410
expected_type)
402411
finally:
412+
# Purge cache to suppress spurious "IOC disconnected" exceptions
413+
_channel_cache.purge()
414+
403415
ioc_process.terminate()
404416
ioc_process.join(timeout=TIMEOUT)
405417
if ioc_process.exitcode is None:
@@ -415,14 +427,31 @@ def record_value_asserts(
415427
if type(expected_value) == float and isnan(expected_value):
416428
assert isnan(actual_value) # NaN != Nan, so needs special case
417429
elif creation_func in [builder.WaveformOut, builder.WaveformIn]:
418-
assert numpy.array_equal(actual_value, expected_value), \
430+
431+
# Special case for lack of default value on Out records before init
432+
if actual_value is None and expected_value is None:
433+
assert type(actual_value) == expected_type
434+
return
435+
436+
# Using .get() on the array returns entire length, not just filled part.
437+
# Confirm this by ensuring sliced part of array is all zeros
438+
assert not numpy.any(actual_value[expected_value.size:])
439+
truncated_value = actual_value[:expected_value.size]
440+
assert numpy.array_equal(truncated_value, expected_value), \
419441
"Arrays not equal: {} {}".format(actual_value, expected_value)
420442
assert type(actual_value) == expected_type
421443
else:
422444
assert actual_value == expected_value
423445
assert type(actual_value) == expected_type
424446

425447

448+
def skip_long_strings(record_values):
449+
if (
450+
record_values[0] in [builder.stringIn, builder.stringOut]
451+
and len(record_values[1]) > 40
452+
):
453+
pytest.skip("CAPut blocks strings > 40 characters.")
454+
426455
def test_records(tmp_path):
427456
import sim_records
428457

@@ -514,6 +543,27 @@ def test_value_post_init_set_after_init(self, record_values):
514543
SetValueEnum.SET_AFTER_INIT,
515544
GetValueEnum.GET)
516545

546+
@requires_cothread
547+
def test_value_post_init_caput(self, record_values):
548+
"""Test that records provide the expected values on get calls when using
549+
caput and .get() after IOC initialisation"""
550+
551+
if record_values[0] in [
552+
builder.aIn,
553+
builder.boolIn,
554+
builder.longIn,
555+
builder.mbbIn,
556+
builder.stringIn,
557+
builder.WaveformIn]:
558+
pytest.skip("CAPut to In records doesn't propogate to .get()")
559+
560+
skip_long_strings(record_values)
561+
562+
run_test_function(
563+
record_values,
564+
SetValueEnum.CAPUT,
565+
GetValueEnum.GET)
566+
517567
@requires_cothread
518568
class TestCagetValue:
519569
"""Tests that use Caget to check whether values applied with .set()
@@ -551,7 +601,7 @@ def test_value_post_init_initial_value(self, record_values):
551601
run_test_function(
552602
record_values,
553603
SetValueEnum.INITIAL_VALUE,
554-
GetValueEnum.GET)
604+
GetValueEnum.CAGET)
555605

556606
@requires_cothread
557607
def test_value_post_init_set_after_init(self, record_values):
@@ -561,17 +611,13 @@ def test_value_post_init_set_after_init(self, record_values):
561611
run_test_function(
562612
record_values,
563613
SetValueEnum.SET_AFTER_INIT,
564-
GetValueEnum.GET)
614+
GetValueEnum.CAGET)
565615

566616
def test_value_post_init_caput(self, record_values):
567617
"""Test that records provide the expected values on get calls when using
568618
.set() before IOC initialisation and caget after initialisation"""
569619

570-
if (
571-
record_values[0] in [builder.stringIn, builder.stringOut]
572-
and len(record_values[1]) > 40
573-
):
574-
pytest.skip("CAPut blocks strings > 40 characters.")
620+
skip_long_strings(record_values)
575621

576622
run_test_function(
577623
record_values,
@@ -594,7 +640,7 @@ class TestDefaultValue:
594640
(builder.mbbOut, None, type(None)),
595641
(builder.mbbIn, 0, int),
596642
(builder.WaveformOut, None, type(None)),
597-
(builder.WaveformIn, [], numpy.ndarray),
643+
(builder.WaveformIn, numpy.empty(0), numpy.ndarray),
598644
])
599645
def test_value_default_pre_init(
600646
self,
@@ -604,6 +650,7 @@ def test_value_default_pre_init(
604650
clear_records):
605651
"""Test that the correct default values are returned from .get() (before
606652
record initialisation) when no initial_value or .set() is done"""
653+
# Out records do not have default values until records are initialized
607654

608655
kwarg = {}
609656
if creation_func in [builder.WaveformIn, builder.WaveformOut]:
@@ -628,8 +675,8 @@ def test_value_default_pre_init(
628675
(builder.stringIn, "", str),
629676
(builder.mbbOut, 0, int),
630677
(builder.mbbIn, 0, int),
631-
(builder.WaveformOut, [], numpy.ndarray),
632-
(builder.WaveformIn, [], numpy.ndarray),
678+
(builder.WaveformOut, numpy.empty(0), numpy.ndarray),
679+
(builder.WaveformIn, numpy.empty(0), numpy.ndarray),
633680
])
634681
def test_value_default_post_init(
635682
self,

0 commit comments

Comments
 (0)