Skip to content

Commit 0e69ae7

Browse files
committed
Add create_transformed_field for proto expressions. This allows users to interpose a transform function between proto expressions (i.e. between levels of parsing).
PiperOrigin-RevId: 309968343
1 parent 7740f70 commit 0e69ae7

File tree

7 files changed

+265
-14
lines changed

7 files changed

+265
-14
lines changed

struct2tensor/calculate_with_source_paths_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,66 @@ def test_calculate_prensors_with_source_paths(self):
7676
expected = [path.Path(["event", "action", "doc_id"])]
7777
self.equal_ignore_order(list_of_paths, expected)
7878

79+
def test_calculate_prensors_with_source_paths_with_transform(self):
80+
"""Tests get_sparse_tensors on a deep tree with a transformed field."""
81+
expr = proto_test_util._get_expression_from_session_empty_user_info()
82+
83+
# Let's make it non-trivial by transforming the data.
84+
def _reverse(parent_indices, values):
85+
return parent_indices, tf.reverse(values, axis=[-1])
86+
87+
expr = proto.create_transformed_field(expr, path.Path(["event"]),
88+
"reversed_event", _reverse)
89+
new_root = promote.promote(
90+
expr, path.Path(["reversed_event", "action", "doc_id"]),
91+
"action_doc_ids")
92+
# A poor-man's reroot.
93+
new_field = new_root.get_descendant_or_error(
94+
path.Path(["reversed_event", "action_doc_ids"]))
95+
result = calculate_with_source_paths.calculate_prensors_with_source_paths(
96+
[new_field])
97+
prensor_result, proto_summary_result = result
98+
self.assertLen(prensor_result, 1)
99+
self.assertLen(proto_summary_result, 1)
100+
leaf_node = prensor_result[0].node
101+
self.assertAllEqual(leaf_node.parent_index, [0, 0, 0, 1, 2, 2, 3, 4, 4])
102+
self.assertAllEqual(leaf_node.values,
103+
[b"h", b"i", b"j", b"g", b"e", b"f", b"c", b"a", b"b"])
104+
list_of_paths = proto_summary_result[0].paths
105+
expected = [path.Path(["event", "action", "doc_id"])]
106+
self.equal_ignore_order(list_of_paths, expected)
107+
108+
def test_calculate_prensors_with_source_paths_with_multiple_transforms(self):
109+
"""Tests get_sparse_tensors on a deep tree with a transformed field."""
110+
expr = proto_test_util._get_expression_from_session_empty_user_info()
111+
112+
# Let's make it non-trivial by transforming the data.
113+
def _reverse(parent_indices, values):
114+
return parent_indices, tf.reverse(values, axis=[-1])
115+
116+
expr = proto.create_transformed_field(expr, path.Path(["event"]),
117+
"reversed_event", _reverse)
118+
expr = proto.create_transformed_field(expr, path.Path(["reversed_event"]),
119+
"reversed_reversed_event", _reverse)
120+
new_root = promote.promote(
121+
expr, path.Path(["reversed_reversed_event", "action", "doc_id"]),
122+
"action_doc_ids")
123+
# A poor-man's reroot.
124+
new_field = new_root.get_descendant_or_error(
125+
path.Path(["reversed_reversed_event", "action_doc_ids"]))
126+
result = calculate_with_source_paths.calculate_prensors_with_source_paths(
127+
[new_field])
128+
prensor_result, proto_summary_result = result
129+
self.assertLen(prensor_result, 1)
130+
self.assertLen(proto_summary_result, 1)
131+
leaf_node = prensor_result[0].node
132+
self.assertAllEqual(leaf_node.parent_index, [0, 0, 1, 2, 2, 3, 4, 4, 4])
133+
self.assertAllEqual(leaf_node.values,
134+
[b"a", b"b", b"c", b"e", b"f", b"g", b"h", b"i", b"j"])
135+
list_of_paths = proto_summary_result[0].paths
136+
expected = [path.Path(["event", "action", "doc_id"])]
137+
self.equal_ignore_order(list_of_paths, expected)
138+
79139
def test_requirements_to_metadata_proto(self):
80140
proto_summary_result_0 = calculate_with_source_paths.ProtoRequirements(
81141
None, test_pb2.Session.DESCRIPTOR, [

struct2tensor/expression_add.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def _get_child_impl(self,
9797
if child_from_origin is None:
9898
if set_root_expr is None:
9999
raise ValueError("Must have a value in the original if there are paths")
100-
return _AddPathsExpression(set_root_expr, subtrees)
100+
if subtrees:
101+
return _AddPathsExpression(set_root_expr, subtrees)
102+
return set_root_expr
101103
if set_root_expr is not None:
102104
raise ValueError("Tried to overwrite an existing expression")
103105
return _AddPathsExpression(child_from_origin, subtrees)

struct2tensor/expression_impl/broadcast_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_broadcast_anonymous(self):
3737
prensor_test_util.create_big_prensor())
3838
new_root, p = broadcast.broadcast_anonymous(expr, path.Path(["foo"]),
3939
"user")
40-
[new_field] = new_root.get_descendant_or_error(p).get_source_expressions()
40+
new_field = new_root.get_descendant_or_error(p)
4141
self.assertFalse(new_field.is_repeated)
4242
self.assertEqual(new_field.type, tf.int32)
4343
self.assertTrue(new_field.is_leaf)

struct2tensor/expression_impl/promote_test.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def test_promote_anonymous(self):
3939
prensor_test_util.create_nested_prensor())
4040
new_root, new_field = promote.promote_anonymous(
4141
expr, path.Path(["user", "friends"]))
42-
[new_field
43-
] = new_root.get_descendant_or_error(new_field).get_source_expressions()
42+
new_field = new_root.get_descendant_or_error(new_field)
4443
self.assertTrue(new_field.is_repeated)
4544
self.assertEqual(new_field.type, tf.string)
4645
self.assertTrue(new_field.is_leaf)
@@ -65,8 +64,7 @@ def test_promote_with_schema(self):
6564

6665
new_root, new_field = promote.promote_anonymous(
6766
expr, path.Path(["user", "friends"]))
68-
[new_field
69-
] = new_root.get_descendant_or_error(new_field).get_source_expressions()
67+
new_field = new_root.get_descendant_or_error(new_field)
7068
new_schema_feature = new_field.schema_feature
7169
self.assertIsNotNone(new_schema_feature)
7270
self.assertEqual(new_schema_feature.string_domain.value[0], "a")
@@ -96,8 +94,7 @@ def test_promote_with_schema_dense_parent(self):
9694

9795
new_root, new_field = promote.promote_anonymous(
9896
expr, path.Path(["user", "friends"]))
99-
[new_field
100-
] = new_root.get_descendant_or_error(new_field).get_source_expressions()
97+
new_field = new_root.get_descendant_or_error(new_field)
10198
new_schema_feature = new_field.schema_feature
10299
self.assertIsNotNone(new_schema_feature)
103100
self.assertEqual(new_schema_feature.string_domain.value[0], "a")
@@ -140,8 +137,7 @@ def _check_lifecycle_stage(a, b):
140137

141138
new_root, new_field = promote.promote_anonymous(
142139
expr, path.Path(["user", "friends"]))
143-
[new_field
144-
] = new_root.get_descendant_or_error(new_field).get_source_expressions()
140+
new_field = new_root.get_descendant_or_error(new_field)
145141
return new_field.schema_feature.lifecycle_stage
146142

147143
self.assertEqual(
@@ -225,8 +221,7 @@ def test_promote_with_schema_dense_fraction(self):
225221

226222
new_root, new_field = promote.promote_anonymous(
227223
expr, path.Path(["user", "friends"]))
228-
[new_field
229-
] = new_root.get_descendant_or_error(new_field).get_source_expressions()
224+
new_field = new_root.get_descendant_or_error(new_field)
230225
new_schema_feature = new_field.schema_feature
231226
self.assertIsNotNone(new_schema_feature)
232227
self.assertEqual(new_schema_feature.presence.min_fraction, 0.3)

struct2tensor/expression_impl/proto.py

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@
2727
import abc
2828
from struct2tensor import calculate_options
2929
from struct2tensor import expression
30+
from struct2tensor import expression_add
3031
from struct2tensor import path
3132
from struct2tensor import prensor
3233
from struct2tensor.expression_impl import parse_message_level_ex
3334
from struct2tensor.ops import struct2tensor_ops
3435
import tensorflow as tf
35-
from typing import FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union
36+
from typing import Callable, FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union
3637

3738

3839
from google.protobuf.descriptor_pb2 import FileDescriptorSet
@@ -105,6 +106,88 @@ def create_expression_from_proto(
105106
return _ProtoRootExpression(desc, tensor_of_protos, message_format)
106107

107108

109+
# The function signature expected by `created_transformed_field`.
110+
# It describes functions of the form:
111+
#
112+
# def transform_fn(parent_indices, values):
113+
# ...
114+
# return (transformed_parent_indices, transformed_values).
115+
#
116+
# Where values/transformed_values are serialized protos of the same type
117+
# and parent_indices/transformed_parent_indices are non-decreasing int64
118+
# vectors. Each pair of indices and values must have the same shape.
119+
TransformFn = Callable[[tf.Tensor, tf.Tensor], Tuple[tf.Tensor, tf.Tensor]]
120+
121+
122+
def create_transformed_field(
123+
expr: expression.Expression, source_path: path.CoercableToPath,
124+
dest_field: StrStep, transform_fn: TransformFn) -> expression.Expression:
125+
"""Create an expression that transforms serialized proto tensors.
126+
127+
The transform_fn argument should take the form:
128+
129+
def transform_fn(parent_indices, values):
130+
...
131+
return (transformed_parent_indices, transformed_values)
132+
133+
Given:
134+
- parent_indices: an int64 vector of non-decreasing parent message indices.
135+
- values: a string vector of serialized protos having the same shape as
136+
`parent_indices`.
137+
`transform_fn` must return new parent indices and serialized values encoding
138+
the same proto message as the passed in `values`. These two vectors must
139+
have the same size, but it need not be the same as the input arguments.
140+
141+
Args:
142+
expr: a source expression containing `source_path`.
143+
source_path: the path to the field to reverse.
144+
dest_field: the name of the newly created field. This field will be a
145+
sibling of the field identified by `source_path`.
146+
transform_fn: a callable that accepts parent_indices and serialized proto
147+
values and returns a posibly modified parent_indices and values.
148+
149+
Returns:
150+
An expression.
151+
152+
Raises:
153+
ValueError: if the source path is not a proto message field.
154+
"""
155+
source_path = path.create_path(source_path)
156+
source_expr = expr.get_descendant_or_error(source_path)
157+
if not isinstance(source_expr, _ProtoChildExpression):
158+
raise ValueError(
159+
"Expected _ProtoChildExpression for field {}, but found {}.".format(
160+
str(source_path), source_expr))
161+
162+
if isinstance(source_expr, _TransformProtoChildExpression):
163+
# In order to be able to propagate fields needed for parsing, the source
164+
# expression of _TransformProtoChildExpression must always be the original
165+
# _ProtoChildExpression before any transformation. This means that two
166+
# sequentially applied _TransformProtoChildExpression would have the same
167+
# source and would apply the transformation to the source directly, instead
168+
# of one transform operating on the output of the other.
169+
# To work around this, the user supplied transform function is wrapped to
170+
# first call the source's transform function.
171+
# The downside of this approach is that the initial transform may be
172+
# applied redundantly if there are other expressions derived directly
173+
# from it.
174+
def final_transform(parent_indices: tf.Tensor,
175+
values: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
176+
parent_indices, values = source_expr.transform_fn(parent_indices, values)
177+
return transform_fn(parent_indices, values)
178+
else:
179+
final_transform = transform_fn
180+
181+
transformed_expr = _TransformProtoChildExpression(
182+
parent=source_expr._parent, # pylint: disable=protected-access
183+
desc=source_expr._desc, # pylint: disable=protected-access
184+
is_repeated=source_expr.is_repeated,
185+
name_as_field=source_expr.name_as_field,
186+
transform_fn=final_transform)
187+
dest_path = source_path.get_parent().get_child(dest_field)
188+
return expression_add.add_paths(expr, {dest_path: transformed_expr})
189+
190+
108191
class _ProtoRootNodeTensor(prensor.RootNodeTensor):
109192
"""The value of the root node.
110193
@@ -309,6 +392,44 @@ def __str__(self) -> str: # pylint: disable=g-ambiguous-str-annotation
309392
str(self.name_as_field), str(self._desc.full_name), self._parent)
310393

311394

395+
class _TransformProtoChildExpression(_ProtoChildExpression):
396+
"""Transforms the parent indices and values prior to parsing."""
397+
398+
def __init__(self, parent: "_ParentProtoExpression",
399+
desc: descriptor.Descriptor, is_repeated: bool,
400+
name_as_field: StrStep, transform_fn: TransformFn):
401+
super(_TransformProtoChildExpression,
402+
self).__init__(parent, desc, is_repeated, name_as_field)
403+
self._transform_fn = transform_fn
404+
405+
@property
406+
def transform_fn(self):
407+
return self._transform_fn
408+
409+
def calculate_from_parsed_field(
410+
self, parsed_field: struct2tensor_ops._ParsedField,
411+
destinations: Sequence[expression.Expression]) -> prensor.NodeTensor:
412+
needed_fields = _get_needed_fields(destinations)
413+
transformed_parent_indices, transformed_values = self._transform_fn(
414+
parsed_field.index, parsed_field.value)
415+
fields = parse_message_level_ex.parse_message_level_ex(
416+
transformed_values, self._desc, needed_fields)
417+
return _ProtoChildNodeTensor(transformed_parent_indices, self.is_repeated,
418+
fields)
419+
420+
def calculation_equal(self, expr: expression.Expression) -> bool:
421+
return (isinstance(expr, _TransformProtoChildExpression) and
422+
self._desc == expr._desc and # pylint: disable=protected-access
423+
self.name_as_field == expr.name_as_field
424+
and self.transform_fn == expr.transform_fn)
425+
426+
def __str__(self) -> str: # pylint: disable=g-ambiguous-str-annotation
427+
return ("_TransformProtoChildExpression: name_as_field: {} desc: {} from {}"
428+
.format(
429+
str(self.name_as_field), str(self._desc.full_name),
430+
self._parent))
431+
432+
312433
class _ProtoRootExpression(expression.Expression):
313434
"""The expression representing the parse of the root of a proto.
314435

struct2tensor/expression_impl/proto_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,35 @@ def test_create_expression_from_proto_with_any_value(self):
210210
self.assertLen(sources, 1)
211211
self.assertIs(any_expr, sources[0])
212212

213+
def test_create_transformed_field(self):
214+
expr = proto_test_util._get_expression_from_session_empty_user_info()
215+
reversed_events_expr = proto.create_transformed_field(
216+
expr, path.Path(["event"]), "reversed_event", _reverse_values)
217+
source_events = expr.get_child_or_error("event")
218+
dest_events = reversed_events_expr.get_child_or_error("reversed_event")
219+
self.assertTrue(dest_events.is_repeated)
220+
self.assertFalse(dest_events.is_leaf)
221+
self.assertEqual(source_events.type, dest_events.type)
222+
leaf_expr = reversed_events_expr.get_descendant_or_error(
223+
path.Path(["reversed_event", "action", "doc_id"]))
224+
leaf_tensor = expression_test_util.calculate_value_slowly(leaf_expr)
225+
self.assertEqual(leaf_tensor.parent_index.dtype, tf.int64)
226+
self.assertEqual(leaf_tensor.values.dtype, tf.string)
227+
228+
def test_create_reversed_field_nested(self):
229+
expr = proto_test_util._get_expression_from_session_empty_user_info()
230+
first_reverse = proto.create_transformed_field(expr, path.Path(["event"]),
231+
"reversed_event",
232+
_reverse_values)
233+
second_reverse = proto.create_transformed_field(
234+
first_reverse, path.Path(["reversed_event", "action"]),
235+
"reversed_action", _reverse_values)
236+
leaf_expr = second_reverse.get_descendant_or_error(
237+
path.Path(["reversed_event", "reversed_action", "doc_id"]))
238+
leaf_tensor = expression_test_util.calculate_value_slowly(leaf_expr)
239+
self.assertEqual(leaf_tensor.parent_index.dtype, tf.int64)
240+
self.assertEqual(leaf_tensor.values.dtype, tf.string)
241+
213242

214243
@test_util.run_all_in_graph_and_eager_modes
215244
class ProtoValuesTest(tf.test.TestCase):
@@ -321,6 +350,50 @@ def test_project_proto_map_leaf_value(self):
321350
self.assertAllEqual(result["int32_string_map[222]"], [[b"2"]])
322351
self.assertAllEqual(result["int32_string_map[223]"], [[]])
323352

353+
def test_transformed_field_values(self):
354+
expr = proto_test_util._get_expression_from_session_empty_user_info()
355+
reversed_events_expr = proto.create_transformed_field(
356+
expr, path.Path(["event"]), "reversed_event", _reverse_values)
357+
result = expression_test_util.calculate_list_map(
358+
reversed_events_expr.project(["reversed_event.action.doc_id"]), self)
359+
self.assertAllEqual(result["reversed_event.action.doc_id"],
360+
[[[[b"h"], [b"i"], [b"j"]], [[b"g"]], [[b"e"], [b"f"]]],
361+
[[[b"c"], []], [[b"a"], [b"b"]]]])
362+
363+
def test_transformed_field_values_with_transformed_parent(self):
364+
expr = proto_test_util._get_expression_from_session_empty_user_info()
365+
first_reversed_expr = proto.create_transformed_field(
366+
expr, path.Path(["event"]), "reversed_event", _reverse_values)
367+
second_reversed_expr = proto.create_transformed_field(
368+
first_reversed_expr, path.Path(["reversed_event", "action"]),
369+
"reversed_action", _reverse_values)
370+
result = expression_test_util.calculate_list_map(
371+
second_reversed_expr.project(["reversed_event.reversed_action.doc_id"]),
372+
self)
373+
self.assertAllEqual(result["reversed_event.reversed_action.doc_id"],
374+
[[[[b"b"], [b"a"], []], [[b"c"]], [[b"f"], [b"e"]]],
375+
[[[b"g"], [b"j"]], [[b"i"], [b"h"]]]])
376+
377+
def test_transformed_field_values_with_multiple_transforms(self):
378+
expr = proto_test_util._get_expression_from_session_empty_user_info()
379+
reversed_events_expr = proto.create_transformed_field(
380+
expr, path.Path(["event"]), "reversed_event", _reverse_values)
381+
reversed_events_again_expr = proto.create_transformed_field(
382+
reversed_events_expr, path.Path(["reversed_event"]),
383+
"reversed_reversed_event", _reverse_values)
384+
385+
result = expression_test_util.calculate_list_map(
386+
reversed_events_again_expr.project(
387+
["reversed_reversed_event.action.doc_id"]), self)
388+
self.assertAllEqual(result["reversed_reversed_event.action.doc_id"],
389+
[[[[b"a"], [b"b"]], [[b"c"], []], [[b"e"], [b"f"]]],
390+
[[[b"g"]], [[b"h"], [b"i"], [b"j"]]]])
391+
392+
393+
def _reverse_values(parent_indices, values):
394+
"""A simple function for testing create_transformed_field."""
395+
return parent_indices, tf.reverse(values, axis=[-1])
396+
324397

325398
if __name__ == "__main__":
326399
absltest.main()

struct2tensor/ops/decode_proto_sparse_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ to `DT_STRING` (the serialized submessage). This is to reduce the
7878
complexity of the API. The resulting string can be used as input
7979
to another instance of the decode_proto op.
8080
81-
- TensorFlow lacks support for unsigned integers. The ops represent uint64
81+
- TensorFlow lacks support for unsigned integers. The ops represent uint64_t
8282
types as a `DT_INT64` with the same twos-complement bit pattern
8383
(the obvious way). Unsigned int32_t values can be represented exactly by
8484
specifying type `DT_INT64`, or using twos-complement if the caller

0 commit comments

Comments
 (0)