Skip to content

Commit 7740f70

Browse files
mzinkevitfx-copybara
authored andcommitted
Transform prensors to structured prensors.
There are three limitations: 1. The optional/repeated distinction is lost. 2. If the prensor has fields that do not conform to the structured tensor requirements, the behavior is undefined. 3. There is no way yet to represent multidimensional fields. This can be handled in a future extension. PiperOrigin-RevId: 309090341
1 parent e12fbe8 commit 7740f70

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Converting a prensor to a structured tensor.
15+
16+
Structured tensors are almost completely more general than a prensor.
17+
There are only four issues:
18+
19+
1. There is no disambiguation between an optional and a repeated field in a
20+
structured test. This is lost.
21+
2. The set of field names is more restricted in a structured tensor. This is
22+
ignored for now: behavior of disallowed field names is undefined.
23+
3. We could try to keep more shape information, especially if the size of the
24+
root prensor is known statically.
25+
4. TODO:interpret a given field name to represent "anonymous" dimensions, and
26+
create multidimensional fields when that field name is used.
27+
28+
"""
29+
30+
from __future__ import absolute_import
31+
from __future__ import division
32+
33+
from __future__ import print_function
34+
import six
35+
from struct2tensor import path
36+
from struct2tensor import prensor
37+
import tensorflow as tf
38+
from typing import Mapping, Union
39+
40+
from tensorflow.python.ops.ragged.row_partition import RowPartition # pylint: disable=g-direct-tensorflow-import
41+
from tensorflow.python.ops.structured import structured_tensor # pylint: disable=g-direct-tensorflow-import
42+
43+
44+
def prensor_to_structured_tensor(
45+
p: prensor.Prensor) -> structured_tensor.StructuredTensor:
46+
"""Creates a structured tensor from a prensor.
47+
48+
All information about optional and repeated fields is dropped.
49+
If the field names in the proto do not meet the specifications for
50+
StructuredTensor, the behavior is undefined.
51+
52+
Args:
53+
p: the prensor to convert.
54+
55+
Returns:
56+
An equivalent StructuredTensor.
57+
58+
Raises:
59+
ValueError: if the root of the prensor is not a RootNodeTensor.
60+
"""
61+
node = p.node
62+
if isinstance(node, prensor.RootNodeTensor):
63+
return _root_node_to_structured_tensor(
64+
_prensor_to_field_map(p.get_children(), node.size))
65+
raise ValueError("Must be a root prensor")
66+
67+
68+
def _root_node_to_structured_tensor(
69+
fields: Mapping[path.Step, prensor.Prensor]
70+
) -> structured_tensor.StructuredTensor:
71+
"""Convert a map of prensors to a structured tensor."""
72+
return structured_tensor.StructuredTensor.from_fields(
73+
fields=fields, shape=tf.TensorShape([None]))
74+
75+
76+
def _prensor_to_structured_tensor_helper(
77+
p: prensor.Prensor, nrows: tf.Tensor
78+
) -> Union[tf.RaggedTensor, structured_tensor.StructuredTensor]:
79+
"""Convert a prensor to a structured tensor with a certain number of rows."""
80+
node = p.node
81+
if isinstance(node, prensor.LeafNodeTensor):
82+
return _leaf_node_to_ragged_tensor(node, nrows)
83+
assert isinstance(node, prensor.ChildNodeTensor)
84+
return _child_node_to_structured_tensor(
85+
node, _prensor_to_field_map(p.get_children(), node.size), nrows)
86+
87+
88+
def _prensor_to_field_map(
89+
p_fields: Mapping[path.Step, prensor.Prensor],
90+
nrows: tf.Tensor) -> Mapping[path.Step, structured_tensor.StructuredTensor]:
91+
"""Convert a map of prensors to map of structured tensors."""
92+
return {
93+
step: _prensor_to_structured_tensor_helper(child, nrows)
94+
for step, child in six.iteritems(p_fields)
95+
}
96+
97+
98+
def _child_node_to_structured_tensor(
99+
node: prensor.ChildNodeTensor, fields: Mapping[path.Step, prensor.Prensor],
100+
nrows: tf.Tensor) -> structured_tensor.StructuredTensor:
101+
"""Convert a map of prensors to map of structured tensors."""
102+
st = structured_tensor.StructuredTensor.from_fields(
103+
fields=fields, shape=tf.TensorShape([None]))
104+
row_partition = RowPartition.from_value_rowids(
105+
value_rowids=node.parent_index, nrows=nrows)
106+
return st.partition_outer_dimension(row_partition)
107+
108+
109+
def _leaf_node_to_ragged_tensor(node: prensor.LeafNodeTensor,
110+
nrows: tf.Tensor) -> tf.RaggedTensor:
111+
"""Converts a LeafNodeTensor to a 2D ragged tensor."""
112+
return tf.RaggedTensor.from_value_rowids(
113+
values=node.values, value_rowids=node.parent_index, nrows=nrows)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Lint as: python3
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tests for StructuredTensor."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
from struct2tensor import prensor
22+
from struct2tensor import prensor_to_structured_tensor
23+
from struct2tensor.test import prensor_test_util
24+
25+
from tensorflow.python.framework import test_util # pylint: disable=g-direct-tensorflow-import
26+
from tensorflow.python.ops.structured import structured_tensor # pylint: disable=g-direct-tensorflow-import
27+
from tensorflow.python.platform import unittest # pylint: disable=g-direct-tensorflow-import
28+
29+
30+
def _make_structured_tensor(shape, fields):
31+
return structured_tensor.StructuredTensor.from_fields(
32+
fields=fields, shape=shape)
33+
34+
35+
# pylint: disable=g-long-lambda
36+
# @test_util.run_all_in_graph_and_eager_modes
37+
class PrensorToStructuredTensorTest(test_util.TensorFlowTestCase):
38+
39+
def test_simple_prensor(self):
40+
pren = prensor_test_util.create_simple_prensor()
41+
st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren)
42+
self.assertAllEqual(st._fields["foo"], [[9], [8], [7]])
43+
self.assertAllEqual(st._fields["foorepeated"], [[9], [8, 7], [6]])
44+
45+
def test_nested_prensor(self):
46+
"""Tests on a deep expression."""
47+
pren = prensor_test_util.create_nested_prensor()
48+
st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren)
49+
self.assertAllEqual(st._fields["doc"]._fields["bar"],
50+
[[[b"a"]], [[b"b", b"c"], [b"d"]], []])
51+
self.assertAllEqual(st._fields["doc"]._fields["keep_me"],
52+
[[[False]], [[True], []], []])
53+
self.assertAllEqual(st._fields["user"]._fields["friends"],
54+
[[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]])
55+
56+
def test_big_prensor(self):
57+
"""Test the big prensor.
58+
59+
a prensor expression representing:
60+
{foo:9, foorepeated:[9], doc:[{bar:["a"], keep_me:False}],
61+
user:[{friends:["a"]}]}
62+
{foo:8, foorepeated:[8,7],
63+
doc:[{bar:["b","c"],keep_me:True},{bar:["d"]}],
64+
user:[{friends:["b", "c"]},{friends:["d"]}],}
65+
{foo:7, foorepeated:[6], user:[friends:["e"]]}
66+
"""
67+
pren = prensor_test_util.create_big_prensor()
68+
st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren)
69+
self.assertAllEqual(st._fields["foo"], [[9], [8], [7]])
70+
self.assertAllEqual(st._fields["foorepeated"], [[9], [8, 7], [6]])
71+
self.assertAllEqual(st._fields["doc"]._fields["keep_me"],
72+
[[[False]], [[True], []], []])
73+
self.assertAllEqual(st._fields["user"]._fields["friends"],
74+
[[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]])
75+
self.assertAllEqual(st._fields["doc"]._fields["bar"],
76+
[[[b"a"]], [[b"b", b"c"], [b"d"]], []])
77+
78+
def test_deep_prensor(self):
79+
"""Test a prensor with three layers: root, event, and doc.
80+
81+
a prensor expression representing:
82+
{foo:9, foorepeated:[9], user:[{friends:["a"]}],
83+
event:{doc:[{bar:["a"], keep_me:False}]}}
84+
{foo:8, foorepeated:[8,7],
85+
event:{doc:[{bar:["b","c"], keep_me:True},{bar:["d"]}]},
86+
user:[{friends:["b", "c"]}, {friends:["d"]}]}
87+
{foo:7, foorepeated:[6], user:[friends:["e"]], event:{}}
88+
"""
89+
pren = prensor_test_util.create_deep_prensor()
90+
st = prensor_to_structured_tensor.prensor_to_structured_tensor(pren)
91+
self.assertAllEqual(st._fields["foo"], [[9], [8], [7]])
92+
self.assertAllEqual(st._fields["foorepeated"], [[9], [8, 7], [6]])
93+
self.assertAllEqual(st._fields["user"]._fields["friends"],
94+
[[[b"a"]], [[b"b", b"c"], [b"d"]], [[b"e"]]])
95+
self.assertAllEqual(st._fields["event"]._fields["doc"]._fields["bar"],
96+
[[[[b"a"]]], [[[b"b", b"c"], [b"d"]]], [[]]])
97+
self.assertAllEqual(st._fields["event"]._fields["doc"]._fields["keep_me"],
98+
[[[[False]]], [[[True], []]], [[]]])
99+
100+
def test_non_root_prensor(self):
101+
child_prensor = prensor.create_prensor_from_root_and_children(
102+
prensor_test_util.create_child_node([0, 0, 1, 3, 7], True), {})
103+
with self.assertRaisesRegexp(ValueError, "Must be a root prensor"):
104+
prensor_to_structured_tensor.prensor_to_structured_tensor(child_prensor)
105+
106+
107+
if __name__ == "__main__":
108+
unittest.main()

0 commit comments

Comments
 (0)