Skip to content

Commit 301c57b

Browse files
ai-edge-botcopybara-github
authored andcommitted
Support dynamic tensor shape in LiteRt Inference Runner.
LiteRT-PiperOrigin-RevId: 815896283
1 parent 0d43b3a commit 301c57b

File tree

12 files changed

+326
-100
lines changed

12 files changed

+326
-100
lines changed

litert/c/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ cc_library(
442442
deps = [
443443
":litert_common",
444444
":litert_environment",
445+
":litert_layout",
445446
":litert_logging",
446447
":litert_metrics",
447448
":litert_model",
@@ -468,11 +469,13 @@ cc_test(
468469
":litert_compiled_model",
469470
":litert_environment",
470471
":litert_environment_options",
472+
":litert_layout",
471473
":litert_logging",
472474
":litert_model",
473475
":litert_options",
474476
":litert_tensor_buffer",
475477
"//litert/test:common",
478+
"//litert/test:matchers",
476479
"//litert/test:simple_model",
477480
"@com_google_absl//absl/log:absl_log",
478481
"@com_google_absl//absl/strings:string_view",
@@ -493,6 +496,7 @@ cc_test(
493496
":litert_environment_options",
494497
":litert_runtime_c_api_shared_lib",
495498
"//litert/test:common",
499+
"//litert/test:matchers",
496500
"//litert/test:simple_model",
497501
"@com_google_absl//absl/log:absl_log",
498502
"@com_google_absl//absl/strings:string_view",

litert/c/litert_compiled_model.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/types/span.h" // from @com_google_absl
2828
#include "litert/c/litert_common.h"
2929
#include "litert/c/litert_environment.h"
30+
#include "litert/c/litert_layout.h"
3031
#include "litert/c/litert_logging.h"
3132
#include "litert/c/litert_metrics.h"
3233
#include "litert/c/litert_model.h"
@@ -96,6 +97,23 @@ LiteRtStatus LiteRtGetCompiledModelEnvironment(
9697
return kLiteRtStatusOk;
9798
}
9899

100+
LiteRtStatus LiteRtGetCompiledModelOutputTensorLayouts(
101+
LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index,
102+
size_t num_layouts, LiteRtLayout* layouts, bool update_allocation) {
103+
if (!compiled_model || !layouts) {
104+
return kLiteRtStatusErrorInvalidArgument;
105+
}
106+
absl::Span<LiteRtLayout> output_layouts(layouts, num_layouts);
107+
LITERT_RETURN_IF_ERROR(compiled_model->GetOutputTensorShapes(
108+
signature_index, output_layouts, update_allocation));
109+
size_t tensors_size = output_layouts.size();
110+
if (tensors_size == 0) {
111+
LITERT_LOG(LITERT_WARNING, "No output tensors found for signature index.");
112+
return kLiteRtStatusErrorInvalidArgument;
113+
}
114+
return kLiteRtStatusOk;
115+
}
116+
99117
LiteRtStatus LiteRtRunCompiledModel(LiteRtCompiledModel compiled_model,
100118
LiteRtParamIndex signature_index,
101119
size_t num_input_buffers,

litert/c/litert_compiled_model.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <stddef.h>
1919

2020
#include "litert/c/litert_common.h"
21+
#include "litert/c/litert_layout.h"
2122

2223
#ifdef __cplusplus
2324
extern "C" {
@@ -85,6 +86,23 @@ LiteRtStatus LiteRtGetCompiledModelOutputBufferRequirements(
8586
LiteRtParamIndex output_index,
8687
LiteRtTensorBufferRequirements* buffer_requirements);
8788

89+
// Returns the tensor layouts for all output tensors.
90+
//
91+
// Parameters:
92+
// - compiled_model: the target `LiteRtCompiledModel` object.
93+
// - signature_index: the index of the signature in `LiteRtModel`.
94+
// - num_layouts: the number of output tensor layouts.
95+
// - layouts: user allocated memory to store `LiteRtLayout` for tensor outputs.
96+
// - update_allocation: whether to update the tensor allocation. Set to true
97+
// for dynamic models after resize input tensors.
98+
//
99+
// Note: This function usually should be called after resizing input tensors
100+
// to get the new output tensor layouts. User should be responsible for
101+
// allocation and deallocating of the layouts memory.
102+
LiteRtStatus LiteRtGetCompiledModelOutputTensorLayouts(
103+
LiteRtCompiledModel compiled_model, LiteRtParamIndex signature_index,
104+
size_t num_layouts, LiteRtLayout* layouts, bool update_allocation);
105+
88106
// Returns the associated environment of the given compiled model.
89107
LiteRtStatus LiteRtGetCompiledModelEnvironment(
90108
LiteRtCompiledModel compiled_model, LiteRtEnvironment* environment);
@@ -183,7 +201,9 @@ LiteRtStatus LiteRtCompiledModelGetProfiler(LiteRtCompiledModel compiled_model,
183201
// - dims: A span containing the new dimensions for the input tensor.
184202
//
185203
// Note: After resizing, the previously obtained buffer requirements may be
186-
// invalidated. Callers should re-query buffer requirements if needed.
204+
// invalidated. Callers should re-query buffer requirements if needed. After
205+
// resizing, LiteRtGetCompiledModelAllOutputTensorLayouts can be used to get
206+
// the new output tensor layouts.
187207
//
188208
// Returns:
189209
// - kLiteRtStatusOk: Success.

0 commit comments

Comments
 (0)