Skip to content

Commit 87c6a10

Browse files
ZacharyGarrettcopybara-github
authored andcommitted
Avoid trying to deconstruct GlobalData for single-element struct result types.
PiperOrigin-RevId: 656443088
1 parent 0220477 commit 87c6a10

File tree

4 files changed

+20
-15
lines changed

4 files changed

+20
-15
lines changed

RELEASE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ and this project adheres to
2525

2626
* Fixed a bug in `tff.jax.computation` that raised an error when the
2727
computation had unused arguments.
28+
* Fixed a bug when using `tff.backends.xla` execution stack that raised errors
29+
when single element structures were returned from `tff.jax.computation`
30+
wrapped methods.
2831
* Modified the model output release frequency to every 10 rounds and the final
2932
round in `tff.learning.programs.train_model`.
3033
* Loosened the `kEpsilonThreshold` constant and updated the tests of

tensorflow_federated/cc/core/impl/executors/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,8 +1126,8 @@ cc_library(
11261126
"@org_tensorflow//tensorflow/compiler/xla:xla_proto_cc",
11271127
"@org_tensorflow//tensorflow/compiler/xla/client",
11281128
"@org_tensorflow//tensorflow/compiler/xla/client:client_library",
1129-
"@org_tensorflow//tensorflow/compiler/xla/client:global_data",
11301129
"@org_tensorflow//tensorflow/compiler/xla/client:xla_computation",
1130+
"@org_tensorflow//tensorflow/compiler/xla/service",
11311131
"@org_tensorflow//tensorflow/compiler/xla/service:hlo_proto_cc",
11321132
"@org_tensorflow//tensorflow/compiler/xla/stream_executor",
11331133
"@org_tensorflow//tensorflow/compiler/xla/stream_executor:multi_platform_manager",

tensorflow_federated/cc/core/impl/executors/xla_executor.cc

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ limitations under the License
3535
#include "tensorflow/compiler/tf2xla/type_util.h"
3636
#include "tensorflow/compiler/xla/client/client.h"
3737
#include "tensorflow/compiler/xla/client/client_library.h"
38-
#include "tensorflow/compiler/xla/client/global_data.h"
3938
#include "tensorflow/compiler/xla/client/xla_computation.h"
4039
#include "tensorflow/compiler/xla/literal.h"
4140
#include "tensorflow/compiler/xla/service/hlo.pb.h"
41+
#include "tensorflow/compiler/xla/service/service.h"
4242
#include "tensorflow/compiler/xla/shape.h"
4343
#include "tensorflow/compiler/xla/stream_executor/platform.h"
4444
#include "tensorflow/compiler/xla/xla.pb.h"
@@ -650,29 +650,34 @@ class XLAExecutor : public ExecutorBase<ValueFuture> {
650650
fn->type().function().result().tensor().dtype())));
651651
}
652652
case v0::Xla::Binding::kStruct: {
653-
absl::StatusOr<std::vector<std::unique_ptr<xla::GlobalData>>>
654-
global_data_vector = xla_client_->DeconstructTuple(**result);
655-
if (!global_data_vector.ok()) {
656-
return absl::InternalError(absl::StrCat(
657-
"Error destructuring tuple in XLA executor. Message: ",
658-
global_data_vector.status().message()));
659-
}
653+
const int num_result_elements =
654+
ComputeNumElementsFromBinding(result_binding);
655+
std::vector<std::unique_ptr<xla::GlobalData>> global_data_vector;
656+
absl::StatusOr<std::vector<std::unique_ptr<xla::GlobalData>>>
657+
status_global_data_vector =
658+
xla_client_->DeconstructTuple(**result);
659+
if (!status_global_data_vector.ok()) {
660+
return absl::InternalError(absl::StrCat(
661+
"Error while destructuring tuple in XLA executor for binding: ",
662+
result_binding.DebugString(), "\nXLA message: ",
663+
status_global_data_vector.status().message()));
664+
}
665+
global_data_vector = std::move(status_global_data_vector).value();
660666
// We begin by constructing a vector of tensor-backed XLAExecutorValues.
661667
// For this purpose, we must compute the datatypes of the GlobalData
662668
// elements (XLA will need them to materialize values from the XLA
663669
// client), from the combination of the return type of the function and
664670
// the result binding.
665671
std::vector<XLAExecutorValue> flat_value_vector;
666-
int result_elements = ComputeNumElementsFromBinding(result_binding);
667672
// Preallocate the flat types tensor as required to assign directly to
668673
// its elements.
669-
std::vector<v0::TensorType> flat_tensor_types(result_elements);
674+
std::vector<v0::TensorType> flat_tensor_types(num_result_elements);
670675
TFF_TRY(FlattenTypeToTensors(fn->type().function().result(),
671676
result_binding, &flat_tensor_types));
672677
flat_value_vector.reserve(flat_tensor_types.size());
673678
for (int i = 0; i < flat_tensor_types.size(); i++) {
674679
flat_value_vector.emplace_back(XLAExecutorValue(
675-
std::move((*global_data_vector)[i]),
680+
std::move(global_data_vector[i]),
676681
TFF_TRY(
677682
PrimitiveTypeFromDataType(flat_tensor_types[i].dtype()))));
678683
}

tensorflow_federated/cc/core/impl/executors/xla_executor_test.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,7 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityScalar) {
544544
CheckMaterializeEqual(called_fn, arg_value);
545545
}
546546

547-
// NOLINTBEGIN
548547
TEST_F(XLAExecutorTest, CreateAndMaterializeIdentitySingletonStruct) {
549-
GTEST_SKIP() << "b/355521231 - re-enable once single structs work";
550548
xla::XlaBuilder builder("float_scalar_singleton_struct");
551549
xla::XlaOp parameter = xla::Parameter(
552550
&builder, 0, xla::ShapeUtil::MakeScalarShape(xla::F32), "x");
@@ -574,7 +572,6 @@ TEST_F(XLAExecutorTest, CreateAndMaterializeIdentitySingletonStruct) {
574572
test_executor_->CreateCall(embedded_fn.ref(), embedded_arg));
575573
CheckMaterializeEqual(called_fn, arg_value);
576574
}
577-
// NOLINTEND
578575

579576
TEST_F(XLAExecutorTest, CreateAndMaterializeIdentityNestedStruct) {
580577
xla::XlaBuilder builder("float_nested_struct_identity");

0 commit comments

Comments
 (0)