Skip to content

Commit 77f74ca

Browse files
mzinkevitfx-copybara
authored andcommitted
Converts structured tensor to prensor.
Any multidimensional fields are converted to be nested messages. PiperOrigin-RevId: 310064947
1 parent df66bc8 commit 77f74ca

File tree

2 files changed

+896
-0
lines changed

2 files changed

+896
-0
lines changed
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Converting a structured tensor to a prensor.
15+
16+
This conversion handles a variety of differences in the implementation.
17+
18+
1. Prensors disambiguate between optional and repeated fields. All fields
19+
from a structured tensor are treated as repeated.
20+
2. Structured tensors can represent a scalar structured object. This is not
21+
allowed in a prensor, so it is converted to a vector of length one.
22+
3. Structured tensors can have dimensions that are unnamed. For instance,
23+
[{"foo":[[3],[4,5]]}] is a valid structured tensor. This is translated
24+
into [{"foo":[{"data":[3]},{"data":[4,5]}]}], where "data" is the default
25+
default_field_name, which can be changed by the user.
26+
4. Structured tensors have a lot of flexibility in how fields are
27+
represented.
28+
4A. First of all, they can be dense tensors. These are converted
29+
to ragged tensors with the same shape. If the rank is unknown, then
30+
this will fail.
31+
4B. Also, ragged tensors may have a multidimensional tensor as the final
32+
value. If the rank is unknown, then this will fail.
33+
5. Prensors encode all parent-child relationships with a parent_index tensor
34+
(equivalent to a value_rowid tensor), and at the head with a size tensor
35+
(equivalent to an nrows tensor).
36+
6. Structured tensors can have "required" fields. Specifically, a structured
37+
tensor could have:
38+
StructuredTensor(shape=[2], fields={t:[3,7]}). This is transformed into a
39+
repeated field.
40+
"""
41+
42+
from __future__ import absolute_import
43+
from __future__ import division
44+
45+
from __future__ import print_function
46+
47+
from struct2tensor import path
48+
from struct2tensor import prensor
49+
import tensorflow.compat.v2 as tf
50+
from typing import Mapping, Union
51+
52+
from tensorflow.python.ops.ragged.row_partition import RowPartition # pylint: disable=g-direct-tensorflow-import
53+
from tensorflow.python.ops.structured import structured_tensor # pylint: disable=g-direct-tensorflow-import
54+
55+
56+
def structured_tensor_to_prensor(
57+
st: structured_tensor.StructuredTensor,
58+
default_field_name: path.Step = "data") -> prensor.Prensor:
59+
"""Converts a structured tensor to a prensor.
60+
61+
Certain rank information must be known. For more details about the
62+
transformation, see the notes above.
63+
64+
Args:
65+
st: the structured tensor to convert.
66+
default_field_name: the name to use when there is an unnamed dimension.
67+
68+
Returns:
69+
a logically equivalent Prensor.
70+
71+
Raises:
72+
ValueError: if there is an issue with the structured tensor.
73+
"""
74+
row_partitions = st.row_partitions
75+
if len(row_partitions) >= 1:
76+
child_prensor = _structured_tensor_to_child_prensor(st, default_field_name)
77+
return prensor.create_prensor_from_root_and_children(
78+
prensor.RootNodeTensor((st).nrows()),
79+
{default_field_name: child_prensor})
80+
elif st.rank == 1:
81+
return prensor.create_prensor_from_root_and_children(
82+
prensor.RootNodeTensor((st).nrows()),
83+
_structured_tensor_prensor_map(st, default_field_name))
84+
else:
85+
# st is a scalar StructuredTensor.
86+
return structured_tensor_to_prensor(_expand_dims(st, 0), default_field_name)
87+
88+
89+
def _structured_tensor_prensor_map(
90+
st: structured_tensor.StructuredTensor,
91+
default_field_name: path.Step) -> Mapping[path.Step, prensor.Prensor]:
92+
"""Creates a map of fields, to put in a child or root prensor."""
93+
return {
94+
k: _structured_tensor_field_to_prensor(
95+
st.field_value(k), default_field_name) for k in st.field_names()
96+
}
97+
98+
99+
# _expand_dims requires special treatment for scalar StructuredTensors, because
100+
# it is not adding a partition dimension. Therefore, we have to expand the
101+
# dimension of each field explicitly.
102+
def _expand_dims_scalar(st: structured_tensor.StructuredTensor):
103+
"""_expand_dims for a scalar structured tensor."""
104+
new_shape = tf.constant([1], dtype=tf.int64)
105+
new_fields = {k: _expand_dims(st.field_value(k), 0) for k in st.field_names()}
106+
return structured_tensor.StructuredTensor.from_fields(
107+
new_fields, shape=new_shape)
108+
109+
110+
def _expand_dims_nonnegative_axis(axis, rank):
111+
"""Get the nonnegative axis according to the rules of tf.expand_dims."""
112+
# Implementation note: equivalent to get_positive_axis(axis, rank + 1)
113+
if axis < 0:
114+
new_axis = (1 + rank) + axis
115+
if new_axis < 0:
116+
# Note: this is unreachable in the current code.
117+
raise ValueError("Axis out of range: " + str(axis))
118+
return new_axis
119+
elif axis > rank:
120+
# Note: this is unreachable in the current code.
121+
raise ValueError("Axis larger than rank: " + str(axis) + " > " + str(rank))
122+
return axis
123+
124+
125+
def _expand_dims(st, axis):
126+
"""tf.expand_dims, but works on StructuredTensor too.
127+
128+
Note: the implementation does not work if axis > 1, and will throw a
129+
ValueError.
130+
131+
Args:
132+
st: a Tensor, RaggedTensor, or StructuredTensor.
133+
axis: the axis to insert a dimension before.
134+
135+
Returns:
136+
a tensor with one more dimension (see tf.expand_dims).
137+
Raises:
138+
ValueError:
139+
if the axis is not valid.
140+
"""
141+
if not isinstance(st, structured_tensor.StructuredTensor):
142+
return tf.expand_dims(st, axis)
143+
nn_axis = _expand_dims_nonnegative_axis(axis, st.rank)
144+
if st.rank == 0:
145+
return _expand_dims_scalar(st)
146+
if nn_axis == 0:
147+
# Here, we can add a dimension 1 at the front.
148+
nrows = st.nrows()
149+
return st.partition_outer_dimension(
150+
RowPartition.from_uniform_row_length(nrows, nrows))
151+
elif nn_axis == 1:
152+
# Again, by partitioning the first dimension into vectors of length 1,
153+
# we can solve this problem.
154+
nrows = st.nrows()
155+
return st.partition_outer_dimension(
156+
RowPartition.from_uniform_row_length(
157+
tf.constant(1, dtype=nrows.dtype), nrows))
158+
else:
159+
# Note: this is unreachable in the current code.
160+
raise ValueError("Unimplemented: non-negative axis > 1 for _expand_dims")
161+
162+
163+
def _structured_tensor_field_to_prensor(
164+
field_value: Union[structured_tensor.StructuredTensor, tf.RaggedTensor,
165+
tf.Tensor],
166+
default_field_name: path.Step) -> prensor.Prensor:
167+
"""Creates a ChildNodeTensor from a field in a structured tensor."""
168+
if isinstance(field_value, structured_tensor.StructuredTensor):
169+
return _structured_tensor_to_child_prensor(field_value, default_field_name)
170+
else:
171+
return _to_leaf_prensor(field_value, default_field_name)
172+
173+
174+
def _row_partition_to_child_node_tensor(row_partition: RowPartition):
175+
"""Creates a ChildNodeTensor from a RowPartition."""
176+
return prensor.ChildNodeTensor(
177+
row_partition.with_row_splits_dtype(tf.int64).value_rowids(),
178+
is_repeated=True)
179+
180+
181+
def _one_child_prensor(row_partition: RowPartition,
182+
child_prensor: prensor.Prensor,
183+
default_field_name: path.Step) -> prensor.Prensor:
184+
"""Creates a prensor with a ChildNodeTensor at the root with one child."""
185+
child_node_tensor = _row_partition_to_child_node_tensor(row_partition)
186+
return prensor.create_prensor_from_root_and_children(
187+
child_node_tensor, {default_field_name: child_prensor})
188+
189+
190+
def _structured_tensor_to_child_prensor(
191+
st: structured_tensor.StructuredTensor,
192+
default_field_name: path.Step) -> prensor.Prensor:
193+
"""Creates a prensor with a ChildNodeTensor at the root."""
194+
row_partitions = st.row_partitions
195+
if len(row_partitions) == 1:
196+
child_st = st.merge_dims(0, 1)
197+
row_partition = row_partitions[0]
198+
return prensor.create_prensor_from_root_and_children(
199+
_row_partition_to_child_node_tensor(row_partition),
200+
_structured_tensor_prensor_map(child_st, default_field_name))
201+
elif len(row_partitions) > 1:
202+
row_partition = row_partitions[0]
203+
child_st = st.merge_dims(0, 1)
204+
return _one_child_prensor(
205+
row_partition,
206+
_structured_tensor_to_child_prensor(child_st, default_field_name),
207+
default_field_name)
208+
# This requires us to transform the scalar to a vector.
209+
# The fields could be scalars or vectors.
210+
# We need _expand_dims(...) to make this work.
211+
return _structured_tensor_to_child_prensor(
212+
_expand_dims(st, 1), default_field_name)
213+
214+
215+
def _to_leaf_prensor_helper(rt: tf.RaggedTensor,
216+
default_field_name: path.Step) -> prensor.Prensor:
217+
"""Converts a fully partitioned ragged tensor to a leaf prensor.
218+
219+
It is assumed that this is a fully partitioned ragged tensor. Specifically,
220+
the values at the end are a vector, not a 2D tensor.
221+
222+
Args:
223+
rt: a fully partitioned ragged tensor (see
224+
_fully_partitioned_ragged_tensor).
225+
default_field_name: a path.Step for unnamed dimensions.
226+
227+
Returns:
228+
a prensor, with a leaf as the root node.
229+
"""
230+
row_partition = rt._row_partition # pylint: disable=protected-access
231+
if rt.ragged_rank == 1:
232+
values = rt.values
233+
leaf = prensor.LeafNodeTensor(row_partition.value_rowids(), values, True)
234+
return prensor.create_prensor_from_root_and_children(leaf, {})
235+
else:
236+
return _one_child_prensor(
237+
row_partition, _to_leaf_prensor_helper(rt.values, default_field_name),
238+
default_field_name)
239+
240+
241+
def _partition_if_not_vector(values: tf.Tensor, dtype: tf.dtypes.DType):
242+
"""Creates a fully partitioned ragged tensor from a multidimensional tensor.
243+
244+
If the tensor is 1D, then it is unchanged.
245+
246+
Args:
247+
values: the tensor to be transformed
248+
dtype: the type of the row splits.
249+
250+
Returns:
251+
A 1D tensor or a ragged tensor.
252+
253+
Raises:
254+
ValueError: if the shape cannot be statically determined or is a scalar.
255+
"""
256+
257+
values_shape = values.shape
258+
assert values_shape is not None
259+
values_rank = values_shape.rank
260+
# values cannot have an unknown rank in a RaggedTensor field
261+
# in a StructuredTensor.
262+
assert values_rank is not None
263+
if values_rank == 1:
264+
return values
265+
# This cannot happen inside a ragged tensor.
266+
assert values_rank > 0
267+
return tf.RaggedTensor.from_tensor(
268+
values, ragged_rank=values_rank - 1, row_splits_dtype=dtype)
269+
270+
271+
def _fully_partitioned_ragged_tensor(rt: Union[tf.RaggedTensor, tf.Tensor],
272+
dtype=tf.dtypes.int64):
273+
"""Creates a fully partitioned ragged tensor from a tensor or a ragged tensor.
274+
275+
If given a tensor, it must be at least two-dimensional.
276+
277+
A fully partitioned ragged tensor is:
278+
1. A ragged tensor.
279+
2. The final values are a vector.
280+
Args:
281+
rt: input to coerce from RaggedTensor or Tensor. Must be at least 2D.
282+
dtype: requested dtype for partitions: tf.int64 or tf.int32.
283+
284+
Returns:
285+
A ragged tensor where the flat values are a 1D tensor.
286+
Raises:
287+
ValueError: if the tensor is 0D or 1D.
288+
"""
289+
if isinstance(rt, tf.RaggedTensor):
290+
rt = rt.with_row_splits_dtype(dtype)
291+
flattened_values = _partition_if_not_vector(rt.flat_values, dtype=dtype)
292+
return rt.with_flat_values(flattened_values)
293+
else:
294+
rt_shape = rt.shape
295+
assert rt_shape is not None
296+
rt_rank = rt_shape.rank
297+
assert rt_rank is not None
298+
if rt_rank < 2:
299+
# Increase the rank if it is a scalar.
300+
return _fully_partitioned_ragged_tensor(tf.expand_dims(rt, -1))
301+
return tf.RaggedTensor.from_tensor(
302+
rt, ragged_rank=rt_rank - 1, row_splits_dtype=dtype)
303+
304+
305+
def _to_leaf_prensor(rt: Union[tf.RaggedTensor, tf.Tensor],
306+
default_field_name: path.Step) -> prensor.Prensor:
307+
"""Creates a leaf tensor from a ragged tensor or tensor."""
308+
return _to_leaf_prensor_helper(
309+
_fully_partitioned_ragged_tensor(rt), default_field_name)

0 commit comments

Comments
 (0)