Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions litert/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ cc_library(
"//litert/cc:litert_expected",
"//litert/cc:litert_macros",
"//litert/runtime:tensor_buffer",
"//litert/runtime:tensor_buffer_registry",
"@com_google_absl//absl/types:span",
"@opencl_headers",
] + gles_deps(),
Expand Down
5 changes: 3 additions & 2 deletions litert/c/internal/litert_tensor_buffer_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,16 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers(
LiteRtEnvironment env, LiteRtTensorBufferType buffer_type,
CreateCustomTensorBuffer create_func,
DestroyCustomTensorBuffer destroy_func, LockCustomTensorBuffer lock_func,
UnlockCustomTensorBuffer unlock_func,
ImportCustomTensorBuffer import_func) {
UnlockCustomTensorBuffer unlock_func, ImportCustomTensorBuffer import_func,
GetCustomTensorBufferHandle get_handle_func) {
auto& registry = env->GetTensorBufferRegistry();
litert::internal::CustomTensorBufferHandlers handlers = {
.create_func = create_func,
.destroy_func = destroy_func,
.lock_func = lock_func,
.unlock_func = unlock_func,
.import_func = import_func,
.get_handle_func = get_handle_func,
};
LITERT_RETURN_IF_ERROR(registry.RegisterHandlers(buffer_type, handlers));
return kLiteRtStatusOk;
Expand Down
4 changes: 3 additions & 1 deletion litert/c/internal/litert_tensor_buffer_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers(
LiteRtEnvironment env, LiteRtTensorBufferType buffer_type,
CreateCustomTensorBuffer create_func,
DestroyCustomTensorBuffer destroy_func, LockCustomTensorBuffer lock_func,
UnlockCustomTensorBuffer unlock_func, ImportCustomTensorBuffer import_func);
UnlockCustomTensorBuffer unlock_func,
ImportCustomTensorBuffer import_func = nullptr,
GetCustomTensorBufferHandle get_handle_func = nullptr);

// Retrieves a singleton instance of the tensor buffer registry.
LiteRtStatus LiteRtGetTensorBufferRegistry(LiteRtEnvironment env,
Expand Down
4 changes: 4 additions & 0 deletions litert/c/litert_custom_tensor_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ typedef LiteRtStatus (*LockCustomTensorBuffer)(LiteRtEnvironment env,
typedef LiteRtStatus (*UnlockCustomTensorBuffer)(
LiteRtEnvironment env, HwMemoryInfoPtr hw_memory_info);

// Get the custom H/W memory buffer handle from HwMemoryInfoPtr.
typedef LiteRtStatus (*GetCustomTensorBufferHandle)(
HwMemoryInfoPtr hw_memory_info, HwMemoryHandle* hw_memory_handle);

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
13 changes: 10 additions & 3 deletions litert/c/litert_tensor_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
#include "litert/cc/litert_macros.h"
#include "litert/runtime/custom_buffer.h"
#include "litert/runtime/tensor_buffer.h"
#include "litert/runtime/tensor_buffer_registry.h" // IWYU pragma: keep
#include "litert/runtime/tensor_buffer_requirements.h"


#if LITERT_HAS_OPENCL_SUPPORT
#include <CL/cl.h>
#endif // LITERT_HAS_OPENCL_SUPPORT
Expand Down Expand Up @@ -202,8 +202,15 @@ LiteRtStatus LiteRtGetTensorBufferMetalMemory(
LITERT_ASSIGN_OR_RETURN(litert::internal::CustomBuffer * custom_buffer,
tensor_buffer->GetCustomBuffer());

*hw_memory_handle = custom_buffer->hw_buffer_handle();
return kLiteRtStatusOk;
auto* registry = reinterpret_cast<litert::internal::TensorBufferRegistry*>(
tensor_buffer->GetTensorBufferRegistry());
LITERT_ASSIGN_OR_RETURN(
auto handlers, registry->GetCustomHandlers(tensor_buffer->buffer_type()));
if (handlers.get_handle_func) {
return handlers.get_handle_func(custom_buffer->hw_memory_info(),
hw_memory_handle);
}
return kLiteRtStatusErrorUnsupported;
}

#endif // LITERT_HAS_METAL_SUPPORT
Expand Down
1 change: 1 addition & 0 deletions litert/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ cc_library(
# ":litert_macros",
# ":litert_model",
# ":litert_tensor_buffer",
# "@com_google_absl//absl/log:absl_log",
# "//third_party/apple_frameworks:XCTest",
# "//litert/c:litert_common",
# "//litert/c:litert_environment",
Expand Down
2 changes: 2 additions & 0 deletions litert/cc/litert_tensor_buffer_test.mm
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#import "third_party/odml/litert/litert/cc/litert_tensor_buffer.h"
#import <XCTest/XCTest.h>
#import <XCTest/XCTestAssertions.h>
#include "absl/log/absl_log.h" // from @com_google_absl
#include "litert/c/litert_common.h"
#include "litert/c/litert_environment.h"
#include "litert/c/litert_model_types.h"
Expand Down Expand Up @@ -127,6 +128,7 @@ - (void)testTensorBufferCreateFromMetalMemory {
XCTAssertTrue(metal_buffer);

// Create a tensor buffer from the existing metal buffer.
ABSL_LOG(INFO) << "Before create from metal buffer";
auto metal_buffer_created = litert::TensorBuffer::CreateFromMetalBuffer(
*env, kTensorType, kTensorBufferType, *metal_buffer, sizeof(kTensorData));
XCTAssertTrue(metal_buffer_created);
Expand Down
1 change: 1 addition & 0 deletions litert/runtime/custom_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CustomBuffer {
~CustomBuffer();

HwMemoryHandle hw_buffer_handle() { return hw_memory_info_->memory_handle; }
HwMemoryInfoPtr hw_memory_info() { return hw_memory_info_; }
// Allocates a CPU memory and conducts a copy from the Custom buffer to the
// CPU memory.
Expected<void*> Lock(LiteRtTensorBufferLockMode mode);
Expand Down
7 changes: 7 additions & 0 deletions litert/runtime/tensor_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/strings/str_format.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "litert/c/internal/litert_logging.h"
#include "litert/c/internal/litert_tensor_buffer_registry.h"
#include "litert/c/litert_common.h"
#include "litert/c/litert_model_types.h"
#include "litert/c/litert_tensor_buffer_types.h"
Expand Down Expand Up @@ -1003,6 +1004,12 @@ Expected<void*> LiteRtTensorBufferT::Lock(LiteRtTensorBufferLockMode mode) {
}
}

void* LiteRtTensorBufferT::GetTensorBufferRegistry() {
void* registry = nullptr;
LiteRtGetTensorBufferRegistry(env_, &registry);
return registry;
}

Expected<void> LiteRtTensorBufferT::Unlock() {
LITERT_RETURN_IF_ERROR(is_locked_ == true,
Unexpected(kLiteRtStatusErrorRuntimeFailure,
Expand Down
2 changes: 2 additions & 0 deletions litert/runtime/tensor_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class LiteRtTensorBufferT {
// Gets the current reference count.
int RefCount() const { return ref_.load(std::memory_order_relaxed); }

void* GetTensorBufferRegistry();

private:
struct HostBuffer {
void* addr;
Expand Down
1 change: 1 addition & 0 deletions litert/runtime/tensor_buffer_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct CustomTensorBufferHandlers {
// Optional function to import an existing buffer.
// TODO(b/446717438): Merge this with the create function.
::ImportCustomTensorBuffer import_func;
::GetCustomTensorBufferHandle get_handle_func;
};

class TensorBufferRegistry {
Expand Down
Loading