Skip to content

Commit 2225cf1

Browse files
kluckeGoogle-ML-Automation
authored andcommitted
Use absl::string_view instead of std::string_view as some environments (e.g. Android) don't provide std::string_view.
PiperOrigin-RevId: 707210600
1 parent 9ba7f35 commit 2225cf1

35 files changed

+160
-175
lines changed

xla/python/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ cc_library(
263263
"@com_google_absl//absl/base",
264264
"@com_google_absl//absl/container:inlined_vector",
265265
"@com_google_absl//absl/hash",
266+
"@com_google_absl//absl/log:check",
266267
"@com_google_absl//absl/strings",
267268
"@com_google_absl//absl/strings:str_format",
268269
"@nanobind",
@@ -502,12 +503,15 @@ cc_library(
502503
"//xla:comparison_util",
503504
"//xla/pjrt:exceptions",
504505
"//xla/pjrt:host_callback",
506+
"//xla/pjrt:transpose",
505507
"//xla/service:custom_call_status",
506508
"//xla/service:custom_call_target_registry",
507509
"//xla/service:platform_util",
508510
"@com_google_absl//absl/base",
511+
"@com_google_absl//absl/log:check",
509512
"@com_google_absl//absl/status:statusor",
510513
"@com_google_absl//absl/strings",
514+
"@com_google_absl//absl/types:span",
511515
"@nanobind",
512516
"@tsl//tsl/platform:errors",
513517
] + if_rocm(
@@ -589,6 +593,7 @@ cc_library(
589593
"@nanobind",
590594
"@local_config_python//:python_headers", # build_cleaner: keep
591595
"//xla/pjrt:pjrt_client",
596+
"//xla/pjrt:pjrt_layout",
592597
"//xla/pjrt:status_casters",
593598
"@tsl//tsl/platform:logging",
594599
"@tsl//tsl/profiler/lib:traceme",
@@ -631,6 +636,9 @@ cc_library(
631636
"@com_google_absl//absl/status",
632637
"@com_google_absl//absl/status:statusor",
633638
"@com_google_absl//absl/strings",
639+
"@tsl//tsl/platform:errors",
640+
"@tsl//tsl/platform:status",
641+
"@tsl//tsl/platform:statusor",
634642
],
635643
)
636644

@@ -1422,6 +1430,7 @@ cc_library(
14221430
copts = ["-fexceptions"],
14231431
features = ["-use_header_modules"],
14241432
deps = [
1433+
"@com_google_absl//absl/strings:string_view",
14251434
"@com_google_absl//absl/types:span",
14261435
"@nanobind",
14271436
# copybara:uncomment "//third_party/py/numpy:multiarray",

xla/python/callback.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ limitations under the License.
2323
#include <memory>
2424
#include <optional>
2525
#include <string>
26-
#include <string_view>
2726
#include <utility>
2827
#include <vector>
2928

@@ -32,10 +31,10 @@ limitations under the License.
3231
#include "absl/status/statusor.h"
3332
#include "absl/strings/str_format.h"
3433
#include "absl/strings/str_join.h"
34+
#include "absl/strings/string_view.h"
3535
#include "absl/types/span.h"
3636
#include "nanobind/nanobind.h"
3737
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
38-
#include "xla/pjrt/host_callback.h"
3938
#include "xla/pjrt/transpose.h"
4039
#include "xla/primitive_util.h"
4140
#include "xla/python/nb_numpy.h"
@@ -127,7 +126,7 @@ absl::StatusOr<nb::tuple> CpuCallback::Call(nb::tuple args) {
127126
if (!PyTuple_Check(result_object.ptr())) {
128127
return absl::InternalError(
129128
absl::StrFormat("CPU callback expected a tuple result, got %s",
130-
nb::cast<std::string_view>(nb::repr(result_object))));
129+
nb::cast<absl::string_view>(nb::repr(result_object))));
131130
}
132131
if (PyTuple_Size(result_object.ptr()) != results_.size()) {
133132
return absl::InternalError(
@@ -142,7 +141,7 @@ absl::StatusOr<nb::tuple> CpuCallback::Call(nb::tuple args) {
142141
if (!output.is_none()) {
143142
return absl::InternalError(absl::StrFormat(
144143
"Token output from Python callback should be None, got %s",
145-
nb::cast<std::string_view>(nb::repr(output))));
144+
nb::cast<absl::string_view>(nb::repr(output))));
146145
}
147146
continue;
148147
}

xla/python/custom_call_sharding.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ limitations under the License.
1919
#include <memory>
2020
#include <optional>
2121
#include <string>
22-
#include <string_view>
2322
#include <tuple>
2423
#include <utility>
2524
#include <vector>
@@ -93,7 +92,7 @@ class PyCustomCallPartitionerCallbacks {
9392
xla::Shape result_shape = std::move(std::get<2>(args_tuple));
9493
std::optional<xla::HloSharding> result_sharding =
9594
std::move(std::get<3>(args_tuple));
96-
std::string_view backend_config = std::move(std::get<4>(args_tuple));
95+
absl::string_view backend_config = std::move(std::get<4>(args_tuple));
9796

9897
{
9998
nb::gil_scoped_acquire gil;
@@ -118,7 +117,7 @@ class PyCustomCallPartitionerCallbacks {
118117
return xla::Internal(
119118
"Shardings returned from partitioning: expected "
120119
"Tuple[bytes, List[HloSharding], HloSharding] got: %s",
121-
nb::cast<std::string_view>(nb::repr(py_result)));
120+
nb::cast<absl::string_view>(nb::repr(py_result)));
122121
}
123122
} catch (const nb::python_error& e) {
124123
return xla::Internal("custom_partitioner: %s", e.what());
@@ -136,7 +135,7 @@ class PyCustomCallPartitionerCallbacks {
136135
std::vector<std::optional<xla::HloSharding>> arg_shardings =
137136
std::move(std::get<1>(args_tuple));
138137
xla::Shape result_shape = std::move(std::get<2>(args_tuple));
139-
std::string_view backend_config = std::move(std::get<3>(args_tuple));
138+
absl::string_view backend_config = std::move(std::get<3>(args_tuple));
140139

141140
std::optional<HloSharding> result;
142141
nb::gil_scoped_acquire gil;
@@ -161,7 +160,7 @@ class PyCustomCallPartitionerCallbacks {
161160
TF_ASSIGN_OR_RETURN(auto args_tuple, jax::ReadArgs(args));
162161
xla::HloSharding result_sharding = std::move(std::get<0>(args_tuple));
163162
xla::Shape result_shape = std::move(std::get<1>(args_tuple));
164-
std::string_view backend_config = std::move(std::get<2>(args_tuple));
163+
absl::string_view backend_config = std::move(std::get<2>(args_tuple));
165164

166165
nb::gil_scoped_acquire gil;
167166
try {
@@ -229,7 +228,7 @@ void BuildCustomCallShardingPybindAPI(nb::module_& m) {
229228
return;
230229
}
231230

232-
if (std::string_view(c_api->name()) != "pjrt_c_api") {
231+
if (absl::string_view(c_api->name()) != "pjrt_c_api") {
233232
throw absl::InvalidArgumentError(
234233
"Argument to register_custom_call_partitioner was not a "
235234
"pjrt_c_api capsule.");

xla/python/custom_partition_callback.cc

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ limitations under the License.
2121
#include <memory>
2222
#include <optional>
2323
#include <string>
24-
#include <string_view>
2524
#include <tuple>
2625
#include <utility>
2726
#include <vector>
@@ -31,6 +30,7 @@ limitations under the License.
3130
#include "absl/status/status.h"
3231
#include "absl/status/statusor.h"
3332
#include "absl/strings/str_cat.h"
33+
#include "absl/strings/string_view.h"
3434
#include "xla/debug_options_flags.h"
3535
#include "xla/hlo/builder/xla_computation.h"
3636
#include "xla/hlo/ir/hlo_casting_utils.h"
@@ -46,8 +46,11 @@ limitations under the License.
4646
#include "xla/pjrt/mlir_to_hlo.h"
4747
#include "xla/service/call_inliner.h"
4848
#include "xla/service/custom_call_sharding_helper.h"
49-
#include "xla/service/spmd/spmd_partitioner_util.h"
49+
#include "xla/service/spmd/spmd_partitioner.h"
5050
#include "xla/util.h"
51+
#include "tsl/platform/errors.h"
52+
#include "tsl/platform/status.h"
53+
#include "tsl/platform/statusor.h"
5154

5255
namespace xla {
5356

@@ -202,8 +205,8 @@ void SetCAPIString(JAX_CustomCallPartitioner_string& out, std::string result,
202205
out.size = scratch.back().size();
203206
}
204207

205-
std::string_view ToStringView(JAX_CustomCallPartitioner_string data) {
206-
return std::string_view(data.data, data.size);
208+
absl::string_view ToStringView(JAX_CustomCallPartitioner_string data) {
209+
return absl::string_view(data.data, data.size);
207210
}
208211

209212
void SetCAPIAval(JAX_CustomCallPartitioner_aval& result,
@@ -343,7 +346,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args,
343346

344347
absl::StatusOr<std::tuple<
345348
std::vector<xla::Shape>, std::vector<std::optional<xla::HloSharding>>,
346-
xla::Shape, std::optional<xla::HloSharding>, std::string_view>>
349+
xla::Shape, std::optional<xla::HloSharding>, absl::string_view>>
347350
ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) {
348351
std::vector<xla::Shape> shapes;
349352
std::vector<std::optional<xla::HloSharding>> shardings;
@@ -369,14 +372,14 @@ ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args) {
369372
}
370373
return std::tuple<std::vector<xla::Shape>,
371374
std::vector<std::optional<xla::HloSharding>>, xla::Shape,
372-
std::optional<xla::HloSharding>, std::string_view>(
375+
std::optional<xla::HloSharding>, absl::string_view>(
373376
std::move(shapes), std::move(shardings), std::move(result_shape),
374377
std::move(result_sharding), ToStringView(args->backend_config));
375378
}
376379

377380
absl::StatusOr<std::tuple<std::vector<xla::Shape>,
378381
std::vector<std::optional<xla::HloSharding>>,
379-
xla::Shape, std::string_view>>
382+
xla::Shape, absl::string_view>>
380383
ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) {
381384
std::vector<xla::Shape> shapes;
382385
std::vector<std::optional<xla::HloSharding>> shardings;
@@ -397,9 +400,9 @@ ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args) {
397400
TF_ASSIGN_OR_RETURN(auto result_shape, ReadHloShape(args->result_shape));
398401
return std::tuple<std::vector<xla::Shape>,
399402
std::vector<std::optional<xla::HloSharding>>, xla::Shape,
400-
std::string_view>(std::move(shapes), std::move(shardings),
401-
std::move(result_shape),
402-
ToStringView(args->backend_config));
403+
absl::string_view>(std::move(shapes), std::move(shardings),
404+
std::move(result_shape),
405+
ToStringView(args->backend_config));
403406
}
404407

405408
PartitionScratch PopulateArgs(
@@ -455,11 +458,11 @@ absl::StatusOr<std::optional<xla::HloSharding>> ConsumeResults(
455458
return ReadHloSharding(args->result_sharding);
456459
}
457460

458-
absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, std::string_view>>
461+
absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, absl::string_view>>
459462
ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args) {
460463
TF_ASSIGN_OR_RETURN(auto shape, ReadHloShape(args->result_shape));
461464
TF_ASSIGN_OR_RETURN(auto sharding, ReadHloSharding(args->result_sharding));
462-
return std::tuple<xla::HloSharding, xla::Shape, std::string_view>(
465+
return std::tuple<xla::HloSharding, xla::Shape, absl::string_view>(
463466
std::move(sharding), std::move(shape),
464467
ToStringView(args->backend_config));
465468
}

xla/python/custom_partition_callback.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
#include <memory>
1919
#include <optional>
2020
#include <string>
21-
#include <string_view>
2221
#include <tuple>
2322

2423
#include "xla/hlo/ir/hlo_instruction.h"
@@ -37,7 +36,7 @@ PartitionScratch PopulateArgs(JAX_CustomCallPartitioner_Partition_Args* args,
3736
const xla::HloInstruction* instruction);
3837
absl::StatusOr<std::tuple<
3938
std::vector<xla::Shape>, std::vector<std::optional<xla::HloSharding>>,
40-
xla::Shape, std::optional<xla::HloSharding>, std::string_view>>
39+
xla::Shape, std::optional<xla::HloSharding>, absl::string_view>>
4140
ReadArgs(JAX_CustomCallPartitioner_Partition_Args* args);
4241
void PopulateResults(
4342
absl::StatusOr<std::tuple<std::string, std::vector<xla::HloSharding>,
@@ -50,7 +49,7 @@ ConsumeResults(JAX_CustomCallPartitioner_Partition_Args* args);
5049

5150
absl::StatusOr<std::tuple<std::vector<xla::Shape>,
5251
std::vector<std::optional<xla::HloSharding>>,
53-
xla::Shape, std::string_view>>
52+
xla::Shape, absl::string_view>>
5453
ReadArgs(JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args);
5554
PartitionScratch PopulateArgs(
5655
JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args,
@@ -61,7 +60,7 @@ void PopulateResults(
6160
absl::StatusOr<std::optional<xla::HloSharding>> ConsumeResults(
6261
JAX_CustomCallPartitioner_InferShardingFromOperands_Args* args);
6362

64-
absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, std::string_view>>
63+
absl::StatusOr<std::tuple<xla::HloSharding, xla::Shape, absl::string_view>>
6564
ReadArgs(JAX_CustomCallPartitioner_PropagateUserSharding_Args* args);
6665
PartitionScratch PopulateArgs(
6766
JAX_CustomCallPartitioner_PropagateUserSharding_Args* args,

xla/python/dlpack.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ limitations under the License.
2222
#include <memory>
2323
#include <numeric>
2424
#include <optional>
25-
#include <string_view>
2625
#include <utility>
2726
#include <vector>
2827

@@ -458,11 +457,11 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
458457
auto* cpu_pjrt_client = cpu_client ? (*cpu_client)->pjrt_client() : nullptr;
459458
auto* gpu_pjrt_client = gpu_client ? (*gpu_client)->pjrt_client() : nullptr;
460459

461-
if (std::string_view(tensor.name()) != kDlTensorCapsuleName) {
460+
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
462461
return InvalidArgument(
463462
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
464463
"Note that a DLPack tensor may be consumed at most once.",
465-
std::string_view(tensor.name()));
464+
absl::string_view(tensor.name()));
466465
}
467466
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor.data());
468467
if (dlmt->dl_tensor.ndim < 0) {
@@ -552,11 +551,11 @@ absl::StatusOr<nb::object> DLPackManagedTensorToBuffer(
552551
"DLPack is only supported for devices addressable by the current "
553552
"process.");
554553
}
555-
if (std::string_view(tensor.name()) != kDlTensorCapsuleName) {
554+
if (absl::string_view(tensor.name()) != kDlTensorCapsuleName) {
556555
return InvalidArgument(
557556
"DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
558557
"Note that a DLPack tensor may be consumed at most once.",
559-
std::string_view(tensor.name()));
558+
absl::string_view(tensor.name()));
560559
}
561560
DLManagedTensor* dlmt = static_cast<DLManagedTensor*>(tensor.data());
562561
if (dlmt->dl_tensor.ndim < 0) {

xla/python/jax_jit.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ limitations under the License.
3333
#include <optional>
3434
#include <stdexcept>
3535
#include <string>
36-
#include <string_view>
3736
#include <utility>
3837
#include <vector>
3938

@@ -45,6 +44,7 @@ limitations under the License.
4544
#include "absl/strings/str_cat.h"
4645
#include "absl/strings/str_format.h"
4746
#include "absl/strings/str_join.h"
47+
#include "absl/strings/string_view.h"
4848
#include "absl/types/span.h"
4949
#include "nanobind/nanobind.h"
5050
#include "nanobind/stl/optional.h" // IWYU pragma: keep
@@ -53,6 +53,7 @@ limitations under the License.
5353
#include "nanobind/stl/string_view.h" // IWYU pragma: keep
5454
#include "nanobind/stl/vector.h" // IWYU pragma: keep
5555
#include "xla/pjrt/pjrt_client.h"
56+
#include "xla/pjrt/pjrt_layout.h"
5657
#include "xla/pjrt/status_casters.h"
5758
#include "xla/python/nb_absl_inlined_vector.h" // IWYU pragma: keep
5859
#include "xla/python/nb_absl_span.h" // IWYU pragma: keep
@@ -147,7 +148,7 @@ bool FetchMemoriesFlag() {
147148

148149
std::string ArgumentSignature::DebugString() const {
149150
auto py_object_formatter = [](std::string* out, const nb::object& o) {
150-
out->append(nb::cast<std::string_view>(nb::str(o)));
151+
out->append(nb::cast<absl::string_view>(nb::str(o)));
151152
};
152153
auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) {
153154
out->append(d.ToString());
@@ -188,16 +189,16 @@ bool ArgumentSignature::operator==(const ArgumentSignature& other) const {
188189
"static arguments should be comparable using __eq__."
189190
"The following error was raised when comparing two objects of "
190191
"types ",
191-
nb::cast<std::string_view>(nb::str(a.type())), " and ",
192-
nb::cast<std::string_view>(nb::str(b.type())),
192+
nb::cast<absl::string_view>(nb::str(a.type())), " and ",
193+
nb::cast<absl::string_view>(nb::str(b.type())),
193194
". The error was:\n", e.what()));
194195
}
195196
});
196197
}
197198

198199
std::string CallSignature::DebugString() const {
199200
auto py_object_formatter = [](std::string* out, const nb::object& o) {
200-
out->append(nb::cast<std::string_view>(nb::str(o)));
201+
out->append(nb::cast<absl::string_view>(nb::str(o)));
201202
};
202203
auto signature_formatter = [](std::string* out,
203204
const xla::PyArgSignature& s) {

xla/python/jax_jit.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ limitations under the License.
2222
#include <optional>
2323
#include <stdexcept>
2424
#include <string>
25-
#include <string_view>
2625
#include <utility>
2726
#include <vector>
2827

@@ -140,8 +139,8 @@ H AbslHashValue(H h, const ArgumentSignature& s) {
140139
throw std::invalid_argument(absl::StrCat(
141140
"Non-hashable static arguments are not supported. An error occurred "
142141
"while trying to hash an object of type ",
143-
nanobind::cast<std::string_view>(nanobind::str(static_arg.type())),
144-
", ", nanobind::cast<std::string_view>(nanobind::str(static_arg)),
142+
nanobind::cast<absl::string_view>(nanobind::str(static_arg.type())),
143+
", ", nanobind::cast<absl::string_view>(nanobind::str(static_arg)),
145144
". The error was:\n", e.what(), "\n"));
146145
}
147146
h = H::combine(std::move(h), hash);
@@ -185,7 +184,7 @@ absl::Status ParseArguments(
185184
// (a) equality (delegated to Python) of the static arguments.
186185
struct CallSignature {
187186
// Not part of the signature, but we need it for error messages.
188-
std::string_view function_name;
187+
absl::string_view function_name;
189188

190189
ArgumentSignature arg_signature;
191190

0 commit comments

Comments
 (0)