@@ -21,7 +21,6 @@ limitations under the License.
2121#include < memory>
2222#include < optional>
2323#include < string>
24- #include < string_view>
2524#include < tuple>
2625#include < utility>
2726#include < vector>
@@ -31,6 +30,7 @@ limitations under the License.
3130#include " absl/status/status.h"
3231#include " absl/status/statusor.h"
3332#include " absl/strings/str_cat.h"
33+ #include " absl/strings/string_view.h"
3434#include " xla/debug_options_flags.h"
3535#include " xla/hlo/builder/xla_computation.h"
3636#include " xla/hlo/ir/hlo_casting_utils.h"
@@ -46,8 +46,11 @@ limitations under the License.
4646#include " xla/pjrt/mlir_to_hlo.h"
4747#include " xla/service/call_inliner.h"
4848#include " xla/service/custom_call_sharding_helper.h"
49- #include " xla/service/spmd/spmd_partitioner_util .h"
49+ #include " xla/service/spmd/spmd_partitioner .h"
5050#include " xla/util.h"
51+ #include " tsl/platform/errors.h"
52+ #include " tsl/platform/status.h"
53+ #include " tsl/platform/statusor.h"
5154
5255namespace xla {
5356
@@ -202,8 +205,8 @@ void SetCAPIString(JAX_CustomCallPartitioner_string& out, std::string result,
202205 out.size = scratch.back ().size ();
203206}
204207
205- std ::string_view ToStringView (JAX_CustomCallPartitioner_string data) {
206- return std ::string_view (data.data , data.size );
208+ absl ::string_view ToStringView (JAX_CustomCallPartitioner_string data) {
209+ return absl ::string_view (data.data , data.size );
207210}
208211
209212void SetCAPIAval (JAX_CustomCallPartitioner_aval& result,
@@ -343,7 +346,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args,
343346
344347absl::StatusOr<std::tuple<
345348 std::vector<xla::Shape>, std::vector<std::optional<xla::HloSharding>>,
346- xla::Shape, std::optional<xla::HloSharding>, std ::string_view>>
349+ xla::Shape, std::optional<xla::HloSharding>, absl ::string_view>>
347350ReadArgs (JAX_CustomCallPartitioner_Partition_Args* args) {
348351 std::vector<xla::Shape> shapes;
349352 std::vector<std::optional<xla::HloSharding>> shardings;
@@ -369,14 +372,14 @@ ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) {
369372 }
370373 return std::tuple<std::vector<xla::Shape>,
371374 std::vector<std::optional<xla::HloSharding>>, xla::Shape,
372- std::optional<xla::HloSharding>, std ::string_view>(
375+ std::optional<xla::HloSharding>, absl ::string_view>(
373376 std::move (shapes), std::move (shardings), std::move (result_shape),
374377 std::move (result_sharding), ToStringView (args->backend_config ));
375378}
376379
377380absl::StatusOr<std::tuple<std::vector<xla::Shape>,
378381 std::vector<std::optional<xla::HloSharding>>,
379- xla::Shape, std ::string_view>>
382+ xla::Shape, absl ::string_view>>
380383ReadArgs (JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) {
381384 std::vector<xla::Shape> shapes;
382385 std::vector<std::optional<xla::HloSharding>> shardings;
@@ -397,9 +400,9 @@ ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) {
397400 TF_ASSIGN_OR_RETURN (auto result_shape, ReadHloShape (args->result_shape ));
398401 return std::tuple<std::vector<xla::Shape>,
399402 std::vector<std::optional<xla::HloSharding>>, xla::Shape,
400- std ::string_view>(std::move (shapes), std::move (shardings),
401- std::move (result_shape),
402- ToStringView (args->backend_config ));
403+ absl ::string_view>(std::move (shapes), std::move (shardings),
404+ std::move (result_shape),
405+ ToStringView (args->backend_config ));
403406}
404407
405408PartitionScratch PopulateArgs (
@@ -455,11 +458,11 @@ absl::StatusOr<std::optional<xla::HloSharding>> ConsumeResults(
455458 return ReadHloSharding (args->result_sharding );
456459}
457460
458- absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, std ::string_view>>
461+ absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, absl ::string_view>>
459462ReadArgs (JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) {
460463 TF_ASSIGN_OR_RETURN (auto shape, ReadHloShape (args->result_shape ));
461464 TF_ASSIGN_OR_RETURN (auto sharding, ReadHloSharding (args->result_sharding ));
462- return std::tuple<xla::HloSharding, xla::Shape, std ::string_view>(
465+ return std::tuple<xla::HloSharding, xla::Shape, absl ::string_view>(
463466 std::move (sharding), std::move (shape),
464467 ToStringView (args->backend_config ));
465468}
0 commit comments