|
27 | 27 | import abc |
28 | 28 | from struct2tensor import calculate_options |
29 | 29 | from struct2tensor import expression |
| 30 | +from struct2tensor import expression_add |
30 | 31 | from struct2tensor import path |
31 | 32 | from struct2tensor import prensor |
32 | 33 | from struct2tensor.expression_impl import parse_message_level_ex |
33 | 34 | from struct2tensor.ops import struct2tensor_ops |
34 | 35 | import tensorflow as tf |
35 | | -from typing import FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union |
| 36 | +from typing import Callable, FrozenSet, Mapping, Optional, Sequence, Set, Text, Tuple, Union |
36 | 37 |
|
37 | 38 |
|
38 | 39 | from google.protobuf.descriptor_pb2 import FileDescriptorSet |
@@ -105,6 +106,88 @@ def create_expression_from_proto( |
105 | 106 | return _ProtoRootExpression(desc, tensor_of_protos, message_format) |
106 | 107 |
|
107 | 108 |
|
| 109 | +# The function signature expected by `created_transformed_field`. |
| 110 | +# It describes functions of the form: |
| 111 | +# |
| 112 | +# def transform_fn(parent_indices, values): |
| 113 | +# ... |
| 114 | +# return (transformed_parent_indices, transformed_values). |
| 115 | +# |
| 116 | +# Where values/transformed_values are serialized protos of the same type |
| 117 | +# and parent_indices/transformed_parent_indices are non-decreasing int64 |
| 118 | +# vectors. Each pair of indices and values must have the same shape. |
| 119 | +TransformFn = Callable[[tf.Tensor, tf.Tensor], Tuple[tf.Tensor, tf.Tensor]] |
| 120 | + |
| 121 | + |
| 122 | +def create_transformed_field( |
| 123 | + expr: expression.Expression, source_path: path.CoercableToPath, |
| 124 | + dest_field: StrStep, transform_fn: TransformFn) -> expression.Expression: |
| 125 | + """Create an expression that transforms serialized proto tensors. |
| 126 | +
|
| 127 | + The transform_fn argument should take the form: |
| 128 | +
|
| 129 | + def transform_fn(parent_indices, values): |
| 130 | + ... |
| 131 | + return (transformed_parent_indices, transformed_values) |
| 132 | +
|
| 133 | + Given: |
| 134 | + - parent_indices: an int64 vector of non-decreasing parent message indices. |
| 135 | + - values: a string vector of serialized protos having the same shape as |
| 136 | + `parent_indices`. |
| 137 | + `transform_fn` must return new parent indices and serialized values encoding |
| 138 | + the same proto message as the passed in `values`. These two vectors must |
| 139 | + have the same size, but it need not be the same as the input arguments. |
| 140 | +
|
| 141 | + Args: |
| 142 | + expr: a source expression containing `source_path`. |
| 143 | + source_path: the path to the field to reverse. |
| 144 | + dest_field: the name of the newly created field. This field will be a |
| 145 | + sibling of the field identified by `source_path`. |
| 146 | + transform_fn: a callable that accepts parent_indices and serialized proto |
| 147 | + values and returns a posibly modified parent_indices and values. |
| 148 | +
|
| 149 | + Returns: |
| 150 | + An expression. |
| 151 | +
|
| 152 | + Raises: |
| 153 | + ValueError: if the source path is not a proto message field. |
| 154 | + """ |
| 155 | + source_path = path.create_path(source_path) |
| 156 | + source_expr = expr.get_descendant_or_error(source_path) |
| 157 | + if not isinstance(source_expr, _ProtoChildExpression): |
| 158 | + raise ValueError( |
| 159 | + "Expected _ProtoChildExpression for field {}, but found {}.".format( |
| 160 | + str(source_path), source_expr)) |
| 161 | + |
| 162 | + if isinstance(source_expr, _TransformProtoChildExpression): |
| 163 | + # In order to be able to propagate fields needed for parsing, the source |
| 164 | + # expression of _TransformProtoChildExpression must always be the original |
| 165 | + # _ProtoChildExpression before any transformation. This means that two |
| 166 | + # sequentially applied _TransformProtoChildExpression would have the same |
| 167 | + # source and would apply the transformation to the source directly, instead |
| 168 | + # of one transform operating on the output of the other. |
| 169 | + # To work around this, the user supplied transform function is wrapped to |
| 170 | + # first call the source's transform function. |
| 171 | + # The downside of this approach is that the initial transform may be |
| 172 | + # applied redundantly if there are other expressions derived directly |
| 173 | + # from it. |
| 174 | + def final_transform(parent_indices: tf.Tensor, |
| 175 | + values: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: |
| 176 | + parent_indices, values = source_expr.transform_fn(parent_indices, values) |
| 177 | + return transform_fn(parent_indices, values) |
| 178 | + else: |
| 179 | + final_transform = transform_fn |
| 180 | + |
| 181 | + transformed_expr = _TransformProtoChildExpression( |
| 182 | + parent=source_expr._parent, # pylint: disable=protected-access |
| 183 | + desc=source_expr._desc, # pylint: disable=protected-access |
| 184 | + is_repeated=source_expr.is_repeated, |
| 185 | + name_as_field=source_expr.name_as_field, |
| 186 | + transform_fn=final_transform) |
| 187 | + dest_path = source_path.get_parent().get_child(dest_field) |
| 188 | + return expression_add.add_paths(expr, {dest_path: transformed_expr}) |
| 189 | + |
| 190 | + |
108 | 191 | class _ProtoRootNodeTensor(prensor.RootNodeTensor): |
109 | 192 | """The value of the root node. |
110 | 193 |
|
@@ -309,6 +392,44 @@ def __str__(self) -> str: # pylint: disable=g-ambiguous-str-annotation |
309 | 392 | str(self.name_as_field), str(self._desc.full_name), self._parent) |
310 | 393 |
|
311 | 394 |
|
| 395 | +class _TransformProtoChildExpression(_ProtoChildExpression): |
| 396 | + """Transforms the parent indices and values prior to parsing.""" |
| 397 | + |
| 398 | + def __init__(self, parent: "_ParentProtoExpression", |
| 399 | + desc: descriptor.Descriptor, is_repeated: bool, |
| 400 | + name_as_field: StrStep, transform_fn: TransformFn): |
| 401 | + super(_TransformProtoChildExpression, |
| 402 | + self).__init__(parent, desc, is_repeated, name_as_field) |
| 403 | + self._transform_fn = transform_fn |
| 404 | + |
| 405 | + @property |
| 406 | + def transform_fn(self): |
| 407 | + return self._transform_fn |
| 408 | + |
| 409 | + def calculate_from_parsed_field( |
| 410 | + self, parsed_field: struct2tensor_ops._ParsedField, |
| 411 | + destinations: Sequence[expression.Expression]) -> prensor.NodeTensor: |
| 412 | + needed_fields = _get_needed_fields(destinations) |
| 413 | + transformed_parent_indices, transformed_values = self._transform_fn( |
| 414 | + parsed_field.index, parsed_field.value) |
| 415 | + fields = parse_message_level_ex.parse_message_level_ex( |
| 416 | + transformed_values, self._desc, needed_fields) |
| 417 | + return _ProtoChildNodeTensor(transformed_parent_indices, self.is_repeated, |
| 418 | + fields) |
| 419 | + |
| 420 | + def calculation_equal(self, expr: expression.Expression) -> bool: |
| 421 | + return (isinstance(expr, _TransformProtoChildExpression) and |
| 422 | + self._desc == expr._desc and # pylint: disable=protected-access |
| 423 | + self.name_as_field == expr.name_as_field |
| 424 | + and self.transform_fn == expr.transform_fn) |
| 425 | + |
| 426 | + def __str__(self) -> str: # pylint: disable=g-ambiguous-str-annotation |
| 427 | + return ("_TransformProtoChildExpression: name_as_field: {} desc: {} from {}" |
| 428 | + .format( |
| 429 | + str(self.name_as_field), str(self._desc.full_name), |
| 430 | + self._parent)) |
| 431 | + |
| 432 | + |
312 | 433 | class _ProtoRootExpression(expression.Expression): |
313 | 434 | """The expression representing the parse of the root of a proto. |
314 | 435 |
|
|
0 commit comments