Skip to content

Commit 60c9ac6

Browse files
fengwuyaocopybara-github
authored andcommitted
Internal changes only.
LiteRT-PiperOrigin-RevId: 824601951
1 parent 3fef5c0 commit 60c9ac6

11 files changed

+35
-6
lines changed

litert/c/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ cc_library(
358358
"//litert/cc:litert_expected",
359359
"//litert/cc:litert_macros",
360360
"//litert/runtime:tensor_buffer",
361+
"//litert/runtime:tensor_buffer_registry",
361362
"@com_google_absl//absl/types:span",
362363
"@opencl_headers",
363364
] + gles_deps(),

litert/c/internal/litert_tensor_buffer_registry.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers(
2525
LiteRtEnvironment env, LiteRtTensorBufferType buffer_type,
2626
CreateCustomTensorBuffer create_func,
2727
DestroyCustomTensorBuffer destroy_func, LockCustomTensorBuffer lock_func,
28-
UnlockCustomTensorBuffer unlock_func,
29-
ImportCustomTensorBuffer import_func) {
28+
UnlockCustomTensorBuffer unlock_func, ImportCustomTensorBuffer import_func,
29+
GetCustomTensorBufferHandle get_handle_func) {
3030
auto& registry = env->GetTensorBufferRegistry();
3131
litert::internal::CustomTensorBufferHandlers handlers = {
3232
.create_func = create_func,
3333
.destroy_func = destroy_func,
3434
.lock_func = lock_func,
3535
.unlock_func = unlock_func,
3636
.import_func = import_func,
37+
.get_handle_func = get_handle_func,
3738
};
3839
LITERT_RETURN_IF_ERROR(registry.RegisterHandlers(buffer_type, handlers));
3940
return kLiteRtStatusOk;

litert/c/internal/litert_tensor_buffer_registry.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers(
3030
LiteRtEnvironment env, LiteRtTensorBufferType buffer_type,
3131
CreateCustomTensorBuffer create_func,
3232
DestroyCustomTensorBuffer destroy_func, LockCustomTensorBuffer lock_func,
33-
UnlockCustomTensorBuffer unlock_func, ImportCustomTensorBuffer import_func);
33+
UnlockCustomTensorBuffer unlock_func,
34+
ImportCustomTensorBuffer import_func = nullptr,
35+
GetCustomTensorBufferHandle get_handle_func = nullptr);
3436

3537
// Retrieves a singleton instance of the tensor buffer registry.
3638
LiteRtStatus LiteRtGetTensorBufferRegistry(LiteRtEnvironment env,

litert/c/litert_custom_tensor_buffer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ typedef LiteRtStatus (*LockCustomTensorBuffer)(LiteRtEnvironment env,
7575
typedef LiteRtStatus (*UnlockCustomTensorBuffer)(
7676
LiteRtEnvironment env, HwMemoryInfoPtr hw_memory_info);
7777

78+
// Get the custom H/W memory buffer handle from HwMemoryInfoPtr.
79+
typedef LiteRtStatus (*GetCustomTensorBufferHandle)(
80+
HwMemoryInfoPtr hw_memory_info, HwMemoryHandle* hw_memory_handle);
81+
7882
#ifdef __cplusplus
7983
}
8084
#endif // __cplusplus

litert/c/litert_tensor_buffer.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
#include "litert/cc/litert_macros.h"
2727
#include "litert/runtime/custom_buffer.h"
2828
#include "litert/runtime/tensor_buffer.h"
29+
#include "litert/runtime/tensor_buffer_registry.h" // IWYU pragma: keep
2930
#include "litert/runtime/tensor_buffer_requirements.h"
3031

31-
3232
#if LITERT_HAS_OPENCL_SUPPORT
3333
#include <CL/cl.h>
3434
#endif // LITERT_HAS_OPENCL_SUPPORT
@@ -202,8 +202,15 @@ LiteRtStatus LiteRtGetTensorBufferMetalMemory(
202202
LITERT_ASSIGN_OR_RETURN(litert::internal::CustomBuffer * custom_buffer,
203203
tensor_buffer->GetCustomBuffer());
204204

205-
*hw_memory_handle = custom_buffer->hw_buffer_handle();
206-
return kLiteRtStatusOk;
205+
auto* registry = reinterpret_cast<litert::internal::TensorBufferRegistry*>(
206+
tensor_buffer->GetTensorBufferRegistry());
207+
LITERT_ASSIGN_OR_RETURN(
208+
auto handlers, registry->GetCustomHandlers(tensor_buffer->buffer_type()));
209+
if (handlers.get_handle_func) {
210+
return handlers.get_handle_func(custom_buffer->hw_memory_info(),
211+
hw_memory_handle);
212+
}
213+
return kLiteRtStatusErrorUnsupported;
207214
}
208215

209216
#endif // LITERT_HAS_METAL_SUPPORT

litert/cc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,7 @@ cc_library(
10521052
# ":litert_macros",
10531053
# ":litert_model",
10541054
# ":litert_tensor_buffer",
1055+
# "@com_google_absl//absl/log:absl_log",
10551056
# "//third_party/apple_frameworks:XCTest",
10561057
# "//litert/c:litert_common",
10571058
# "//litert/c:litert_environment",

litert/cc/litert_tensor_buffer_test.mm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#import "third_party/odml/litert/litert/cc/litert_tensor_buffer.h"
1616
#import <XCTest/XCTest.h>
1717
#import <XCTest/XCTestAssertions.h>
18+
#include "absl/log/absl_log.h" // from @com_google_absl
1819
#include "litert/c/litert_common.h"
1920
#include "litert/c/litert_environment.h"
2021
#include "litert/c/litert_model_types.h"
@@ -127,6 +128,7 @@ - (void)testTensorBufferCreateFromMetalMemory {
127128
XCTAssertTrue(metal_buffer);
128129

129130
// Create a tensor buffer from the existing metal buffer.
131+
ABSL_LOG(INFO) << "Before create from metal buffer";
130132
auto metal_buffer_created = litert::TensorBuffer::CreateFromMetalBuffer(
131133
*env, kTensorType, kTensorBufferType, *metal_buffer, sizeof(kTensorData));
132134
XCTAssertTrue(metal_buffer_created);

litert/runtime/custom_buffer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class CustomBuffer {
4646
~CustomBuffer();
4747

4848
HwMemoryHandle hw_buffer_handle() { return hw_memory_info_->memory_handle; }
49+
HwMemoryInfoPtr hw_memory_info() { return hw_memory_info_; }
4950
// Allocates a CPU memory and conducts a copy from the Custom buffer to the
5051
// CPU memory.
5152
Expected<void*> Lock(LiteRtTensorBufferLockMode mode);

litert/runtime/tensor_buffer.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/strings/str_format.h" // from @com_google_absl
2828
#include "absl/types/span.h" // from @com_google_absl
2929
#include "litert/c/internal/litert_logging.h"
30+
#include "litert/c/internal/litert_tensor_buffer_registry.h"
3031
#include "litert/c/litert_common.h"
3132
#include "litert/c/litert_model_types.h"
3233
#include "litert/c/litert_tensor_buffer_types.h"
@@ -1003,6 +1004,12 @@ Expected<void*> LiteRtTensorBufferT::Lock(LiteRtTensorBufferLockMode mode) {
10031004
}
10041005
}
10051006

1007+
void* LiteRtTensorBufferT::GetTensorBufferRegistry() {
1008+
void* registry = nullptr;
1009+
LiteRtGetTensorBufferRegistry(env_, &registry);
1010+
return registry;
1011+
}
1012+
10061013
Expected<void> LiteRtTensorBufferT::Unlock() {
10071014
LITERT_RETURN_IF_ERROR(is_locked_ == true,
10081015
Unexpected(kLiteRtStatusErrorRuntimeFailure,

litert/runtime/tensor_buffer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ class LiteRtTensorBufferT {
187187
// Gets the current reference count.
188188
int RefCount() const { return ref_.load(std::memory_order_relaxed); }
189189

190+
void* GetTensorBufferRegistry();
191+
190192
private:
191193
struct HostBuffer {
192194
void* addr;

0 commit comments

Comments
 (0)