Skip to content

Commit decaa03

Browse files
authored
Improved typing for validators (#995)
* Add typing to ConsoleSummary * Add typing to StationServer * Improved typing for validators
1 parent c34e3cc commit decaa03

File tree

3 files changed

+63
-51
lines changed

3 files changed

+63
-51
lines changed

openhtf/output/callbacks/console_summary.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import sys
5+
from typing import TextIO
56

67
from openhtf.core import measurements
78
from openhtf.core import test_record
@@ -11,7 +12,9 @@ class ConsoleSummary():
1112
"""Print test results with failure info on console."""
1213

1314
# pylint: disable=invalid-name
14-
def __init__(self, indent=2, output_stream=sys.stdout):
15+
def __init__(self,
16+
indent: int = 2,
17+
output_stream: TextIO = sys.stdout) -> None:
1518
self.indent = ' ' * indent
1619
if os.name == 'posix': # Linux and Mac.
1720
self.RED = '\033[91m'
@@ -37,7 +40,7 @@ def __init__(self, indent=2, output_stream=sys.stdout):
3740

3841
# pylint: enable=invalid-name
3942

40-
def __call__(self, record):
43+
def __call__(self, record: test_record.TestRecord) -> None:
4144
output_lines = [
4245
''.join((self.color_table[record.outcome], self.BOLD,
4346
record.code_info.name, ':', record.outcome.name, self.RESET))

openhtf/output/servers/station_server.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import threading
1616
import time
1717
import types
18+
from typing import Optional, Union
1819

1920
import openhtf
2021
from openhtf.output.servers import pub_sub
@@ -558,7 +559,9 @@ class StationServer(web_gui_server.WebGuiServer):
558559
test.execute()
559560
"""
560561

561-
def __init__(self, history_path=None):
562+
def __init__(
563+
self,
564+
history_path: Optional[Union[str, bytes, os.PathLike]] = None) -> None:
562565
# Disable tornado's logging.
563566
# TODO(kenadia): Enable these logs if verbosity flag is at least -vvv.
564567
# I think this will require changing how StoreRepsInModule works.
@@ -614,7 +617,7 @@ def _get_config(self):
614617
'server_type': STATION_SERVER_TYPE,
615618
}
616619

617-
def run(self):
620+
def run(self) -> None:
618621
_LOG.info('Announcing station server via multicast on %s:%s',
619622
self.station_multicast.address, self.station_multicast.port)
620623
self.station_multicast.start()
@@ -624,13 +627,13 @@ def run(self):
624627
host=socket.gethostname(), port=self.port))
625628
super(StationServer, self).run()
626629

627-
def stop(self):
630+
def stop(self) -> None:
628631
_LOG.info('Stopping station server.')
629632
super(StationServer, self).stop()
630633
_LOG.info('Stopping multicast.')
631634
self.station_multicast.stop(timeout_s=0)
632635

633-
def publish_final_state(self, test_record):
636+
def publish_final_state(self, test_record: openhtf.TestRecord) -> None:
634637
"""Test output callback publishing a final state from the test record."""
635638
StationPubSub.publish_test_record(test_record)
636639

openhtf/util/validators.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from openhtf.util import measurements
99
1010
class MyLessThanValidator(ValidatorBase):
11-
def __init__(self, limit):
11+
def __init__(self, limit) -> None:
1212
self.limit = limit
1313
1414
# This will be invoked to test if the measurement is 'PASS' or 'FAIL'.
15-
def __call__(self, value):
15+
def __call__(self, value) -> bool:
1616
return value < self.limit
1717
1818
# Name defaults to the validator's __name__ attribute unless overridden.
@@ -31,12 +31,12 @@ def MyPhase(test):
3131
For simpler validators, you don't need to register them at all, you can
3232
simply attach them to the Measurement with the .with_validator() method:
3333
34-
def LessThan4(value):
34+
def LessThan4(value) -> bool:
3535
return value < 4
3636
3737
@measurements.measures(
3838
measurements.Measurement('my_measurement).with_validator(LessThan4))
39-
def MyPhase(test):
39+
def MyPhase(test: htf.TestApi) -> None:
4040
test.measurements.my_measurement = 5 # Will also 'FAIL'
4141
4242
Notes:
@@ -58,37 +58,43 @@ def MyPhase(test):
5858
import math
5959
import numbers
6060
import re
61+
from typing import Callable, Dict, Optional, Type, TypeVar, Union
62+
6163
from openhtf import util
6264

63-
_VALIDATORS = {}
6465

66+
class ValidatorBase(abc.ABC):
67+
68+
@abc.abstractmethod
69+
def __call__(self, value) -> bool:
70+
"""Should validate value, returning a boolean result."""
71+
72+
73+
_ValidatorT = TypeVar("_ValidatorT", bound=ValidatorBase)
74+
_ValidatorFactoryT = Union[Type[_ValidatorT], Callable[..., _ValidatorT]]
75+
_VALIDATORS: Dict[str, _ValidatorFactoryT] = {}
6576

66-
def register(validator, name=None):
77+
78+
def register(validator: _ValidatorFactoryT,
79+
name: Optional[str] = None) -> _ValidatorFactoryT:
6780
name = name or validator.__name__
6881
if name in _VALIDATORS:
6982
raise ValueError('Duplicate validator name', name)
7083
_VALIDATORS[name] = validator
7184
return validator
7285

7386

74-
def has_validator(name):
87+
def has_validator(name: str) -> bool:
7588
return name in _VALIDATORS
7689

7790

78-
def create_validator(name, *args, **kwargs):
91+
def create_validator(name: str, *args, **kwargs) -> _ValidatorT:
7992
return _VALIDATORS[name](*args, **kwargs)
8093

8194

8295
_identity = lambda x: x
8396

8497

85-
class ValidatorBase(abc.ABC):
86-
87-
@abc.abstractmethod
88-
def __call__(self, value):
89-
"""Should validate value, returning a boolean result."""
90-
91-
9298
class RangeValidatorBase(ValidatorBase, abc.ABC):
9399

94100
@abc.abstractproperty
@@ -120,7 +126,7 @@ def __init__(self,
120126
minimum,
121127
maximum,
122128
marginal_minimum=None,
123-
marginal_maximum=None):
129+
marginal_maximum=None) -> None:
124130
super(AllInRangeValidator, self).__init__()
125131
if minimum is None and maximum is None:
126132
raise ValueError('Must specify minimum, maximum, or both')
@@ -168,7 +174,7 @@ def marginal_minimum(self):
168174
def marginal_maximum(self):
169175
return self._marginal_maximum
170176

171-
def __call__(self, values):
177+
def __call__(self, values) -> bool:
172178
within_maximum = self._maximum is None or all(
173179
value <= self.maximum for value in values)
174180
within_minimum = self._minimum is None or all(
@@ -204,18 +210,18 @@ def __str__(self):
204210
class AllEqualsValidator(ValidatorBase):
205211
"""Validator to verify a list of values are equal to the expected value."""
206212

207-
def __init__(self, spec):
213+
def __init__(self, spec) -> None:
208214
super(AllEqualsValidator, self).__init__()
209215
self._spec = spec
210216

211217
@property
212218
def spec(self):
213219
return self._spec
214220

215-
def __call__(self, values):
221+
def __call__(self, values) -> bool:
216222
return all([value == self.spec for value in values])
217223

218-
def __str__(self):
224+
def __str__(self) -> str:
219225
return "'x' is equal to '%s'" % self._spec
220226

221227

@@ -242,7 +248,7 @@ def __init__(self,
242248
maximum=None,
243249
marginal_minimum=None,
244250
marginal_maximum=None,
245-
type=None): # pylint: disable=redefined-builtin
251+
type=None) -> None: # pylint: disable=redefined-builtin
246252
super(InRange, self).__init__()
247253

248254
if minimum is None and maximum is None:
@@ -292,7 +298,7 @@ def marginal_minimum(self):
292298
return converter(self._marginal_minimum)
293299

294300
@property
295-
def marginal_maximum(self):
301+
def marginal_maximum(self) -> str:
296302
converter = self._type if self._type is not None else _identity
297303
return converter(self._marginal_maximum)
298304

@@ -305,7 +311,7 @@ def with_args(self, **kwargs):
305311
type=self._type,
306312
)
307313

308-
def __call__(self, value):
314+
def __call__(self, value) -> bool:
309315
if value is None:
310316
return False
311317
if math.isnan(value):
@@ -329,7 +335,7 @@ def is_marginal(self, value) -> bool:
329335
return True
330336
return False
331337

332-
def __str__(self):
338+
def __str__(self) -> str:
333339
assert self._minimum is not None or self._maximum is not None
334340
if (self._minimum is not None and self._maximum is not None and
335341
self._minimum == self._maximum):
@@ -347,13 +353,13 @@ def __str__(self):
347353
string_repr += ' <= {}'.format(self._maximum)
348354
return string_repr
349355

350-
def __eq__(self, other):
356+
def __eq__(self, other) -> bool:
351357
return (isinstance(other, type(self)) and self.minimum == other.minimum and
352358
self.maximum == other.maximum and
353359
self.marginal_minimum == other.marginal_minimum and
354360
self.marginal_maximum == other.marginal_maximum)
355361

356-
def __ne__(self, other):
362+
def __ne__(self, other) -> bool:
357363
return not self == other
358364

359365

@@ -373,10 +379,10 @@ def equals(value, type=None): # pylint: disable=redefined-builtin
373379
return Equals(value, type=type)
374380

375381

376-
class Equals(object):
382+
class Equals(ValidatorBase):
377383
"""Validator to verify an object is equal to the expected value."""
378384

379-
def __init__(self, expected, type=None): # pylint: disable=redefined-builtin
385+
def __init__(self, expected, type=None) -> None: # pylint: disable=redefined-builtin
380386
self._expected = expected
381387
self._type = type
382388

@@ -388,21 +394,21 @@ def expected(self):
388394
def __call__(self, value):
389395
return value == self.expected
390396

391-
def __str__(self):
397+
def __str__(self) -> str:
392398
return f"'x' is equal to '{self._expected}'"
393399

394-
def __eq__(self, other):
400+
def __eq__(self, other) -> bool:
395401
return isinstance(other, type(self)) and self.expected == other.expected
396402

397403

398-
class RegexMatcher(object):
404+
class RegexMatcher(ValidatorBase):
399405
"""Validator to verify a string value matches a regex."""
400406

401-
def __init__(self, regex, compiled_regex):
407+
def __init__(self, regex, compiled_regex) -> None:
402408
self._compiled = compiled_regex
403409
self.regex = regex
404410

405-
def __call__(self, value):
411+
def __call__(self, value) -> bool:
406412
return self._compiled.match(str(value)) is not None
407413

408414
def __deepcopy__(self, dummy_memo):
@@ -414,7 +420,7 @@ def __str__(self):
414420
def __eq__(self, other):
415421
return isinstance(other, type(self)) and self.regex == other.regex
416422

417-
def __ne__(self, other):
423+
def __ne__(self, other) -> bool:
418424
return not self == other
419425

420426

@@ -426,7 +432,7 @@ def matches_regex(regex):
426432
class WithinPercent(RangeValidatorBase):
427433
"""Validates that a number is within percent of a value."""
428434

429-
def __init__(self, expected, percent, marginal_percent=None):
435+
def __init__(self, expected, percent, marginal_percent=None) -> None:
430436
super(WithinPercent, self).__init__()
431437
if percent < 0:
432438
raise ValueError('percent argument is {}, must be >0'.format(percent))
@@ -465,7 +471,7 @@ def marginal_maximum(self):
465471
return (self.expected -
466472
self._applied_marginal_percent if self.marginal_percent else None)
467473

468-
def __call__(self, value):
474+
def __call__(self, value) -> bool:
469475
return self.minimum <= value <= self.maximum
470476

471477
def is_marginal(self, value) -> bool:
@@ -475,17 +481,17 @@ def is_marginal(self, value) -> bool:
475481
return (self.minimum < value <= self.marginal_minimum or
476482
self.marginal_maximum <= value < self.maximum)
477483

478-
def __str__(self):
484+
def __str__(self) -> str:
479485
return "'x' is within {}% of {}. Marginal: {}% of {}".format(
480486
self.percent, self.expected, self.marginal_percent, self.expected)
481487

482-
def __eq__(self, other):
488+
def __eq__(self, other) -> bool:
483489
return (isinstance(other, type(self)) and
484490
self.expected == other.expected and
485491
self.percent == other.percent and
486492
self.marginal_percent == other.marginal_percent)
487493

488-
def __ne__(self, other):
494+
def __ne__(self, other) -> bool:
489495
return not self == other
490496

491497

@@ -497,14 +503,14 @@ def within_percent(expected, percent):
497503
class DimensionPivot(ValidatorBase):
498504
"""Runs a validator on each actual value of a dimensioned measurement."""
499505

500-
def __init__(self, sub_validator):
506+
def __init__(self, sub_validator) -> None:
501507
super(DimensionPivot, self).__init__()
502508
self._sub_validator = sub_validator
503509

504-
def __call__(self, dimensioned_value):
510+
def __call__(self, dimensioned_value) -> bool:
505511
return all(self._sub_validator(row[-1]) for row in dimensioned_value)
506512

507-
def __str__(self):
513+
def __str__(self) -> str:
508514
return 'All values pass: {}'.format(str(self._sub_validator))
509515

510516

@@ -516,11 +522,11 @@ def dimension_pivot_validate(sub_validator):
516522
class ConsistentEndDimensionPivot(ValidatorBase):
517523
"""If any rows validate, all following rows must also validate."""
518524

519-
def __init__(self, sub_validator):
525+
def __init__(self, sub_validator) -> None:
520526
super(ConsistentEndDimensionPivot, self).__init__()
521527
self._sub_validator = sub_validator
522528

523-
def __call__(self, dimensioned_value):
529+
def __call__(self, dimensioned_value) -> bool:
524530
for index, row in enumerate(dimensioned_value):
525531
if self._sub_validator(row[-1]):
526532
i = index
@@ -529,7 +535,7 @@ def __call__(self, dimensioned_value):
529535
return False
530536
return all(self._sub_validator(rest[-1]) for rest in dimensioned_value[i:])
531537

532-
def __str__(self):
538+
def __str__(self) -> str:
533539
return 'Once pass, rest must also pass: {}'.format(str(self._sub_validator))
534540

535541

0 commit comments

Comments
 (0)