Skip to content

Commit 4dcdec4

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:ffi] Add TypeRegistry::TypeInfo to be able to register functions to manipulate user-defined types
PiperOrigin-RevId: 820811829
1 parent b14d379 commit 4dcdec4

26 files changed

+478
-264
lines changed

xla/ffi/BUILD

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ cc_library(
6262
srcs = ["execution_context.cc"],
6363
hdrs = ["execution_context.h"],
6464
deps = [
65-
":type_id_registry",
65+
":type_registry",
6666
"//xla:util",
6767
"//xla/tsl/platform:errors",
6868
"//xla/tsl/platform:logging",
@@ -79,7 +79,7 @@ xla_cc_test(
7979
srcs = ["execution_context_test.cc"],
8080
deps = [
8181
":execution_context",
82-
":type_id_registry",
82+
":type_registry",
8383
"//xla/tsl/lib/core:status_test_util",
8484
"//xla/tsl/platform:statusor",
8585
"//xla/tsl/platform:test",
@@ -94,13 +94,14 @@ cc_library(
9494
srcs = ["execution_state.cc"],
9595
hdrs = ["execution_state.h"],
9696
deps = [
97-
":type_id_registry",
97+
":type_registry",
9898
"//xla:util",
99+
"//xla/tsl/platform:statusor",
100+
"//xla/tsl/util:safe_reinterpret_cast",
101+
"@com_google_absl//absl/base:core_headers",
99102
"@com_google_absl//absl/log:check",
100103
"@com_google_absl//absl/status",
101104
"@com_google_absl//absl/status:statusor",
102-
"@tsl//tsl/platform:logging",
103-
"@tsl//tsl/platform:statusor",
104105
],
105106
)
106107

@@ -109,11 +110,12 @@ xla_cc_test(
109110
srcs = ["execution_state_test.cc"],
110111
deps = [
111112
":execution_state",
113+
":type_registry",
112114
"//xla/tsl/lib/core:status_test_util",
115+
"//xla/tsl/platform:statusor",
116+
"//xla/tsl/platform:test",
113117
"@com_google_googletest//:gtest",
114118
"@com_google_googletest//:gtest_main",
115-
"@tsl//tsl/platform:statusor",
116-
"@tsl//tsl/platform:test",
117119
],
118120
)
119121

@@ -124,7 +126,7 @@ cc_library(
124126
":api",
125127
":execution_context",
126128
":execution_state",
127-
":type_id_registry",
129+
":type_registry",
128130
"//xla:executable_run_options",
129131
"//xla:shape_util",
130132
"//xla:types",
@@ -141,6 +143,7 @@ cc_library(
141143
"@com_google_absl//absl/algorithm:container",
142144
"@com_google_absl//absl/base:core_headers",
143145
"@com_google_absl//absl/base:nullability",
146+
"@com_google_absl//absl/log:check",
144147
"@com_google_absl//absl/status",
145148
"@com_google_absl//absl/status:statusor",
146149
"@com_google_absl//absl/strings:string_view",
@@ -160,7 +163,7 @@ cc_library(
160163
":call_frame",
161164
":execution_context",
162165
":execution_state",
163-
":type_id_registry",
166+
":type_registry",
164167
"//xla:executable_run_options",
165168
"//xla:util",
166169
"//xla/ffi/api:c_api",
@@ -218,7 +221,7 @@ xla_cc_test(
218221
":execution_state",
219222
":ffi",
220223
":ffi_api",
221-
":type_id_registry",
224+
":type_registry",
222225
"//xla:executable_run_options",
223226
"//xla:xla_data_proto_cc",
224227
"//xla/ffi/api:c_api",
@@ -243,27 +246,31 @@ xla_cc_test(
243246
)
244247

245248
cc_library(
246-
name = "type_id_registry",
247-
srcs = ["type_id_registry.cc"],
248-
hdrs = ["type_id_registry.h"],
249+
name = "type_registry",
250+
srcs = ["type_registry.cc"],
251+
hdrs = ["type_registry.h"],
249252
deps = [
250253
"//xla:util",
251254
"//xla/tsl/lib/gtl:int_type",
255+
"//xla/tsl/util:safe_reinterpret_cast",
252256
"@com_google_absl//absl/algorithm:container",
253257
"@com_google_absl//absl/base:core_headers",
258+
"@com_google_absl//absl/base:no_destructor",
254259
"@com_google_absl//absl/container:flat_hash_map",
260+
"@com_google_absl//absl/log",
255261
"@com_google_absl//absl/status",
256262
"@com_google_absl//absl/status:statusor",
263+
"@com_google_absl//absl/strings:str_format",
257264
"@com_google_absl//absl/strings:string_view",
258265
"@com_google_absl//absl/synchronization",
259266
],
260267
)
261268

262269
xla_cc_test(
263-
name = "type_id_registry_test",
264-
srcs = ["type_id_registry_test.cc"],
270+
name = "type_registry_test",
271+
srcs = ["type_registry_test.cc"],
265272
deps = [
266-
":type_id_registry",
273+
":type_registry",
267274
"//xla/tsl/lib/core:status_test_util",
268275
"//xla/tsl/platform:statusor",
269276
"//xla/tsl/platform:test",

xla/ffi/api/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ xla_cc_test(
8989
"//xla/ffi:execution_context",
9090
"//xla/ffi:execution_state",
9191
"//xla/ffi:ffi_api",
92-
"//xla/ffi:type_id_registry",
92+
"//xla/ffi:type_registry",
9393
"//xla/stream_executor:device_memory",
9494
"//xla/stream_executor:device_memory_allocator",
9595
"//xla/tsl/concurrency:async_value",

xla/ffi/api/api.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,13 @@ inline XLA_FFI_Error* Ffi::RegisterTypeId(const XLA_FFI_Api* api,
348348
std::string_view name,
349349
XLA_FFI_TypeId* type_id,
350350
XLA_FFI_TypeInfo type_info) {
351+
assert(type_id && "type_id must not be null");
351352
XLA_FFI_TypeId_Register_Args args;
352353
args.struct_size = XLA_FFI_TypeId_Register_Args_STRUCT_SIZE;
353354
args.extension_start = nullptr;
354355
args.name = XLA_FFI_ByteSpan{name.data(), name.size()};
355356
args.type_id = type_id;
357+
args.type_info = &type_info;
356358
return api->XLA_FFI_TypeId_Register(&args);
357359
}
358360

xla/ffi/api/c_api.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next);
6767
// * Deleting a method or argument
6868
// * Changing the type of an argument
6969
// * Rearranging fields in the XLA_FFI_Api or argument structs
70-
#define XLA_FFI_API_MAJOR 0
70+
#define XLA_FFI_API_MAJOR 1
7171

7272
// Incremented when the interface is updated in a way that is potentially
7373
// ABI-compatible with older versions, if supported by the caller and/or
@@ -82,7 +82,7 @@ XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_Extension_Base, next);
8282
// Minor changes include:
8383
// * Adding a new field to the XLA_FFI_Api or argument structs
8484
// * Renaming a method or argument (doesn't affect ABI)
85-
#define XLA_FFI_API_MINOR 1
85+
#define XLA_FFI_API_MINOR 0
8686

8787
struct XLA_FFI_Api_Version {
8888
size_t struct_size;
@@ -491,6 +491,7 @@ struct XLA_FFI_TypeId_Register_Args {
491491

492492
XLA_FFI_ByteSpan name;
493493
XLA_FFI_TypeId* type_id; // in-out
494+
XLA_FFI_TypeInfo* type_info;
494495
};
495496

496497
XLA_FFI_DEFINE_STRUCT_TRAITS(XLA_FFI_TypeId_Register_Args, type_id);

xla/ffi/api/ffi.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class CountDownPromise {
437437
assert(state_->count.load() >= count && "Invalid count down value");
438438

439439
if (XLA_FFI_PREDICT_FALSE(!error.success())) {
440-
const std::lock_guard<std::mutex> lock(state_->mutex);
440+
std::lock_guard<std::mutex> lock(state_->mutex); // NOLINT
441441
state_->is_error.store(true, std::memory_order_release);
442442
state_->error = error;
443443
}
@@ -448,7 +448,7 @@ class CountDownPromise {
448448
bool is_error = state_->is_error.load(std::memory_order_acquire);
449449
if (XLA_FFI_PREDICT_FALSE(is_error)) {
450450
auto take_error = [&] {
451-
const std::lock_guard<std::mutex> lock(state_->mutex);
451+
std::lock_guard<std::mutex> lock(state_->mutex); // NOLINT
452452
return state_->error;
453453
};
454454
state_->promise.SetError(take_error());
@@ -476,7 +476,7 @@ class CountDownPromise {
476476
std::atomic<int64_t> count;
477477
std::atomic<bool> is_error;
478478

479-
std::mutex mutex;
479+
std::mutex mutex; // NOLINT
480480
Error error;
481481
};
482482

xla/ffi/api/ffi_test.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ limitations under the License.
3939
#include "xla/ffi/execution_context.h"
4040
#include "xla/ffi/execution_state.h"
4141
#include "xla/ffi/ffi_api.h"
42-
#include "xla/ffi/type_id_registry.h"
42+
#include "xla/ffi/type_registry.h"
4343
#include "xla/primitive_util.h"
4444
#include "xla/stream_executor/device_memory.h"
4545
#include "xla/stream_executor/device_memory_allocator.h"
@@ -1220,9 +1220,9 @@ TEST(FfiTest, UserData) {
12201220

12211221
ExecutionContext execution_context;
12221222
TF_ASSERT_OK(execution_context.Insert(
1223-
TypeIdRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0));
1223+
TypeRegistry::TypeId(MyDataWithAutoTypeId::id.type_id), &data0));
12241224
TF_ASSERT_OK(execution_context.Insert(
1225-
TypeIdRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1));
1225+
TypeRegistry::TypeId(MyDataWithExplicitTypeId::id.type_id), &data1));
12261226

12271227
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
12281228
auto call_frame = builder.Build();

xla/ffi/execution_context.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ limitations under the License.
2525
#include "absl/functional/function_ref.h"
2626
#include "absl/status/status.h"
2727
#include "absl/status/statusor.h"
28-
#include "xla/ffi/type_id_registry.h"
28+
#include "xla/ffi/type_registry.h"
2929
#include "xla/tsl/platform/logging.h"
3030
#include "xla/tsl/platform/statusor.h"
3131

@@ -45,7 +45,7 @@ namespace xla::ffi {
4545
// unique between separate calls to XLA execute.
4646
class ExecutionContext {
4747
public:
48-
using TypeId = TypeIdRegistry::TypeId;
48+
using TypeId = TypeRegistry::TypeId;
4949

5050
template <typename T>
5151
using Deleter = std::function<void(T*)>;
@@ -67,7 +67,7 @@ class ExecutionContext {
6767
template <typename T>
6868
absl::StatusOr<T*> Lookup() const {
6969
TF_ASSIGN_OR_RETURN(auto user_data,
70-
LookupUserData(TypeIdRegistry::GetTypeId<T>()));
70+
LookupUserData(TypeRegistry::GetTypeId<T>()));
7171
return static_cast<T*>(user_data->data());
7272
}
7373

@@ -110,7 +110,7 @@ class ExecutionContext {
110110

111111
template <typename T>
112112
absl::Status ExecutionContext::Insert(T* data, Deleter<T> deleter) {
113-
return InsertUserData(TypeIdRegistry::GetTypeId<T>(),
113+
return InsertUserData(TypeRegistry::GetTypeId<T>(),
114114
std::make_unique<UserData>(
115115
data, [deleter = std::move(deleter)](void* data) {
116116
if (deleter) deleter(static_cast<T*>(data));
@@ -119,7 +119,7 @@ absl::Status ExecutionContext::Insert(T* data, Deleter<T> deleter) {
119119

120120
template <typename T, typename... Args>
121121
absl::Status ExecutionContext::Emplace(Args&&... args) {
122-
return InsertUserData(TypeIdRegistry::GetTypeId<T>(),
122+
return InsertUserData(TypeRegistry::GetTypeId<T>(),
123123
std::make_unique<UserData>(
124124
new T(std::forward<Args>(args)...),
125125
[](void* data) { delete static_cast<T*>(data); }));

xla/ffi/execution_context_test.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License.
2020

2121
#include <gtest/gtest.h>
2222
#include "absl/status/status.h"
23-
#include "xla/ffi/type_id_registry.h"
23+
#include "xla/ffi/type_registry.h"
2424
#include "xla/tsl/lib/core/status_test_util.h"
2525
#include "xla/tsl/platform/statusor.h"
2626
#include "xla/tsl/platform/test.h"
@@ -62,8 +62,9 @@ TEST(ExecutionContextTest, InsertUserOwned) {
6262
}
6363

6464
TEST(ExecutionContextTest, InsertUserOwnedWithTypeId) {
65-
TF_ASSERT_OK_AND_ASSIGN(TypeIdRegistry::TypeId type_id,
66-
TypeIdRegistry::AssignExternalTypeId("I32UserData"));
65+
TF_ASSERT_OK_AND_ASSIGN(TypeRegistry::TypeId type_id,
66+
TypeRegistry::AssignExternalTypeId(
67+
"I32UserData", TypeRegistry::TypeInfo{}));
6768

6869
I32UserData user_data(42);
6970

xla/ffi/execution_state.cc

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,62 @@ limitations under the License.
1515

1616
#include "xla/ffi/execution_state.h"
1717

18-
#include <utility>
19-
18+
#include "absl/base/attributes.h"
2019
#include "absl/log/check.h"
2120
#include "absl/status/status.h"
2221
#include "absl/status/statusor.h"
23-
#include "xla/ffi/type_id_registry.h"
22+
#include "xla/ffi/type_registry.h"
23+
#include "xla/tsl/platform/statusor.h"
2424
#include "xla/util.h"
25-
#include "tsl/platform/logging.h"
2625

2726
namespace xla::ffi {
2827

2928
ExecutionState::ExecutionState()
30-
: type_id_(TypeIdRegistry::kUnknownTypeId),
31-
state_(nullptr),
32-
deleter_(nullptr) {}
29+
: type_id_(TypeRegistry::kUnknownTypeId), state_(nullptr) {}
3330

3431
ExecutionState::~ExecutionState() {
35-
if (deleter_) deleter_(state_);
32+
if (type_info_.deleter) {
33+
type_info_.deleter(state_);
34+
}
35+
}
36+
37+
absl::Status ExecutionState::Set(TypeId type_id, void* state) {
38+
TF_ASSIGN_OR_RETURN(auto type_info,
39+
TypeRegistry::GetExternalTypeInfo(type_id));
40+
if (type_info.deleter == nullptr) {
41+
return InvalidArgument(
42+
"Type id %d does not have a registered type info with a deleter",
43+
type_id.value());
44+
}
45+
return Set(type_id, type_info, state);
3646
}
3747

48+
ABSL_DEPRECATED("FFI users must rely in TypeInfo registration")
3849
absl::Status ExecutionState::Set(TypeId type_id, void* state,
39-
Deleter<void> deleter) {
40-
DCHECK(state && deleter) << "State and deleter must not be null";
50+
void (*deleter)(void*)) {
51+
return Set(type_id, TypeInfo{deleter}, state);
52+
}
53+
54+
absl::Status ExecutionState::Set(TypeId type_id, TypeInfo type_info,
55+
void* state) {
56+
DCHECK(state && type_info.deleter) << "State and deleter must not be null";
4157

42-
if (type_id_ != TypeIdRegistry::kUnknownTypeId) {
58+
if (type_id_ != TypeRegistry::kUnknownTypeId) {
4359
return FailedPrecondition("State is already set with a type id %d",
4460
type_id_.value());
4561
}
4662

4763
type_id_ = type_id;
64+
type_info_ = type_info;
4865
state_ = state;
49-
deleter_ = std::move(deleter);
5066

5167
return absl::OkStatus();
5268
}
5369

5470
// Returns opaque state of the given type id. If set state type id does not
5571
// match the requested one, returns an error.
5672
absl::StatusOr<void*> ExecutionState::Get(TypeId type_id) const {
57-
if (type_id_ == TypeIdRegistry::kUnknownTypeId) {
73+
if (type_id_ == TypeRegistry::kUnknownTypeId) {
5874
return NotFound("State is not set");
5975
}
6076

@@ -68,7 +84,7 @@ absl::StatusOr<void*> ExecutionState::Get(TypeId type_id) const {
6884
}
6985

7086
bool ExecutionState::IsSet() const {
71-
return type_id_ != TypeIdRegistry::kUnknownTypeId;
87+
return type_id_ != TypeRegistry::kUnknownTypeId;
7288
}
7389

7490
} // namespace xla::ffi

0 commit comments

Comments
 (0)