diff --git a/litert/cc/litert_tensor_buffer.cc b/litert/cc/litert_tensor_buffer.cc index 3f1fa19be..e7400221f 100644 --- a/litert/cc/litert_tensor_buffer.cc +++ b/litert/cc/litert_tensor_buffer.cc @@ -62,6 +62,13 @@ Expected TensorBuffer::CreateManaged( return TensorBuffer(tensor_buffer, OwnHandle::kYes); } +Expected TensorBuffer::CreateManaged( + TensorBufferType buffer_type, const RankedTensorType& tensor_type, + size_t buffer_size) { + return CreateManaged(static_cast(buffer_type), + tensor_type, buffer_size); +} + Expected TensorBuffer::CreateFromHostMemory( const Environment&, const RankedTensorType& tensor_type, void* host_mem_addr, size_t buffer_size) { diff --git a/litert/cc/litert_tensor_buffer.h b/litert/cc/litert_tensor_buffer.h index 1d56c0d54..6b4f3e64b 100644 --- a/litert/cc/litert_tensor_buffer.h +++ b/litert/cc/litert_tensor_buffer.h @@ -57,6 +57,13 @@ class TensorBuffer const Environment& env, TensorBufferType buffer_type, const RankedTensorType& tensor_type, size_t buffer_size); + // Creates a managed TensorBuffer object in the given buffer type using the + // default environment (if applicable). The returned object is owned by the + // caller. + static Expected CreateManaged( + TensorBufferType buffer_type, const RankedTensorType& tensor_type, + size_t buffer_size); + [[deprecated( "Use the overload that takes litert::TensorBufferType instead.")]] static Expected CreateManaged(