@@ -35,10 +35,10 @@ limitations under the License
3535#include " tensorflow/compiler/tf2xla/type_util.h"
3636#include " tensorflow/compiler/xla/client/client.h"
3737#include " tensorflow/compiler/xla/client/client_library.h"
38- #include " tensorflow/compiler/xla/client/global_data.h"
3938#include " tensorflow/compiler/xla/client/xla_computation.h"
4039#include " tensorflow/compiler/xla/literal.h"
4140#include " tensorflow/compiler/xla/service/hlo.pb.h"
41+ #include " tensorflow/compiler/xla/service/service.h"
4242#include " tensorflow/compiler/xla/shape.h"
4343#include " tensorflow/compiler/xla/stream_executor/platform.h"
4444#include " tensorflow/compiler/xla/xla.pb.h"
@@ -650,29 +650,34 @@ class XLAExecutor : public ExecutorBase<ValueFuture> {
650650 fn->type ().function ().result ().tensor ().dtype ())));
651651 }
652652 case v0::Xla::Binding::kStruct : {
653- absl::StatusOr<std::vector<std::unique_ptr<xla::GlobalData>>>
654- global_data_vector = xla_client_->DeconstructTuple (**result);
655- if (!global_data_vector.ok ()) {
656- return absl::InternalError (absl::StrCat (
657- " Error destructuring tuple in XLA executor. Message: " ,
658- global_data_vector.status ().message ()));
659- }
653+ const int num_result_elements =
654+ ComputeNumElementsFromBinding (result_binding);
655+ std::vector<std::unique_ptr<xla::GlobalData>> global_data_vector;
656+ absl::StatusOr<std::vector<std::unique_ptr<xla::GlobalData>>>
657+ status_global_data_vector =
658+ xla_client_->DeconstructTuple (**result);
659+ if (!status_global_data_vector.ok ()) {
660+ return absl::InternalError (absl::StrCat (
661+ " Error while destructuring tuple in XLA executor for binding: " ,
662+ result_binding.DebugString (), " \n XLA message: " ,
663+ status_global_data_vector.status ().message ()));
664+ }
665+ global_data_vector = std::move (status_global_data_vector).value ();
660666 // We begin by constructing a vector of tensor-backed XLAExecutorValues.
661667 // For this purpose, we must compute the datatypes of the GlobalData
662668 // elements (XLA will need them to materialize values from the XLA
663669 // client), from the combination of the return type of the function and
664670 // the result binding.
665671 std::vector<XLAExecutorValue> flat_value_vector;
666- int result_elements = ComputeNumElementsFromBinding (result_binding);
667672 // Preallocate the flat types tensor as required to assign directly to
668673 // its elements.
669- std::vector<v0::TensorType> flat_tensor_types (result_elements );
674+ std::vector<v0::TensorType> flat_tensor_types (num_result_elements );
670675 TFF_TRY (FlattenTypeToTensors (fn->type ().function ().result (),
671676 result_binding, &flat_tensor_types));
672677 flat_value_vector.reserve (flat_tensor_types.size ());
673678 for (int i = 0 ; i < flat_tensor_types.size (); i++) {
674679 flat_value_vector.emplace_back (XLAExecutorValue (
675- std::move ((* global_data_vector) [i]),
680+ std::move (global_data_vector[i]),
676681 TFF_TRY (
677682 PrimitiveTypeFromDataType (flat_tensor_types[i].dtype ()))));
678683 }
0 commit comments