2626#include " absl/strings/string_view.h" // from @com_google_absl
2727#include " absl/types/span.h" // from @com_google_absl
2828#include " litert/c/litert_common.h"
29- #include " litert/c/litert_environment .h"
29+ #include " litert/c/litert_tensor_buffer .h"
3030#include " litert/c/litert_tensor_buffer_types.h"
3131#include " litert/cc/internal/litert_dispatch_delegate.h"
3232#include " litert/cc/litert_compiled_model.h"
3333#include " litert/cc/litert_environment.h"
3434#include " litert/cc/litert_expected.h"
35+ #include " litert/cc/litert_macros.h"
3536#include " litert/cc/litert_model.h"
3637#include " litert/cc/litert_options.h"
3738#include " litert/cc/litert_tensor_buffer.h"
3839#include " litert/cc/litert_tensor_buffer_requirements.h"
3940#include " litert/runtime/dispatch/dispatch_opaque_options.h"
4041#include " litert/runtime/external_litert_buffer_context.h"
42+ #include " litert/runtime/tensor_buffer.h"
43+ #include " litert/runtime/tensor_buffer_requirements.h"
44+ #include " litert/runtime/tensor_identifier.h"
45+ #include " litert/runtime/tfl_utils.h"
4146#include " litert/test/common.h"
4247#include " litert/test/matchers.h"
4348#include " litert/test/testdata/simple_model_test_vectors.h"
@@ -78,6 +83,25 @@ litert::Expected<Options> CreateDispatchOptions(const uint8_t* base) {
7883 return options;
7984}
8085
86+ LiteRtExternalLiteRtBufferContextT CreateBufferContext (
87+ const LiteRtEnvironment& env, const tflite::Interpreter& interpreter) {
88+ auto get_tensor_id = [&interpreter](const TfLiteOpaqueTensor* target_tensor)
89+ -> litert::internal::TfLiteTensorIdentifier {
90+ auto tensor_id = litert::internal::GetTensorIdentifier (
91+ interpreter, reinterpret_cast <const TfLiteTensor*>(target_tensor));
92+ if (!tensor_id) {
93+ LITERT_LOG (LITERT_ERROR, " Failed to get tensor identifier: %s" ,
94+ tensor_id.Error ().Message ().c_str ());
95+ constexpr litert::internal::TfLiteTensorIdentifier kInvalidTensorId {-1 ,
96+ -1 };
97+ return kInvalidTensorId ;
98+ }
99+ return *tensor_id;
100+ };
101+
102+ return LiteRtExternalLiteRtBufferContextT (env, get_tensor_id);
103+ }
104+
81105TEST (DispatchDelegate, CpuBuffer) {
82106 // The dispatch delegate must be declared before the TFL interpreter so that
83107 // it gets destroyed only after the interpreter and the dispatch delegate
@@ -91,15 +115,17 @@ TEST(DispatchDelegate, CpuBuffer) {
91115 MakeRuntimeFromTestFileWithNpuModel (kTfliteFile , kNpuFile ));
92116 tflite::Interpreter& interpreter = runtime->Interpreter ();
93117
94- litert::internal::ExternalLiteRtBufferContext buffer_context;
118+ LITERT_ASSERT_OK_AND_ASSIGN (auto env, CreateDefaultEnvironment ());
119+
120+ LiteRtExternalLiteRtBufferContextT buffer_context =
121+ CreateBufferContext (env.Get (), interpreter);
95122 interpreter.SetExternalContext (kTfLiteLiteRtBufferContext , &buffer_context);
96123
97124 EXPECT_EQ (interpreter.nodes_size (), 1 );
98125 EXPECT_EQ (interpreter.inputs ().size (), 2 );
99126 EXPECT_EQ (interpreter.outputs ().size (), 1 );
100127 ASSERT_EQ (interpreter.execution_plan ().size (), 1 );
101128
102- LITERT_ASSERT_OK_AND_ASSIGN (auto env, CreateDefaultEnvironment ());
103129 LITERT_ASSERT_OK_AND_ASSIGN (auto env_options, env.GetOptions ());
104130 LITERT_ASSERT_OK_AND_ASSIGN (
105131 auto options, CreateDispatchOptions (runtime->Flatbuffer ().Buf ().Data ()));
@@ -164,15 +190,17 @@ TEST(DispatchDelegate, HwBuffer) {
164190 MakeRuntimeFromTestFileWithNpuModel (kTfliteFile , kNpuFile ));
165191 tflite::Interpreter& interpreter = runtime->Interpreter ();
166192
167- litert::internal::ExternalLiteRtBufferContext buffer_context;
193+ LITERT_ASSERT_OK_AND_ASSIGN (auto env, CreateDefaultEnvironment ());
194+
195+ LiteRtExternalLiteRtBufferContextT buffer_context =
196+ CreateBufferContext (env.Get (), interpreter);
168197 interpreter.SetExternalContext (kTfLiteLiteRtBufferContext , &buffer_context);
169198
170199 EXPECT_EQ (interpreter.nodes_size (), 1 );
171200 EXPECT_EQ (interpreter.inputs ().size (), 2 );
172201 EXPECT_EQ (interpreter.outputs ().size (), 1 );
173202 ASSERT_EQ (interpreter.execution_plan ().size (), 1 );
174203
175- LITERT_ASSERT_OK_AND_ASSIGN (auto env, CreateDefaultEnvironment ());
176204 LITERT_ASSERT_OK_AND_ASSIGN (auto env_options, env.GetOptions ());
177205 LITERT_ASSERT_OK_AND_ASSIGN (
178206 auto options, CreateDispatchOptions (runtime->Flatbuffer ().Buf ().Data ()));
@@ -189,45 +217,41 @@ TEST(DispatchDelegate, HwBuffer) {
189217 kTfLiteOk );
190218
191219 // Create and register tensor buffers for all inputs and outputs.
192- std::vector<litert::TensorBuffer > input_buffers;
220+ std::vector<LiteRtTensorBufferPtr > input_buffers;
193221 for (int i = 0 ; i < interpreter.inputs ().size (); ++i) {
194222 LITERT_ASSERT_OK_AND_ASSIGN (
195- auto * input_buffer_requirements,
223+ const LiteRtTensorBufferRequirementsT * input_buffer_requirements,
196224 buffer_context.GetBufferRequirements (interpreter.input_tensor (i)));
197- LITERT_ASSERT_OK_AND_ASSIGN ( const auto supported_types,
198- input_buffer_requirements->SupportedTypes () );
225+ const std::vector<LiteRtTensorBufferType>& supported_types =
226+ input_buffer_requirements->SupportedBufferTypes ( );
199227 ASSERT_EQ (supported_types.at (0 ), kLiteRtTensorBufferTypeAhwb );
200228 LITERT_ASSERT_OK_AND_ASSIGN (
201- TensorBuffer input_buffer,
229+ LiteRtTensorBufferPtr input_buffer,
202230 buffer_context.CreateBufferForTensor (interpreter.input_tensor (i)));
203- ASSERT_TRUE (input_buffer.IsOwned ());
204- ASSERT_THAT (input_buffer.BufferType (),
205- IsOkAndHolds (kLiteRtTensorBufferTypeAhwb ));
206- LITERT_ASSERT_OK_AND_ASSIGN (TensorBuffer duplicate_buffer,
207- input_buffer.Duplicate ());
231+ ASSERT_EQ (input_buffer->buffer_type (), kLiteRtTensorBufferTypeAhwb );
232+ input_buffer->Duplicate ();
233+ LiteRtTensorBufferPtr duplicate_buffer (input_buffer.get ());
208234 auto status = buffer_context.RegisterTensorBuffer (
209235 interpreter.input_tensor (i), std::move (duplicate_buffer));
210236 ASSERT_EQ (status, kLiteRtStatusOk );
211237 input_buffers.push_back (std::move (input_buffer));
212238 }
213239
214- std::vector<litert::TensorBuffer > output_buffers;
240+ std::vector<LiteRtTensorBufferPtr > output_buffers;
215241 for (int i = 0 ; i < interpreter.outputs ().size (); ++i) {
216242 LITERT_ASSERT_OK_AND_ASSIGN (
217- auto * output_buffer_requirements,
243+ const auto * output_buffer_requirements,
218244 buffer_context.GetBufferRequirements (interpreter.output_tensor (i)));
219245 ASSERT_NE (output_buffer_requirements, nullptr );
220- LITERT_ASSERT_OK_AND_ASSIGN ( const auto supported_types,
221- output_buffer_requirements->SupportedTypes () );
246+ const auto & supported_types =
247+ output_buffer_requirements->SupportedBufferTypes ( );
222248 ASSERT_EQ (supported_types.at (0 ), kLiteRtTensorBufferTypeAhwb );
223249 LITERT_ASSERT_OK_AND_ASSIGN (
224- TensorBuffer output_buffer,
250+ LiteRtTensorBufferPtr output_buffer,
225251 buffer_context.CreateBufferForTensor (interpreter.output_tensor (i)));
226- ASSERT_TRUE (output_buffer.IsOwned ());
227- ASSERT_THAT (output_buffer.BufferType (),
228- IsOkAndHolds (kLiteRtTensorBufferTypeAhwb ));
229- LITERT_ASSERT_OK_AND_ASSIGN (TensorBuffer duplicate_buffer,
230- output_buffer.Duplicate ());
252+ ASSERT_EQ (output_buffer->buffer_type (), kLiteRtTensorBufferTypeAhwb );
253+ output_buffer->Duplicate ();
254+ LiteRtTensorBufferPtr duplicate_buffer (output_buffer.get ());
231255 auto status = buffer_context.RegisterTensorBuffer (
232256 interpreter.output_tensor (i), std::move (duplicate_buffer));
233257 ASSERT_EQ (status, kLiteRtStatusOk );
@@ -247,29 +271,33 @@ TEST(DispatchDelegate, HwBuffer) {
247271 // Fill model inputs.
248272 ASSERT_STREQ (runner->input_names ()[0 ], " arg0" );
249273 auto & input_0_buffer = input_buffers[0 ];
250- input_0_buffer.Write <float >(
251- absl::MakeConstSpan (kTestInput0Tensor , kTestInput0Size ));
274+ LITERT_ASSERT_OK_AND_ASSIGN (
275+ void * host_mem_addr,
276+ input_0_buffer->Lock (kLiteRtTensorBufferLockModeWrite ));
277+ std::memcpy (host_mem_addr, kTestInput0Tensor , sizeof (kTestInput0Tensor ));
278+ LITERT_ASSERT_OK (input_0_buffer->Unlock ());
252279
253280 ASSERT_STREQ (runner->input_names ()[1 ], " arg1" );
254281 auto & input_1_buffer = input_buffers[1 ];
255- input_1_buffer.Write <float >(
256- absl::MakeConstSpan (kTestInput1Tensor , kTestInput1Size ));
282+ LITERT_ASSERT_OK_AND_ASSIGN (
283+ host_mem_addr, input_1_buffer->Lock (kLiteRtTensorBufferLockModeWrite ));
284+ std::memcpy (host_mem_addr, kTestInput1Tensor , sizeof (kTestInput1Tensor ));
285+ LITERT_ASSERT_OK (input_1_buffer->Unlock ());
257286
258287 EXPECT_EQ (runner->Invoke (), kTfLiteOk );
259288
260289 // Check model output.
261290 ASSERT_STREQ (runner->output_names ()[0 ], " tfl.custom" );
262- {
263- LITERT_ASSERT_OK_AND_ASSIGN (
264- auto lock_and_addr,
265- litert::TensorBufferScopedLock::Create<const float >(
266- output_buffers[0 ], TensorBuffer::LockMode::kRead ));
267- auto output = absl::MakeSpan (lock_and_addr.second , kTestOutputSize );
268- for (auto i = 0 ; i < kTestOutputSize ; ++i) {
269- ABSL_LOG (INFO) << " Result: " << output[i] << " \t " << kTestOutputTensor [i];
270- }
271- EXPECT_THAT (output, Pointwise (FloatNear (1e-5 ), kTestOutputTensor ));
291+ LITERT_ASSERT_OK_AND_ASSIGN (
292+ void * output_mem_addr,
293+ output_buffers[0 ]->Lock (kLiteRtTensorBufferLockModeRead ));
294+ absl::Span<const float > output = absl::MakeConstSpan (
295+ reinterpret_cast <const float *>(output_mem_addr), kTestOutputSize );
296+ for (auto i = 0 ; i < kTestOutputSize ; ++i) {
297+ ABSL_LOG (INFO) << " Result: " << output[i] << " \t " << kTestOutputTensor [i];
272298 }
299+ EXPECT_THAT (output, Pointwise (FloatNear (1e-5 ), kTestOutputTensor ));
300+ LITERT_ASSERT_OK (output_buffers[0 ]->Unlock ());
273301}
274302
275303TEST (DispatchDelegate, CompiledModel) {
@@ -453,7 +481,8 @@ TEST(DispatchDelegate, CompiledModelMultiRun) {
453481 {
454482 LITERT_ASSERT_OK_AND_ASSIGN (
455483 auto lock_and_addr,
456- litert::TensorBufferScopedLock::Create<const float >(output_buffers[0 ]));
484+ litert::TensorBufferScopedLock::Create<const float >(
485+ output_buffers[0 ], TensorBuffer::LockMode::kRead ));
457486 auto output = absl::MakeSpan (lock_and_addr.second , kTestOutputSize );
458487 for (auto i = 0 ; i < kTestOutputSize ; ++i) {
459488 ABSL_LOG (INFO) << " Result: " << output[i] << " \t "
0 commit comments