Skip to content

Commit cccd495

Browse files
terryheocopybara-github
authored andcommitted
Fix broken dispatch_delegate_*_test
Applied recent refactoring. LiteRT-PiperOrigin-RevId: 819963167
1 parent 24b31a9 commit cccd495

File tree

5 files changed

+128
-53
lines changed

5 files changed

+128
-53
lines changed

litert/runtime/dispatch/BUILD

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,14 @@ litert_device_test(
329329
"//litert/cc:litert_environment",
330330
"//litert/cc:litert_expected",
331331
"//litert/cc:litert_macros",
332-
"//litert/cc/dynamic_runtime:litert_compiled_model",
333-
"//litert/cc/dynamic_runtime:litert_environment",
334332
"//litert/cc/dynamic_runtime:litert_model",
335333
"//litert/cc/dynamic_runtime:litert_options",
336334
"//litert/cc/dynamic_runtime:litert_tensor_buffer",
337335
"//litert/cc/internal:litert_dispatch_delegate",
338336
"//litert/runtime:external_litert_buffer_context",
339337
"//litert/runtime:tensor_buffer",
338+
"//litert/runtime:tensor_identifier",
339+
"//litert/runtime:tfl_utils",
340340
"//litert/test:common",
341341
"//litert/test:matchers",
342342
"//litert/test:simple_model_npu",
@@ -430,14 +430,14 @@ litert_device_test(
430430
"//litert/cc:litert_compiled_model",
431431
"//litert/cc:litert_environment",
432432
"//litert/cc:litert_expected",
433-
"//litert/cc/dynamic_runtime:litert_compiled_model",
434-
"//litert/cc/dynamic_runtime:litert_environment",
433+
"//litert/cc:litert_macros",
435434
"//litert/cc/dynamic_runtime:litert_model",
436435
"//litert/cc/dynamic_runtime:litert_options",
437436
"//litert/cc/dynamic_runtime:litert_tensor_buffer",
438-
"//litert/core/model:model_buffer",
439-
"//litert/core/util:flatbuffer_tools",
440437
"//litert/runtime:external_litert_buffer_context",
438+
"//litert/runtime:tensor_buffer",
439+
"//litert/runtime:tensor_identifier",
440+
"//litert/runtime:tfl_utils",
441441
"//litert/test:common",
442442
"//litert/test:matchers",
443443
"//litert/test:simple_model_npu",

litert/runtime/dispatch/dispatch_delegate_google_tensor_test.cc

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include "litert/runtime/external_litert_buffer_context.h"
5050
#include "litert/runtime/tensor_buffer.h"
5151
#include "litert/runtime/tensor_buffer_requirements.h"
52+
#include "litert/runtime/tensor_identifier.h"
53+
#include "litert/runtime/tfl_utils.h"
5254
#include "litert/test/common.h"
5355
#include "litert/test/matchers.h"
5456
#include "litert/test/testdata/simple_model_test_vectors.h"
@@ -90,6 +92,25 @@ litert::Expected<Options> CreateDispatchOptions(const uint8_t* base) {
9092
return options;
9193
}
9294

95+
LiteRtExternalLiteRtBufferContextT CreateBufferContext(
96+
const LiteRtEnvironment& env, const tflite::Interpreter& interpreter) {
97+
auto get_tensor_id = [&interpreter](const TfLiteOpaqueTensor* target_tensor)
98+
-> litert::internal::TfLiteTensorIdentifier {
99+
auto tensor_id = litert::internal::GetTensorIdentifier(
100+
interpreter, reinterpret_cast<const TfLiteTensor*>(target_tensor));
101+
if (!tensor_id) {
102+
LITERT_LOG(LITERT_ERROR, "Failed to get tensor identifier: %s",
103+
tensor_id.Error().Message().c_str());
104+
constexpr litert::internal::TfLiteTensorIdentifier kInvalidTensorId{-1,
105+
-1};
106+
return kInvalidTensorId;
107+
}
108+
return *tensor_id;
109+
};
110+
111+
return LiteRtExternalLiteRtBufferContextT(env, get_tensor_id);
112+
}
113+
93114
TEST(DispatchDelegate, CpuBuffer) {
94115
// The dispatch delegate must be declared before the TFL interpreter so that
95116
// it gets destroyed only after the interpreter and the dispatch delegate
@@ -102,15 +123,16 @@ TEST(DispatchDelegate, CpuBuffer) {
102123
MakeRuntimeFromTestFile(kPrecompiledTfliteFile));
103124
tflite::Interpreter& interpreter = runtime->Interpreter();
104125

105-
LiteRtExternalLiteRtBufferContextT buffer_context;
126+
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
127+
LiteRtExternalLiteRtBufferContextT buffer_context =
128+
CreateBufferContext(env.Get(), interpreter);
106129
interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context);
107130

108131
EXPECT_EQ(interpreter.nodes_size(), 1);
109132
EXPECT_EQ(interpreter.inputs().size(), 2);
110133
EXPECT_EQ(interpreter.outputs().size(), 1);
111134
ASSERT_EQ(interpreter.execution_plan().size(), 1);
112135

113-
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
114136
LITERT_ASSERT_OK_AND_ASSIGN(auto env_options, env.GetOptions());
115137
LITERT_ASSERT_OK_AND_ASSIGN(
116138
auto options, CreateDispatchOptions(runtime->Flatbuffer().Buf().Data()));
@@ -173,16 +195,16 @@ TEST(DispatchDelegate, HwBuffer) {
173195
LITERT_ASSERT_OK_AND_ASSIGN(testing::TflRuntime::Ptr runtime,
174196
MakeRuntimeFromTestFile(kPrecompiledTfliteFile));
175197
tflite::Interpreter& interpreter = runtime->Interpreter();
176-
177-
LiteRtExternalLiteRtBufferContextT buffer_context;
198+
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
199+
LiteRtExternalLiteRtBufferContextT buffer_context =
200+
CreateBufferContext(env.Get(), interpreter);
178201
interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context);
179202

180203
EXPECT_EQ(interpreter.nodes_size(), 1);
181204
EXPECT_EQ(interpreter.inputs().size(), 2);
182205
EXPECT_EQ(interpreter.outputs().size(), 1);
183206
ASSERT_EQ(interpreter.execution_plan().size(), 1);
184207

185-
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
186208
LITERT_ASSERT_OK_AND_ASSIGN(auto env_options, env.GetOptions());
187209
LITERT_ASSERT_OK_AND_ASSIGN(
188210
auto options, CreateDispatchOptions(runtime->Flatbuffer().Buf().Data()));

litert/runtime/dispatch/dispatch_delegate_mediatek_test.cc

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,23 @@
2626
#include "absl/strings/string_view.h" // from @com_google_absl
2727
#include "absl/types/span.h" // from @com_google_absl
2828
#include "litert/c/litert_common.h"
29-
#include "litert/c/litert_environment.h"
29+
#include "litert/c/litert_tensor_buffer.h"
3030
#include "litert/c/litert_tensor_buffer_types.h"
3131
#include "litert/cc/internal/litert_dispatch_delegate.h"
3232
#include "litert/cc/litert_compiled_model.h"
3333
#include "litert/cc/litert_environment.h"
3434
#include "litert/cc/litert_expected.h"
35+
#include "litert/cc/litert_macros.h"
3536
#include "litert/cc/litert_model.h"
3637
#include "litert/cc/litert_options.h"
3738
#include "litert/cc/litert_tensor_buffer.h"
3839
#include "litert/cc/litert_tensor_buffer_requirements.h"
3940
#include "litert/runtime/dispatch/dispatch_opaque_options.h"
4041
#include "litert/runtime/external_litert_buffer_context.h"
42+
#include "litert/runtime/tensor_buffer.h"
43+
#include "litert/runtime/tensor_buffer_requirements.h"
44+
#include "litert/runtime/tensor_identifier.h"
45+
#include "litert/runtime/tfl_utils.h"
4146
#include "litert/test/common.h"
4247
#include "litert/test/matchers.h"
4348
#include "litert/test/testdata/simple_model_test_vectors.h"
@@ -78,6 +83,25 @@ litert::Expected<Options> CreateDispatchOptions(const uint8_t* base) {
7883
return options;
7984
}
8085

86+
LiteRtExternalLiteRtBufferContextT CreateBufferContext(
87+
const LiteRtEnvironment& env, const tflite::Interpreter& interpreter) {
88+
auto get_tensor_id = [&interpreter](const TfLiteOpaqueTensor* target_tensor)
89+
-> litert::internal::TfLiteTensorIdentifier {
90+
auto tensor_id = litert::internal::GetTensorIdentifier(
91+
interpreter, reinterpret_cast<const TfLiteTensor*>(target_tensor));
92+
if (!tensor_id) {
93+
LITERT_LOG(LITERT_ERROR, "Failed to get tensor identifier: %s",
94+
tensor_id.Error().Message().c_str());
95+
constexpr litert::internal::TfLiteTensorIdentifier kInvalidTensorId{-1,
96+
-1};
97+
return kInvalidTensorId;
98+
}
99+
return *tensor_id;
100+
};
101+
102+
return LiteRtExternalLiteRtBufferContextT(env, get_tensor_id);
103+
}
104+
81105
TEST(DispatchDelegate, CpuBuffer) {
82106
// The dispatch delegate must be declared before the TFL interpreter so that
83107
// it gets destroyed only after the interpreter and the dispatch delegate
@@ -91,15 +115,17 @@ TEST(DispatchDelegate, CpuBuffer) {
91115
MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile));
92116
tflite::Interpreter& interpreter = runtime->Interpreter();
93117

94-
litert::internal::ExternalLiteRtBufferContext buffer_context;
118+
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
119+
120+
LiteRtExternalLiteRtBufferContextT buffer_context =
121+
CreateBufferContext(env.Get(), interpreter);
95122
interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context);
96123

97124
EXPECT_EQ(interpreter.nodes_size(), 1);
98125
EXPECT_EQ(interpreter.inputs().size(), 2);
99126
EXPECT_EQ(interpreter.outputs().size(), 1);
100127
ASSERT_EQ(interpreter.execution_plan().size(), 1);
101128

102-
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
103129
LITERT_ASSERT_OK_AND_ASSIGN(auto env_options, env.GetOptions());
104130
LITERT_ASSERT_OK_AND_ASSIGN(
105131
auto options, CreateDispatchOptions(runtime->Flatbuffer().Buf().Data()));
@@ -164,15 +190,17 @@ TEST(DispatchDelegate, HwBuffer) {
164190
MakeRuntimeFromTestFileWithNpuModel(kTfliteFile, kNpuFile));
165191
tflite::Interpreter& interpreter = runtime->Interpreter();
166192

167-
litert::internal::ExternalLiteRtBufferContext buffer_context;
193+
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
194+
195+
LiteRtExternalLiteRtBufferContextT buffer_context =
196+
CreateBufferContext(env.Get(), interpreter);
168197
interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context);
169198

170199
EXPECT_EQ(interpreter.nodes_size(), 1);
171200
EXPECT_EQ(interpreter.inputs().size(), 2);
172201
EXPECT_EQ(interpreter.outputs().size(), 1);
173202
ASSERT_EQ(interpreter.execution_plan().size(), 1);
174203

175-
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
176204
LITERT_ASSERT_OK_AND_ASSIGN(auto env_options, env.GetOptions());
177205
LITERT_ASSERT_OK_AND_ASSIGN(
178206
auto options, CreateDispatchOptions(runtime->Flatbuffer().Buf().Data()));
@@ -189,45 +217,41 @@ TEST(DispatchDelegate, HwBuffer) {
189217
kTfLiteOk);
190218

191219
// Create and register tensor buffers for all inputs and outputs.
192-
std::vector<litert::TensorBuffer> input_buffers;
220+
std::vector<LiteRtTensorBufferPtr> input_buffers;
193221
for (int i = 0; i < interpreter.inputs().size(); ++i) {
194222
LITERT_ASSERT_OK_AND_ASSIGN(
195-
auto* input_buffer_requirements,
223+
const LiteRtTensorBufferRequirementsT* input_buffer_requirements,
196224
buffer_context.GetBufferRequirements(interpreter.input_tensor(i)));
197-
LITERT_ASSERT_OK_AND_ASSIGN(const auto supported_types,
198-
input_buffer_requirements->SupportedTypes());
225+
const std::vector<LiteRtTensorBufferType>& supported_types =
226+
input_buffer_requirements->SupportedBufferTypes();
199227
ASSERT_EQ(supported_types.at(0), kLiteRtTensorBufferTypeAhwb);
200228
LITERT_ASSERT_OK_AND_ASSIGN(
201-
TensorBuffer input_buffer,
229+
LiteRtTensorBufferPtr input_buffer,
202230
buffer_context.CreateBufferForTensor(interpreter.input_tensor(i)));
203-
ASSERT_TRUE(input_buffer.IsOwned());
204-
ASSERT_THAT(input_buffer.BufferType(),
205-
IsOkAndHolds(kLiteRtTensorBufferTypeAhwb));
206-
LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer duplicate_buffer,
207-
input_buffer.Duplicate());
231+
ASSERT_EQ(input_buffer->buffer_type(), kLiteRtTensorBufferTypeAhwb);
232+
input_buffer->Duplicate();
233+
LiteRtTensorBufferPtr duplicate_buffer(input_buffer.get());
208234
auto status = buffer_context.RegisterTensorBuffer(
209235
interpreter.input_tensor(i), std::move(duplicate_buffer));
210236
ASSERT_EQ(status, kLiteRtStatusOk);
211237
input_buffers.push_back(std::move(input_buffer));
212238
}
213239

214-
std::vector<litert::TensorBuffer> output_buffers;
240+
std::vector<LiteRtTensorBufferPtr> output_buffers;
215241
for (int i = 0; i < interpreter.outputs().size(); ++i) {
216242
LITERT_ASSERT_OK_AND_ASSIGN(
217-
auto* output_buffer_requirements,
243+
const auto* output_buffer_requirements,
218244
buffer_context.GetBufferRequirements(interpreter.output_tensor(i)));
219245
ASSERT_NE(output_buffer_requirements, nullptr);
220-
LITERT_ASSERT_OK_AND_ASSIGN(const auto supported_types,
221-
output_buffer_requirements->SupportedTypes());
246+
const auto& supported_types =
247+
output_buffer_requirements->SupportedBufferTypes();
222248
ASSERT_EQ(supported_types.at(0), kLiteRtTensorBufferTypeAhwb);
223249
LITERT_ASSERT_OK_AND_ASSIGN(
224-
TensorBuffer output_buffer,
250+
LiteRtTensorBufferPtr output_buffer,
225251
buffer_context.CreateBufferForTensor(interpreter.output_tensor(i)));
226-
ASSERT_TRUE(output_buffer.IsOwned());
227-
ASSERT_THAT(output_buffer.BufferType(),
228-
IsOkAndHolds(kLiteRtTensorBufferTypeAhwb));
229-
LITERT_ASSERT_OK_AND_ASSIGN(TensorBuffer duplicate_buffer,
230-
output_buffer.Duplicate());
252+
ASSERT_EQ(output_buffer->buffer_type(), kLiteRtTensorBufferTypeAhwb);
253+
output_buffer->Duplicate();
254+
LiteRtTensorBufferPtr duplicate_buffer(output_buffer.get());
231255
auto status = buffer_context.RegisterTensorBuffer(
232256
interpreter.output_tensor(i), std::move(duplicate_buffer));
233257
ASSERT_EQ(status, kLiteRtStatusOk);
@@ -247,29 +271,33 @@ TEST(DispatchDelegate, HwBuffer) {
247271
// Fill model inputs.
248272
ASSERT_STREQ(runner->input_names()[0], "arg0");
249273
auto& input_0_buffer = input_buffers[0];
250-
input_0_buffer.Write<float>(
251-
absl::MakeConstSpan(kTestInput0Tensor, kTestInput0Size));
274+
LITERT_ASSERT_OK_AND_ASSIGN(
275+
void* host_mem_addr,
276+
input_0_buffer->Lock(kLiteRtTensorBufferLockModeWrite));
277+
std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor));
278+
LITERT_ASSERT_OK(input_0_buffer->Unlock());
252279

253280
ASSERT_STREQ(runner->input_names()[1], "arg1");
254281
auto& input_1_buffer = input_buffers[1];
255-
input_1_buffer.Write<float>(
256-
absl::MakeConstSpan(kTestInput1Tensor, kTestInput1Size));
282+
LITERT_ASSERT_OK_AND_ASSIGN(
283+
host_mem_addr, input_1_buffer->Lock(kLiteRtTensorBufferLockModeWrite));
284+
std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor));
285+
LITERT_ASSERT_OK(input_1_buffer->Unlock());
257286

258287
EXPECT_EQ(runner->Invoke(), kTfLiteOk);
259288

260289
// Check model output.
261290
ASSERT_STREQ(runner->output_names()[0], "tfl.custom");
262-
{
263-
LITERT_ASSERT_OK_AND_ASSIGN(
264-
auto lock_and_addr,
265-
litert::TensorBufferScopedLock::Create<const float>(
266-
output_buffers[0], TensorBuffer::LockMode::kRead));
267-
auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize);
268-
for (auto i = 0; i < kTestOutputSize; ++i) {
269-
ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i];
270-
}
271-
EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor));
291+
LITERT_ASSERT_OK_AND_ASSIGN(
292+
void* output_mem_addr,
293+
output_buffers[0]->Lock(kLiteRtTensorBufferLockModeRead));
294+
absl::Span<const float> output = absl::MakeConstSpan(
295+
reinterpret_cast<const float*>(output_mem_addr), kTestOutputSize);
296+
for (auto i = 0; i < kTestOutputSize; ++i) {
297+
ABSL_LOG(INFO) << "Result: " << output[i] << "\t" << kTestOutputTensor[i];
272298
}
299+
EXPECT_THAT(output, Pointwise(FloatNear(1e-5), kTestOutputTensor));
300+
LITERT_ASSERT_OK(output_buffers[0]->Unlock());
273301
}
274302

275303
TEST(DispatchDelegate, CompiledModel) {
@@ -453,7 +481,8 @@ TEST(DispatchDelegate, CompiledModelMultiRun) {
453481
{
454482
LITERT_ASSERT_OK_AND_ASSIGN(
455483
auto lock_and_addr,
456-
litert::TensorBufferScopedLock::Create<const float>(output_buffers[0]));
484+
litert::TensorBufferScopedLock::Create<const float>(
485+
output_buffers[0], TensorBuffer::LockMode::kRead));
457486
auto output = absl::MakeSpan(lock_and_addr.second, kTestOutputSize);
458487
for (auto i = 0; i < kTestOutputSize; ++i) {
459488
ABSL_LOG(INFO) << "Result: " << output[i] << "\t"

litert/runtime/dispatch/dispatch_delegate_qualcomm_test.cc

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
#include "litert/runtime/external_litert_buffer_context.h"
4141
#include "litert/runtime/tensor_buffer.h"
4242
#include "litert/runtime/tensor_buffer_requirements.h"
43+
#include "litert/runtime/tensor_identifier.h"
44+
#include "litert/runtime/tfl_utils.h"
4345
#include "litert/test/common.h"
4446
#include "litert/test/matchers.h"
4547
#include "litert/test/testdata/simple_model_test_vectors.h"
@@ -80,6 +82,25 @@ litert::Expected<Options> CreateDispatchOptions(const uint8_t* base) {
8082
return options;
8183
}
8284

85+
LiteRtExternalLiteRtBufferContextT CreateBufferContext(
86+
const LiteRtEnvironment& env, const tflite::Interpreter& interpreter) {
87+
auto get_tensor_id = [&interpreter](const TfLiteOpaqueTensor* target_tensor)
88+
-> litert::internal::TfLiteTensorIdentifier {
89+
auto tensor_id = litert::internal::GetTensorIdentifier(
90+
interpreter, reinterpret_cast<const TfLiteTensor*>(target_tensor));
91+
if (!tensor_id) {
92+
LITERT_LOG(LITERT_ERROR, "Failed to get tensor identifier: %s",
93+
tensor_id.Error().Message().c_str());
94+
constexpr litert::internal::TfLiteTensorIdentifier kInvalidTensorId{-1,
95+
-1};
96+
return kInvalidTensorId;
97+
}
98+
return *tensor_id;
99+
};
100+
101+
return LiteRtExternalLiteRtBufferContextT(env, get_tensor_id);
102+
}
103+
83104
TEST(DispatchDelegate, CpuBuffer) {
84105
// The dispatch delegate must be declared before the TFL interpreter so that
85106
// it gets destroyed only after the interpreter and the dispatch delegate
@@ -95,7 +116,8 @@ TEST(DispatchDelegate, CpuBuffer) {
95116

96117
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
97118

98-
LiteRtExternalLiteRtBufferContextT buffer_context(env.Get());
119+
LiteRtExternalLiteRtBufferContextT buffer_context =
120+
CreateBufferContext(env.Get(), interpreter);
99121
interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context);
100122

101123
EXPECT_EQ(interpreter.nodes_size(), 1);
@@ -169,7 +191,8 @@ TEST(DispatchDelegate, HwBuffer) {
169191

170192
LITERT_ASSERT_OK_AND_ASSIGN(auto env, CreateDefaultEnvironment());
171193

172-
LiteRtExternalLiteRtBufferContextT buffer_context(env.Get());
194+
LiteRtExternalLiteRtBufferContextT buffer_context =
195+
CreateBufferContext(env.Get(), interpreter);
173196
interpreter.SetExternalContext(kTfLiteLiteRtBufferContext, &buffer_context);
174197

175198
EXPECT_EQ(interpreter.nodes_size(), 1);

litert/runtime/external_litert_buffer_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class LiteRtExternalLiteRtBufferContextT : public TfLiteExternalContext {
6363
explicit LiteRtExternalLiteRtBufferContextT(
6464
LiteRtEnvironment env, GetTensorIdentifierFn get_tensor_identifier_fn)
6565
: env_(env), get_tensor_identifier_fn_(get_tensor_identifier_fn) {}
66+
6667
~LiteRtExternalLiteRtBufferContextT() = default;
6768

6869
static litert::Expected<LiteRtExternalLiteRtBufferContextT*> GetInstance(

0 commit comments

Comments
 (0)