|
33 | 33 | from struct2tensor.expression_impl import parse_message_level_ex |
34 | 34 | from struct2tensor.ops import struct2tensor_ops |
35 | 35 | import tensorflow as tf |
36 | | -from typing import Callable, FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union |
| 36 | +from typing import cast, Callable, FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union |
37 | 37 |
|
38 | 38 |
|
39 | 39 | from google.protobuf.descriptor_pb2 import FileDescriptorSet |
@@ -376,8 +376,12 @@ def calculate_from_parsed_field(self, |
376 | 376 | return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated, fields) |
377 | 377 |
|
378 | 378 | def calculation_equal(self, expr: expression.Expression) -> bool: |
379 | | - return (isinstance(expr, _ProtoChildExpression) and |
380 | | - self._desc == expr._desc and # pylint: disable=protected-access |
| 379 | + # Ensure that we're dealing with the _ProtoChildExpression and not any |
| 380 | + # of its subclasses. |
| 381 | + if type(expr) != _ProtoChildExpression: # pylint: disable=unidiomatic-typecheck |
| 382 | + return False |
| 383 | + expr = cast(_ProtoChildExpression, expr) # Keep pytype happy. |
| 384 | + return (self._desc == expr._desc and # pylint: disable=protected-access |
381 | 385 | self.name_as_field == expr.name_as_field) |
382 | 386 |
|
383 | 387 | def _get_child_impl(self, |
@@ -421,7 +425,7 @@ def calculation_equal(self, expr: expression.Expression) -> bool: |
421 | 425 | return (isinstance(expr, _TransformProtoChildExpression) and |
422 | 426 | self._desc == expr._desc and # pylint: disable=protected-access |
423 | 427 | self.name_as_field == expr.name_as_field |
424 | | - and self.transform_fn == expr.transform_fn) |
| 428 | + and self.transform_fn is expr.transform_fn) |
425 | 429 |
|
426 | 430 | def __str__(self) -> str: # pylint: disable=g-ambiguous-str-annotation |
427 | 431 | return ("_TransformProtoChildExpression: name_as_field: {} desc: {} from {}" |
|
0 commit comments