Skip to content

Commit 4dd2860

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:ffi] Allow registering type ids assigned by the user
PiperOrigin-RevId: 738942202
1 parent 5f8b10c commit 4dd2860

File tree

10 files changed

+140
-40
lines changed

10 files changed

+140
-40
lines changed

xla/ffi/BUILD

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ xla_cc_test(
8181
":execution_context",
8282
":type_id_registry",
8383
"//xla/tsl/lib/core:status_test_util",
84+
"//xla/tsl/platform:statusor",
85+
"//xla/tsl/platform:test",
8486
"@com_google_absl//absl/status",
8587
"@com_google_googletest//:gtest",
8688
"@com_google_googletest//:gtest_main",
87-
"@tsl//tsl/platform:statusor",
88-
"@tsl//tsl/platform:test",
8989
],
9090
)
9191

@@ -247,9 +247,12 @@ cc_library(
247247
deps = [
248248
"//xla:util",
249249
"//xla/tsl/lib/gtl:int_type",
250+
"@com_google_absl//absl/algorithm:container",
250251
"@com_google_absl//absl/base:core_headers",
251252
"@com_google_absl//absl/container:flat_hash_map",
253+
"@com_google_absl//absl/status",
252254
"@com_google_absl//absl/status:statusor",
255+
"@com_google_absl//absl/strings:string_view",
253256
"@com_google_absl//absl/synchronization",
254257
],
255258
)
@@ -260,10 +263,10 @@ xla_cc_test(
260263
deps = [
261264
":type_id_registry",
262265
"//xla/tsl/lib/core:status_test_util",
266+
"//xla/tsl/platform:statusor",
267+
"//xla/tsl/platform:test",
263268
"@com_google_absl//absl/status",
264269
"@com_google_googletest//:gtest",
265270
"@com_google_googletest//:gtest_main",
266-
"@tsl//tsl/platform:statusor",
267-
"@tsl//tsl/platform:test",
268271
],
269272
)

xla/ffi/api/c_api.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,17 +468,22 @@ typedef XLA_FFI_Error* XLA_FFI_Handler_Register(
468468
// TypeId
469469
//===----------------------------------------------------------------------===//
470470

471+
#define XLA_FFI_UNKNOWN_TYPE_ID XLA_FFI_TypeId{0}
472+
471473
struct XLA_FFI_TypeId_Register_Args {
472474
size_t struct_size;
473475
XLA_FFI_Extension_Base* extension_start;
474476

475477
XLA_FFI_ByteSpan name;
476-
XLA_FFI_TypeId* type_id; // out
478+
XLA_FFI_TypeId* type_id; // in-out
477479
};
478480

479481
XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_TypeId_Register_Args, type_id);
480482

481-
// Registers user type `name` and returns a unique `type_id`.
483+
// Registers user type `name` with XLA. If type id is `XLA_FFI_UNKNOWN_TYPE_ID`,
484+
// XLA will assign a unique type id and return it in `type_id` out argument,
485+
// otherwise XLA will verify that type id is unique and matches the type id of
486+
// the type registered with the same `name` earlier.
482487
typedef XLA_FFI_Error* XLA_FFI_TypeId_Register(
483488
XLA_FFI_TypeId_Register_Args* args);
484489

xla/ffi/execution_context_test.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ limitations under the License.
2222
#include "absl/status/status.h"
2323
#include "xla/ffi/type_id_registry.h"
2424
#include "xla/tsl/lib/core/status_test_util.h"
25-
#include "tsl/platform/statusor.h"
26-
#include "tsl/platform/test.h"
25+
#include "xla/tsl/platform/statusor.h"
26+
#include "xla/tsl/platform/test.h"
2727

2828
namespace xla::ffi {
2929

@@ -62,9 +62,8 @@ TEST(ExecutionContextTest, InsertUserOwned) {
6262
}
6363

6464
TEST(ExecutionContextTest, InsertUserOwnedWithTypeId) {
65-
TF_ASSERT_OK_AND_ASSIGN(
66-
TypeIdRegistry::TypeId type_id,
67-
TypeIdRegistry::RegisterExternalTypeId("I32UserData"));
65+
TF_ASSERT_OK_AND_ASSIGN(TypeIdRegistry::TypeId type_id,
66+
TypeIdRegistry::AssignExternalTypeId("I32UserData"));
6867

6968
I32UserData user_data(42);
7069

xla/ffi/ffi_api.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "absl/strings/str_cat.h"
3737
#include "absl/strings/str_format.h"
3838
#include "absl/strings/str_join.h"
39+
#include "absl/strings/string_view.h"
3940
#include "xla/executable_run_options.h"
4041
#include "xla/ffi/api/api.h"
4142
#include "xla/ffi/api/c_api.h"
@@ -635,13 +636,27 @@ static XLA_FFI_Error* XLA_FFI_TypeId_Register(
635636
"XLA_FFI_ExecutionContext_Get_Args",
636637
XLA_FFI_ExecutionContext_Get_Args_STRUCT_SIZE, args->struct_size));
637638

638-
auto type_id = TypeIdRegistry::RegisterExternalTypeId(
639-
std::string_view(args->name.ptr, args->name.len));
640-
if (!type_id.ok()) {
641-
return new XLA_FFI_Error{std::move(type_id).status()};
639+
absl::string_view type_name(args->name.ptr, args->name.len);
640+
TypeIdRegistry::TypeId type_id(args->type_id->type_id);
641+
642+
// If type_id is unknown, we are registering a new type and XLA will assign a
643+
// unique type id to it.
644+
if (type_id == TypeIdRegistry::kUnknownTypeId) {
645+
auto assigned_type_id = TypeIdRegistry::AssignExternalTypeId(type_name);
646+
if (!assigned_type_id.ok()) {
647+
return new XLA_FFI_Error{std::move(assigned_type_id).status()};
648+
}
649+
650+
args->type_id->type_id = assigned_type_id->value();
651+
return nullptr;
652+
}
653+
654+
// If type_id is set, we are relying on the caller-provided unique type id.
655+
if (auto status = TypeIdRegistry::RegisterExternalTypeId(type_name, type_id);
656+
!status.ok()) {
657+
return new XLA_FFI_Error{std::move(status)};
642658
}
643659

644-
args->type_id->type_id = type_id->value();
645660
return nullptr;
646661
}
647662

xla/ffi/type_id_registry.cc

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@ limitations under the License.
2020
#include <string>
2121
#include <string_view>
2222

23+
#include "absl/algorithm/container.h"
2324
#include "absl/base/attributes.h"
2425
#include "absl/base/const_init.h"
2526
#include "absl/container/flat_hash_map.h"
27+
#include "absl/status/status.h"
2628
#include "absl/status/statusor.h"
29+
#include "absl/strings/string_view.h"
2730
#include "absl/synchronization/mutex.h"
2831
#include "xla/util.h"
2932

@@ -39,24 +42,56 @@ static ExternalTypeIdRegistry& StaticExternalTypeIdRegistry() {
3942
return *registry;
4043
}
4144

42-
TypeIdRegistry::TypeId TypeIdRegistry::GetNextTypeId() {
45+
TypeIdRegistry::TypeId TypeIdRegistry::GetNextInternalTypeId() {
4346
static auto* counter = new std::atomic<int64_t>(1);
4447
return TypeId(counter->fetch_add(1));
4548
}
4649

47-
absl::StatusOr<TypeIdRegistry::TypeId> TypeIdRegistry::RegisterExternalTypeId(
50+
TypeIdRegistry::TypeId TypeIdRegistry::GetNextExternalTypeId() {
51+
static auto* counter = new std::atomic<int64_t>(1);
52+
return TypeId(counter->fetch_add(1));
53+
}
54+
55+
absl::StatusOr<TypeIdRegistry::TypeId> TypeIdRegistry::AssignExternalTypeId(
4856
std::string_view name) {
4957
absl::MutexLock lock(&type_registry_mutex);
5058
auto& registry = StaticExternalTypeIdRegistry();
5159

52-
// Try to emplace with type id zero and fill it with real type id only if we
60+
// Try to emplace with unknow type id and fill it with real type id only if we
5361
// successfully acquired an entry for a given name.
54-
auto emplaced = registry.emplace(name, TypeId(0));
62+
auto emplaced = registry.emplace(name, kUnknownTypeId);
5563
if (!emplaced.second) {
56-
return Internal("Type id %d already registered for type name %s",
57-
emplaced.first->second.value(), name);
64+
return Internal("Type name %s already registered with type id %d", name,
65+
emplaced.first->second.value());
5866
}
59-
return emplaced.first->second = GetNextTypeId();
67+
68+
// Returns true if the registry contains an entry with a given type id.
69+
auto type_id_is_in_use = [&registry](TypeId type_id) {
70+
return absl::c_any_of(registry,
71+
[&](const auto& e) { return e.second == type_id; });
72+
};
73+
74+
// Create a new type id that is not already in use.
75+
TypeId type_id = GetNextExternalTypeId();
76+
while (type_id_is_in_use(type_id)) {
77+
type_id = GetNextExternalTypeId();
78+
}
79+
80+
return emplaced.first->second = type_id;
81+
}
82+
83+
absl::Status TypeIdRegistry::RegisterExternalTypeId(absl::string_view name,
84+
TypeId type_id) {
85+
absl::MutexLock lock(&type_registry_mutex);
86+
auto& registry = StaticExternalTypeIdRegistry();
87+
88+
auto emplaced = registry.emplace(name, type_id);
89+
if (!emplaced.second && emplaced.first->second != type_id) {
90+
return Internal("Type name %s already registered with type id %d vs %d)",
91+
name, emplaced.first->second.value(), type_id.value());
92+
}
93+
94+
return absl::OkStatus();
6095
}
6196

6297
} // namespace xla::ffi

xla/ffi/type_id_registry.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ limitations under the License.
1717
#define XLA_FFI_TYPE_ID_REGISTRY_H_
1818

1919
#include <cstdint>
20-
#include <string_view>
2120

21+
#include "absl/status/status.h"
2222
#include "absl/status/statusor.h"
23+
#include "absl/strings/string_view.h"
2324
#include "xla/tsl/lib/gtl/int_type.h"
2425

2526
namespace xla::ffi {
@@ -48,20 +49,30 @@ class TypeIdRegistry {
4849

4950
static constexpr TypeId kUnknownTypeId = TypeId(0);
5051

51-
// Registers external type with a given name in a static type registry.
52-
static absl::StatusOr<TypeId> RegisterExternalTypeId(std::string_view name);
52+
// Assigns a unique type id to an external type with a given name. Returns an
53+
// error if a type with a given name is already registered in the process.
54+
static absl::StatusOr<TypeId> AssignExternalTypeId(absl::string_view name);
55+
56+
// Registers external type with a given name and type id. Type id is provided
57+
// by the caller, and must be unique. Returns an error if a type with a given
58+
// name is already registered with a different type id.
59+
static absl::Status RegisterExternalTypeId(absl::string_view name,
60+
TypeId type_id);
5361

5462
// Returns a type id for a given type. For internal type ids only.
5563
template <typename T>
5664
static TypeId GetTypeId();
5765

5866
private:
59-
static TypeId GetNextTypeId();
67+
// We never mix external and internal type ids, so we can use different type
68+
// id spaces to assign unique ids to each type.
69+
static TypeId GetNextInternalTypeId();
70+
static TypeId GetNextExternalTypeId();
6071
};
6172

6273
template <typename T>
6374
TypeIdRegistry::TypeId TypeIdRegistry::GetTypeId() {
64-
static const TypeId id = GetNextTypeId();
75+
static const TypeId id = GetNextInternalTypeId();
6576
return id;
6677
}
6778

xla/ffi/type_id_registry_test.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ limitations under the License.
1616
#include "xla/ffi/type_id_registry.h"
1717

1818
#include <cstdint>
19+
#include <limits>
1920

2021
#include <gmock/gmock.h>
2122
#include <gtest/gtest.h>
2223
#include "absl/status/status.h"
23-
#include "tsl/platform/statusor.h"
24-
#include "tsl/platform/test.h"
24+
#include "xla/tsl/lib/core/status_test_util.h"
25+
#include "xla/tsl/platform/statusor.h"
26+
#include "xla/tsl/platform/test.h"
2527

2628
namespace xla::ffi {
2729
namespace {
@@ -30,12 +32,25 @@ using ::testing::HasSubstr;
3032

3133
TEST(TypeIdRegistryTest, RegisterExternalTypeId) {
3234
TF_ASSERT_OK_AND_ASSIGN(auto type_id,
33-
TypeIdRegistry::RegisterExternalTypeId("foo"));
35+
TypeIdRegistry::AssignExternalTypeId("foo"));
3436
EXPECT_GE(type_id.value(), 0);
3537

36-
auto duplicate_type_id = TypeIdRegistry::RegisterExternalTypeId("foo");
38+
auto duplicate_type_id = TypeIdRegistry::AssignExternalTypeId("foo");
3739
EXPECT_THAT(duplicate_type_id.status().message(),
38-
HasSubstr("already registered for type name foo"));
40+
HasSubstr("Type name foo already registered with type id"));
41+
42+
// It's ok to register the same type with same type id.
43+
TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId("foo", type_id));
44+
45+
// It's an error to register the same type with a different type id.
46+
auto wrong_type_id = TypeIdRegistry::RegisterExternalTypeId(
47+
"foo", TypeIdRegistry::TypeId(std::numeric_limits<int64_t>::max()));
48+
EXPECT_THAT(wrong_type_id.message(),
49+
HasSubstr("Type name foo already registered with type id"));
50+
51+
// It's ok to register a new type with a user-provided type id.
52+
TF_ASSERT_OK(TypeIdRegistry::RegisterExternalTypeId(
53+
"bar", TypeIdRegistry::TypeId(std::numeric_limits<int64_t>::max())));
3954
}
4055

4156
TEST(TypeIdRegistryTest, RegisterInternalTypeId) {

xla/pjrt/c/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ cc_library(
7070
"//xla/ffi:execution_context",
7171
"//xla/ffi:type_id_registry",
7272
"@com_google_absl//absl/status",
73+
"@com_google_absl//absl/strings:string_view",
7374
],
7475
)
7576

xla/pjrt/c/pjrt_c_api_ffi_extension.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ struct PJRT_FFI_TypeID_Register_Args {
3838

3939
const char* type_name;
4040
size_t type_name_size;
41-
int64_t type_id; // out
41+
int64_t type_id; // in-out
4242
};
4343
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_TypeID_Register_Args, type_id);
4444

45-
// Registers external type in a static type registry.
45+
// Registers external type in a static type registry. If `type_id` is set to `0`
46+
// XLA will assign a unique type id to it and return via out argument, otherwise
47+
// it will verify that user-provided type id matches previously registered type
48+
// id for the given type name.
4649
typedef PJRT_Error* PJRT_FFI_TypeID_Register(
4750
PJRT_FFI_TypeID_Register_Args* args);
4851

xla/pjrt/c/pjrt_c_api_ffi_internal.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
1717

1818
#include "absl/status/status.h"
19+
#include "absl/strings/string_view.h"
1920
#include "xla/ffi/execution_context.h"
2021
#include "xla/ffi/type_id_registry.h"
2122
#include "xla/pjrt/c/pjrt_c_api.h"
@@ -31,11 +32,23 @@ static PJRT_Error* PJRT_FFI_TypeID_Register(
3132
"PJRT_FFI_TypeID_Register_Args",
3233
PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE, args->struct_size));
3334

34-
PJRT_ASSIGN_OR_RETURN(
35-
auto type_id,
36-
xla::ffi::TypeIdRegistry::RegisterExternalTypeId(
37-
absl::string_view(args->type_name, args->type_name_size)));
38-
args->type_id = type_id.value();
35+
absl::string_view type_name(args->type_name, args->type_name_size);
36+
xla::ffi::TypeIdRegistry::TypeId type_id(args->type_id);
37+
38+
if (type_id == xla::ffi::TypeIdRegistry::kUnknownTypeId) {
39+
// If type_id is unknown, we are registering a new type and XLA will assign
40+
// a unique type id to it.
41+
PJRT_ASSIGN_OR_RETURN(
42+
auto assigned_type_id,
43+
xla::ffi::TypeIdRegistry::AssignExternalTypeId(type_name));
44+
args->type_id = assigned_type_id.value();
45+
46+
} else {
47+
// If type_id is set, we are relying on the caller-provided unique type id.
48+
PJRT_RETURN_IF_ERROR(
49+
xla::ffi::TypeIdRegistry::RegisterExternalTypeId(type_name, type_id));
50+
}
51+
3952
return nullptr;
4053
}
4154

0 commit comments

Comments
 (0)