diff --git a/litert/c/BUILD b/litert/c/BUILD index d7ebeed88c..5543f2fcce 100644 --- a/litert/c/BUILD +++ b/litert/c/BUILD @@ -358,6 +358,7 @@ cc_library( "//litert/cc:litert_expected", "//litert/cc:litert_macros", "//litert/runtime:tensor_buffer", + "//litert/runtime:tensor_buffer_registry", "@com_google_absl//absl/types:span", "@opencl_headers", ] + gles_deps(), diff --git a/litert/c/internal/litert_tensor_buffer_registry.cc b/litert/c/internal/litert_tensor_buffer_registry.cc index af1abdee9e..583232d6d8 100644 --- a/litert/c/internal/litert_tensor_buffer_registry.cc +++ b/litert/c/internal/litert_tensor_buffer_registry.cc @@ -25,8 +25,8 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers( LiteRtEnvironment env, LiteRtTensorBufferType buffer_type, CreateCustomTensorBuffer create_func, DestroyCustomTensorBuffer destroy_func, LockCustomTensorBuffer lock_func, - UnlockCustomTensorBuffer unlock_func, - ImportCustomTensorBuffer import_func) { + UnlockCustomTensorBuffer unlock_func, ImportCustomTensorBuffer import_func, + GetCustomTensorBufferHandle get_handle_func) { auto& registry = env->GetTensorBufferRegistry(); litert::internal::CustomTensorBufferHandlers handlers = { .create_func = create_func, @@ -34,6 +34,7 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers( .lock_func = lock_func, .unlock_func = unlock_func, .import_func = import_func, + .get_handle_func = get_handle_func, }; LITERT_RETURN_IF_ERROR(registry.RegisterHandlers(buffer_type, handlers)); return kLiteRtStatusOk; diff --git a/litert/c/internal/litert_tensor_buffer_registry.h b/litert/c/internal/litert_tensor_buffer_registry.h index 28c20af765..07d35eeb57 100644 --- a/litert/c/internal/litert_tensor_buffer_registry.h +++ b/litert/c/internal/litert_tensor_buffer_registry.h @@ -30,7 +30,9 @@ LiteRtStatus LiteRtRegisterTensorBufferHandlers( LiteRtEnvironment env, LiteRtTensorBufferType buffer_type, CreateCustomTensorBuffer create_func, DestroyCustomTensorBuffer destroy_func, LockCustomTensorBuffer lock_func, - UnlockCustomTensorBuffer unlock_func, ImportCustomTensorBuffer import_func); + UnlockCustomTensorBuffer unlock_func, + ImportCustomTensorBuffer import_func = nullptr, + GetCustomTensorBufferHandle get_handle_func = nullptr); // Retrieves a singleton instance of the tensor buffer registry. LiteRtStatus LiteRtGetTensorBufferRegistry(LiteRtEnvironment env, diff --git a/litert/c/litert_custom_tensor_buffer.h b/litert/c/litert_custom_tensor_buffer.h index 7debdc5700..0c2aae9465 100644 --- a/litert/c/litert_custom_tensor_buffer.h +++ b/litert/c/litert_custom_tensor_buffer.h @@ -75,6 +75,10 @@ typedef LiteRtStatus (*LockCustomTensorBuffer)(LiteRtEnvironment env, typedef LiteRtStatus (*UnlockCustomTensorBuffer)( LiteRtEnvironment env, HwMemoryInfoPtr hw_memory_info); +// Get the custom H/W memory buffer handle from HwMemoryInfoPtr. +typedef LiteRtStatus (*GetCustomTensorBufferHandle)( + HwMemoryInfoPtr hw_memory_info, HwMemoryHandle* hw_memory_handle); + #ifdef __cplusplus } #endif // __cplusplus diff --git a/litert/c/litert_tensor_buffer.cc b/litert/c/litert_tensor_buffer.cc index 288b28eb54..5b4f421db7 100644 --- a/litert/c/litert_tensor_buffer.cc +++ b/litert/c/litert_tensor_buffer.cc @@ -26,9 +26,9 @@ #include "litert/cc/litert_macros.h" #include "litert/runtime/custom_buffer.h" #include "litert/runtime/tensor_buffer.h" +#include "litert/runtime/tensor_buffer_registry.h" // IWYU pragma: keep #include "litert/runtime/tensor_buffer_requirements.h" - #if LITERT_HAS_OPENCL_SUPPORT #include #endif // LITERT_HAS_OPENCL_SUPPORT @@ -202,8 +202,15 @@ LiteRtStatus LiteRtGetTensorBufferMetalMemory( LITERT_ASSIGN_OR_RETURN(litert::internal::CustomBuffer * custom_buffer, tensor_buffer->GetCustomBuffer()); - *hw_memory_handle = custom_buffer->hw_buffer_handle(); - return kLiteRtStatusOk; + auto* registry = reinterpret_cast( + tensor_buffer->GetTensorBufferRegistry()); + LITERT_ASSIGN_OR_RETURN( + auto handlers, registry->GetCustomHandlers(tensor_buffer->buffer_type())); + if (handlers.get_handle_func) { + return handlers.get_handle_func(custom_buffer->hw_memory_info(), + hw_memory_handle); + } + return kLiteRtStatusErrorUnsupported; } #endif // LITERT_HAS_METAL_SUPPORT diff --git a/litert/cc/BUILD b/litert/cc/BUILD index 1fbecb407c..33c06a29d2 100644 --- a/litert/cc/BUILD +++ b/litert/cc/BUILD @@ -1052,6 +1052,7 @@ cc_library( # ":litert_macros", # ":litert_model", # ":litert_tensor_buffer", +# "@com_google_absl//absl/log:absl_log", # "//third_party/apple_frameworks:XCTest", # "//litert/c:litert_common", # "//litert/c:litert_environment", diff --git a/litert/cc/litert_tensor_buffer_test.mm b/litert/cc/litert_tensor_buffer_test.mm index 232b252f7a..579eb17861 100644 --- a/litert/cc/litert_tensor_buffer_test.mm +++ b/litert/cc/litert_tensor_buffer_test.mm @@ -15,6 +15,7 @@ #import "third_party/odml/litert/litert/cc/litert_tensor_buffer.h" #import #import +#include "absl/log/absl_log.h" // from @com_google_absl #include "litert/c/litert_common.h" #include "litert/c/litert_environment.h" #include "litert/c/litert_model_types.h" @@ -127,6 +128,7 @@ - (void)testTensorBufferCreateFromMetalMemory { XCTAssertTrue(metal_buffer); // Create a tensor buffer from the existing metal buffer. + ABSL_LOG(INFO) << "Before create from metal buffer"; auto metal_buffer_created = litert::TensorBuffer::CreateFromMetalBuffer( *env, kTensorType, kTensorBufferType, *metal_buffer, sizeof(kTensorData)); XCTAssertTrue(metal_buffer_created); diff --git a/litert/runtime/custom_buffer.h b/litert/runtime/custom_buffer.h index b49c09b8aa..2ccfbc0b0f 100644 --- a/litert/runtime/custom_buffer.h +++ b/litert/runtime/custom_buffer.h @@ -46,6 +46,7 @@ class CustomBuffer { ~CustomBuffer(); HwMemoryHandle hw_buffer_handle() { return hw_memory_info_->memory_handle; } + HwMemoryInfoPtr hw_memory_info() { return hw_memory_info_; } // Allocates a CPU memory and conducts a copy from the Custom buffer to the // CPU memory. Expected Lock(LiteRtTensorBufferLockMode mode); diff --git a/litert/runtime/tensor_buffer.cc b/litert/runtime/tensor_buffer.cc index bb88b5b863..2429e03e41 100644 --- a/litert/runtime/tensor_buffer.cc +++ b/litert/runtime/tensor_buffer.cc @@ -27,6 +27,7 @@ #include "absl/strings/str_format.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "litert/c/internal/litert_logging.h" +#include "litert/c/internal/litert_tensor_buffer_registry.h" #include "litert/c/litert_common.h" #include "litert/c/litert_model_types.h" #include "litert/c/litert_tensor_buffer_types.h" @@ -1003,6 +1004,12 @@ Expected LiteRtTensorBufferT::Lock(LiteRtTensorBufferLockMode mode) { } } +void* LiteRtTensorBufferT::GetTensorBufferRegistry() { + void* registry = nullptr; + LiteRtGetTensorBufferRegistry(env_, ®istry); + return registry; +} + Expected LiteRtTensorBufferT::Unlock() { LITERT_RETURN_IF_ERROR(is_locked_ == true, Unexpected(kLiteRtStatusErrorRuntimeFailure, diff --git a/litert/runtime/tensor_buffer.h b/litert/runtime/tensor_buffer.h index 909e61a5af..38be4e6dd5 100644 --- a/litert/runtime/tensor_buffer.h +++ b/litert/runtime/tensor_buffer.h @@ -187,6 +187,8 @@ class LiteRtTensorBufferT { // Gets the current reference count. int RefCount() const { return ref_.load(std::memory_order_relaxed); } + void* GetTensorBufferRegistry(); + private: struct HostBuffer { void* addr; diff --git a/litert/runtime/tensor_buffer_registry.h b/litert/runtime/tensor_buffer_registry.h index 11ed3ce4da..9966739049 100644 --- a/litert/runtime/tensor_buffer_registry.h +++ b/litert/runtime/tensor_buffer_registry.h @@ -32,6 +32,7 @@ struct CustomTensorBufferHandlers { // Optional function to import an existing buffer. // TODO(b/446717438): Merge this with the create function. ::ImportCustomTensorBuffer import_func; + ::GetCustomTensorBufferHandle get_handle_func; }; class TensorBufferRegistry {