Skip to content

Commit 3255752

Browse files
mzinkevitfx-copybara
authored andcommitted
Fixed a bug where the nrows was not correctly set.
The nrows of the child node and the parent node were confused. Since the outcome is immediately modified, this only matters when we update StructuredTensor to have more rigorous unit tests (see cl/456841301). In addition, added some more helpful debugging so any errors give you information about which field they came from. PiperOrigin-RevId: 457122359
1 parent 53eac85 commit 3255752

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

struct2tensor/prensor_to_structured_tensor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,21 @@ def _prensor_to_field_map(
8585
p_fields: Mapping[path.Step, prensor.Prensor],
8686
nrows: tf.Tensor) -> Mapping[path.Step, structured_tensor.StructuredTensor]:
8787
"""Convert a map of prensors to map of structured tensors."""
88-
return {
89-
step: _prensor_to_structured_tensor_helper(child, nrows)
90-
for step, child in p_fields.items()
91-
}
88+
result = {}
89+
for step, child in p_fields.items():
90+
try:
91+
result[step] = _prensor_to_structured_tensor_helper(child, nrows)
92+
except ValueError as err:
93+
raise ValueError(f"Error in field: {step}") from err
94+
return result
9295

9396

9497
def _child_node_to_structured_tensor(
9598
node: prensor.ChildNodeTensor, fields: Mapping[path.Step, prensor.Prensor],
9699
nrows: tf.Tensor) -> structured_tensor.StructuredTensor:
97100
"""Convert a map of prensors to map of structured tensors."""
98101
st = structured_tensor.StructuredTensor.from_fields(
99-
fields=fields, shape=tf.TensorShape([None]), nrows=nrows)
102+
fields=fields, shape=tf.TensorShape([None]), nrows=node.size)
100103
row_partition = RowPartition.from_value_rowids(
101104
value_rowids=node.parent_index, nrows=nrows)
102105
return st.partition_outer_dimension(row_partition)

struct2tensor/prensor_to_structured_tensor_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,15 @@
1313
# limitations under the License.
1414
"""Tests for StructuredTensor."""
1515

16+
from google.protobuf import text_format
1617
from struct2tensor import calculate
1718
from struct2tensor import prensor
1819
from struct2tensor import prensor_to_structured_tensor
1920
from struct2tensor.expression_impl import proto
2021
from struct2tensor.test import prensor_test_util
2122
from struct2tensor.test import test_pb2
22-
2323
import tensorflow as tf
2424

25-
from google.protobuf import text_format
26-
2725

2826
# @test_util.run_all_in_graph_and_eager_modes
2927
class PrensorToStructuredTensorTest(tf.test.TestCase):

0 commit comments

Comments
 (0)