Skip to content

Commit df66bc8

Browse files
committed
Fix bug where _ProtoChildExpression.calculation_equal that was not commutative w.r.t. _TrasformProtoChildExpression.calculation_equal.
Also removed some non-determinism from calculate.py. PiperOrigin-RevId: 310041912
1 parent 0e69ae7 commit df66bc8

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

struct2tensor/calculate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _order_nodes(self) -> None:
406406
expr_id_to_count = {
407407
expr_id: len(node.sources) for expr_id, node in self._node.items()
408408
}
409-
for k, v in expr_id_to_count.items():
409+
for k, v in sorted(expr_id_to_count.items()):
410410
if v == 0:
411411
nodes_to_process.append(self._node[k])
412412
while nodes_to_process:

struct2tensor/expression_impl/proto.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from struct2tensor.expression_impl import parse_message_level_ex
3434
from struct2tensor.ops import struct2tensor_ops
3535
import tensorflow as tf
36-
from typing import Callable, FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union
36+
from typing import cast, Callable, FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union
3737

3838

3939
from google.protobuf.descriptor_pb2 import FileDescriptorSet
@@ -376,8 +376,12 @@ def calculate_from_parsed_field(self,
376376
return _ProtoChildNodeTensor(parsed_field.index, self.is_repeated, fields)
377377

378378
def calculation_equal(self, expr: expression.Expression) -> bool:
379-
return (isinstance(expr, _ProtoChildExpression) and
380-
self._desc == expr._desc and # pylint: disable=protected-access
379+
# Ensure that we're dealing with the _ProtoChildExpression and not any
380+
# of its subclasses.
381+
if type(expr) != _ProtoChildExpression: # pylint: disable=unidiomatic-typecheck
382+
return False
383+
expr = cast(_ProtoChildExpression, expr) # Keep pytype happy.
384+
return (self._desc == expr._desc and # pylint: disable=protected-access
381385
self.name_as_field == expr.name_as_field)
382386

383387
def _get_child_impl(self,
@@ -421,7 +425,7 @@ def calculation_equal(self, expr: expression.Expression) -> bool:
421425
return (isinstance(expr, _TransformProtoChildExpression) and
422426
self._desc == expr._desc and # pylint: disable=protected-access
423427
self.name_as_field == expr.name_as_field
424-
and self.transform_fn == expr.transform_fn)
428+
and self.transform_fn is expr.transform_fn)
425429

426430
def __str__(self) -> str: # pylint: disable=g-ambiguous-str-annotation
427431
return ("_TransformProtoChildExpression: name_as_field: {} desc: {} from {}"

0 commit comments

Comments
 (0)