Skip to content

Commit 4199154

Browse files
ai-edge-botcopybara-github
authored andcommitted
Reverts 7237356
LiteRT-PiperOrigin-RevId: 820748563
1 parent 2523703 commit 4199154

15 files changed

+368
-66
lines changed

litert/c/litert_model.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,33 @@ LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature,
229229
return kLiteRtStatusOk;
230230
}
231231

232+
LiteRtStatus LiteRtGetSignatureInputTensor(LiteRtSignature signature,
233+
const char* input_name,
234+
LiteRtTensor* tensor) {
235+
if (!signature || !input_name || !tensor) {
236+
return kLiteRtStatusErrorInvalidArgument;
237+
}
238+
auto input_tensor = signature->FindInputTensor(input_name);
239+
if (!input_tensor) {
240+
return input_tensor.Error().Status();
241+
}
242+
*tensor = *input_tensor;
243+
return kLiteRtStatusOk;
244+
}
245+
246+
LiteRtStatus LiteRtGetSignatureInputTensorByIndex(LiteRtSignature signature,
247+
LiteRtParamIndex input_idx,
248+
LiteRtTensor* tensor) {
249+
if (!signature || !tensor) {
250+
return kLiteRtStatusErrorInvalidArgument;
251+
}
252+
if (input_idx >= signature->InputNames().size()) {
253+
return kLiteRtStatusErrorIndexOOB;
254+
}
255+
*tensor = signature->GetInputTensor(input_idx);
256+
return kLiteRtStatusOk;
257+
}
258+
232259
LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature,
233260
LiteRtParamIndex* num_outputs) {
234261
if (!signature || !num_outputs) {
@@ -251,6 +278,33 @@ LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature,
251278
return kLiteRtStatusOk;
252279
}
253280

281+
LiteRtStatus LiteRtGetSignatureOutputTensor(LiteRtSignature signature,
282+
const char* output_name,
283+
LiteRtTensor* tensor) {
284+
if (!signature || !output_name || !tensor) {
285+
return kLiteRtStatusErrorInvalidArgument;
286+
}
287+
auto output_tensor = signature->FindOutputTensor(output_name);
288+
if (!output_tensor) {
289+
return output_tensor.Error().Status();
290+
}
291+
*tensor = *output_tensor;
292+
return kLiteRtStatusOk;
293+
}
294+
295+
LiteRtStatus LiteRtGetSignatureOutputTensorByIndex(LiteRtSignature signature,
296+
LiteRtParamIndex output_idx,
297+
LiteRtTensor* tensor) {
298+
if (!signature || !tensor) {
299+
return kLiteRtStatusErrorInvalidArgument;
300+
}
301+
if (output_idx >= signature->OutputNames().size()) {
302+
return kLiteRtStatusErrorIndexOOB;
303+
}
304+
*tensor = signature->GetOutputTensor(output_idx);
305+
return kLiteRtStatusOk;
306+
}
307+
254308
//
255309
// Subgraph
256310
//

litert/c/litert_model.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,16 @@ LiteRtStatus LiteRtGetSignatureInputName(LiteRtSignature signature,
176176
LiteRtParamIndex input_idx,
177177
const char** input_name);
178178

179+
// Get the input tensor for the given signature and input name.
180+
LiteRtStatus LiteRtGetSignatureInputTensor(LiteRtSignature signature,
181+
const char* input_name,
182+
LiteRtTensor* tensor);
183+
184+
// Get the input tensor for the given signature and input index.
185+
LiteRtStatus LiteRtGetSignatureInputTensorByIndex(LiteRtSignature signature,
186+
LiteRtParamIndex input_idx,
187+
LiteRtTensor* tensor);
188+
179189
// Get the number of outputs for the given signature.
180190
LiteRtStatus LiteRtGetNumSignatureOutputs(LiteRtSignature signature,
181191
LiteRtParamIndex* num_outputs);
@@ -187,6 +197,16 @@ LiteRtStatus LiteRtGetSignatureOutputName(LiteRtSignature signature,
187197
LiteRtParamIndex output_idx,
188198
const char** output_name);
189199

200+
// Get the output tensor for the given signature and output name.
201+
LiteRtStatus LiteRtGetSignatureOutputTensor(LiteRtSignature signature,
202+
const char* output_name,
203+
LiteRtTensor* tensor);
204+
205+
// Get the output tensor for the given signature and output index.
206+
LiteRtStatus LiteRtGetSignatureOutputTensorByIndex(LiteRtSignature signature,
207+
LiteRtParamIndex output_idx,
208+
LiteRtTensor* tensor);
209+
190210
//
191211
// LiteRtModel
192212
//

litert/c/windows_exported_symbols.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,12 @@ EXPORTS
109109
LiteRtGetRuntimeOptionsIdentifier
110110
LiteRtGetRuntimeOptionsShloCompositeInlining
111111
LiteRtGetSignatureInputName
112+
LiteRtGetSignatureInputTensor
113+
LiteRtGetSignatureInputTensorByIndex
112114
LiteRtGetSignatureKey
113115
LiteRtGetSignatureOutputName
116+
LiteRtGetSignatureOutputTensor
117+
LiteRtGetSignatureOutputTensorByIndex
114118
LiteRtGetSignatureSubgraph
115119
LiteRtGetSubgraphInput
116120
LiteRtGetSubgraphOp

litert/cc/litert_compiled_model.cc

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(
9393

9494
LITERT_ASSIGN_OR_RETURN(Subgraph subgraph, model_.Subgraph(signature.Key()));
9595

96-
Expected<Tensor> tensor_expected =
97-
is_input ? subgraph.Input(tensor_name) : subgraph.Output(tensor_name);
96+
Expected<Tensor> tensor_expected = is_input
97+
? signature.InputTensor(tensor_name)
98+
: signature.OutputTensor(tensor_name);
9899
Expected<TensorBufferRequirements> buffer_requirements_expected =
99100
is_input ? GetInputBufferRequirements(signature_index, tensor_name)
100101
: GetOutputBufferRequirements(signature_index, tensor_name);
@@ -111,8 +112,6 @@ Expected<std::vector<TensorBuffer>> CompiledModel::CreateInputOutputBuffers(
111112
size_t signature_index, bool is_input) const {
112113
LITERT_ASSIGN_OR_RETURN(const Signature& signature,
113114
model_.GetSignature(signature_index));
114-
LITERT_ASSIGN_OR_RETURN(const Subgraph subgraph,
115-
model_.Subgraph(signature.Key()));
116115
std::vector<TensorBuffer> tensor_buffers;
117116
std::vector<absl::string_view> tensor_names;
118117

@@ -176,24 +175,22 @@ Expected<void> CompiledModel::RunMapHelper(
176175
return Unexpected(kLiteRtStatusErrorNotFound,
177176
"Failed to get signature_index");
178177
}
179-
auto subgraph = model_.Subgraph(signature_key);
180-
if (!subgraph) {
181-
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph");
182-
}
183-
return RunMapWithIndexHelper(*signature_index, *subgraph, input_map,
178+
LITERT_ASSIGN_OR_RETURN(Signature signature,
179+
model_.GetSignature(*signature_index));
180+
return RunMapWithIndexHelper(*signature_index, signature, input_map,
184181
output_map, async);
185182
}
186183

187184
Expected<void> CompiledModel::RunMapWithIndexHelper(
188-
size_t signature_index, const Subgraph& subgraph,
185+
size_t signature_index, const Signature& signature,
189186
const absl::flat_hash_map<absl::string_view, TensorBuffer>& input_map,
190187
const absl::flat_hash_map<absl::string_view, TensorBuffer>& output_map,
191188
bool& async) const {
192-
auto input_tensors = subgraph.Inputs();
193-
size_t num_inputs = input_tensors.size();
189+
auto input_names = signature.InputNames();
190+
size_t num_inputs = input_names.size();
194191
auto input_buffers_ptr = std::make_unique<LiteRtTensorBuffer[]>(num_inputs);
195192
for (int i = 0; i < num_inputs; ++i) {
196-
absl::string_view input_name = input_tensors[i].Name();
193+
absl::string_view input_name = input_names[i];
197194
auto it = input_map.find(input_name);
198195
// if the input is not provided in the input map, we set it to nullptr.
199196
if (it == input_map.end()) {
@@ -202,11 +199,11 @@ Expected<void> CompiledModel::RunMapWithIndexHelper(
202199
}
203200
input_buffers_ptr[i] = it->second.Get();
204201
}
205-
auto output_tensors = subgraph.Outputs();
206-
size_t num_outputs = output_tensors.size();
202+
auto output_names = signature.OutputNames();
203+
size_t num_outputs = output_names.size();
207204
auto output_buffers_ptr = std::make_unique<LiteRtTensorBuffer[]>(num_outputs);
208205
for (int i = 0; i < num_outputs; ++i) {
209-
absl::string_view output_name = output_tensors[i].Name();
206+
absl::string_view output_name = output_names[i];
210207
auto it = output_map.find(output_name);
211208
if (it == output_map.end()) {
212209
return Unexpected(kLiteRtStatusErrorNotFound,

litert/cc/litert_compiled_model.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,8 @@ class CompiledModel
394394
const absl::flat_hash_map<absl::string_view, TensorBuffer>& output_map)
395395
const {
396396
bool async = false;
397-
auto subgraph = model_.MainSubgraph();
398-
if (!subgraph) {
399-
return Unexpected(kLiteRtStatusErrorNotFound,
400-
"Failed to get main subgraph");
401-
}
402-
return RunMapWithIndexHelper(/*signature_index=*/0, *subgraph, input_map,
397+
LITERT_ASSIGN_OR_RETURN(Signature signature, model_.GetSignature(0));
398+
return RunMapWithIndexHelper(/*signature_index=*/0, signature, input_map,
403399
output_map, async);
404400
}
405401

@@ -698,7 +694,7 @@ class CompiledModel
698694
bool& async) const;
699695

700696
Expected<void> RunMapWithIndexHelper(
701-
size_t signature_index, const Subgraph& subgraph,
697+
size_t signature_index, const Signature& signature,
702698
const absl::flat_hash_map<absl::string_view, TensorBuffer>& input_map,
703699
const absl::flat_hash_map<absl::string_view, TensorBuffer>& output_map,
704700
bool& async) const;

litert/cc/litert_compiled_model_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,12 +547,12 @@ TEST(CompiledModelTest, ResizeInputTensorWithDynamicModel) {
547547
// Get the element type from the original model.
548548
LITERT_ASSERT_OK_AND_ASSIGN(const Signature& signature,
549549
model.GetSignature(0));
550-
LITERT_ASSERT_OK_AND_ASSIGN(const Subgraph subgraph,
551-
model.Subgraph(signature.Key()));
552-
LITERT_ASSERT_OK_AND_ASSIGN(const Tensor& tensor0, subgraph.Input("arg0"));
550+
LITERT_ASSERT_OK_AND_ASSIGN(const Tensor& tensor0,
551+
signature.InputTensor("arg0"));
553552
LITERT_ASSERT_OK_AND_ASSIGN(const RankedTensorType& type0,
554553
tensor0.RankedTensorType());
555-
LITERT_ASSERT_OK_AND_ASSIGN(const Tensor& tensor1, subgraph.Input("arg1"));
554+
LITERT_ASSERT_OK_AND_ASSIGN(const Tensor& tensor1,
555+
signature.InputTensor("arg1"));
556556
LITERT_ASSERT_OK_AND_ASSIGN(const RankedTensorType& type1,
557557
tensor1.RankedTensorType());
558558

@@ -581,7 +581,7 @@ TEST(CompiledModelTest, ResizeInputTensorWithDynamicModel) {
581581
compiled_model.GetOutputBufferRequirements(size_t(0)));
582582
LITERT_ASSERT_OK_AND_ASSIGN(size_t out_size, out_req.BufferSize());
583583
LITERT_ASSERT_OK_AND_ASSIGN(const Tensor& out_tensor,
584-
subgraph.Output("tfl.add"));
584+
signature.OutputTensor("tfl.add"));
585585
LITERT_ASSERT_OK_AND_ASSIGN(const RankedTensorType& out_type,
586586
out_tensor.RankedTensorType());
587587

litert/cc/litert_model.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "litert/cc/litert_model.h"
1616

17+
#include <cstddef>
1718
#include <vector>
1819

1920
#include "absl/strings/string_view.h" // from @com_google_absl
@@ -150,4 +151,44 @@ std::vector<Op> Subgraph::Ops() const {
150151
return ops;
151152
}
152153

154+
Expected<Tensor> Signature::InputTensor(absl::string_view name) const {
155+
LiteRtTensor tensor;
156+
auto status =
157+
LiteRtGetSignatureInputTensor(Get(), std::string(name).c_str(), &tensor);
158+
if (status != kLiteRtStatusOk) {
159+
return Unexpected(status, "Failed to look up signature input tensor");
160+
}
161+
return Tensor(tensor);
162+
}
163+
164+
Expected<Tensor> Signature::InputTensor(size_t index) const {
165+
LiteRtTensor tensor;
166+
auto status = LiteRtGetSignatureInputTensorByIndex(
167+
Get(), static_cast<LiteRtParamIndex>(index), &tensor);
168+
if (status != kLiteRtStatusOk) {
169+
return Unexpected(status, "Failed to look up signature input tensor");
170+
}
171+
return Tensor(tensor);
172+
}
173+
174+
Expected<Tensor> Signature::OutputTensor(absl::string_view name) const {
175+
LiteRtTensor tensor;
176+
auto status =
177+
LiteRtGetSignatureOutputTensor(Get(), std::string(name).c_str(), &tensor);
178+
if (status != kLiteRtStatusOk) {
179+
return Unexpected(status, "Failed to look up signature output tensor");
180+
}
181+
return Tensor(tensor);
182+
}
183+
184+
Expected<Tensor> Signature::OutputTensor(size_t index) const {
185+
LiteRtTensor tensor;
186+
auto status = LiteRtGetSignatureOutputTensorByIndex(
187+
Get(), static_cast<LiteRtParamIndex>(index), &tensor);
188+
if (status != kLiteRtStatusOk) {
189+
return Unexpected(status, "Failed to look up signature output tensor");
190+
}
191+
return Tensor(tensor);
192+
}
193+
153194
} // namespace litert

litert/cc/litert_model.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,20 @@ class Signature : public internal::NonOwnedHandle<LiteRtSignature> {
371371
}
372372
return output_names;
373373
}
374+
375+
// Returns the input tensor with the given input signature name in the
376+
// signature entry.
377+
Expected<Tensor> InputTensor(absl::string_view name) const;
378+
379+
// Returns the input tensor at the given index in the signature entry.
380+
Expected<Tensor> InputTensor(size_t index) const;
381+
382+
// Returns the output tensor with the given output signature name in the
383+
// signature entry.
384+
Expected<Tensor> OutputTensor(absl::string_view name) const;
385+
386+
// Returns the output tensor at the given index in the signature entry.
387+
Expected<Tensor> OutputTensor(size_t index) const;
374388
};
375389

376390
// Model. C++ equivalent of LiteRtModel.

litert/core/model/model.cc

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,28 @@ Quantization MakePerTensorQuantization(float scale, int64_t zero_point) {
102102
}
103103

104104
LiteRtSignatureT MakeDefaultSignature(LiteRtSubgraph subgraph) {
105-
auto tensor_name = [](auto* tensor) { return std::string(tensor->Name()); };
106-
107-
auto in_start = subgraph->Inputs().cbegin();
108-
auto in_end = subgraph->Inputs().cend();
109-
std::vector<std::string> input_names(subgraph->NumInputs());
110-
std::transform(in_start, in_end, input_names.begin(), tensor_name);
105+
std::vector<std::string> input_names;
106+
std::vector<LiteRtTensor> input_tensors;
107+
input_names.reserve(subgraph->NumInputs());
108+
input_tensors.reserve(subgraph->NumInputs());
109+
for (auto* tensor : subgraph->Inputs()) {
110+
input_names.push_back(std::string(tensor->Name()));
111+
input_tensors.push_back(tensor);
112+
}
111113

112-
auto out_start = subgraph->Outputs().cbegin();
113-
auto out_end = subgraph->Outputs().cend();
114-
std::vector<std::string> output_names(subgraph->NumOutputs());
115-
std::transform(out_start, out_end, output_names.begin(), tensor_name);
114+
std::vector<std::string> output_names;
115+
std::vector<LiteRtTensor> output_tensors;
116+
output_names.reserve(subgraph->NumOutputs());
117+
output_tensors.reserve(subgraph->NumOutputs());
118+
for (auto* tensor : subgraph->Outputs()) {
119+
output_names.push_back(std::string(tensor->Name()));
120+
output_tensors.push_back(tensor);
121+
}
116122

117123
std::string name(LiteRtSignatureT::kDefaultSignatureKey);
118124
return LiteRtSignatureT(subgraph, std::move(input_names),
119-
std::move(output_names), std::move(name));
125+
std::move(input_tensors), std::move(output_names),
126+
std::move(output_tensors), std::move(name));
120127
}
121128

122129
::litert::Expected<LiteRtSubgraph> LookupSubgraph(

0 commit comments

Comments
 (0)