Skip to content

Commit 42f1026

Browse files
ai-edge-botcopybara-github
authored andcommitted
Enable multi subgraph support.
LiteRT-PiperOrigin-RevId: 826298676
1 parent ec7edf1 commit 42f1026

File tree

3 files changed

+231
-1
lines changed

3 files changed

+231
-1
lines changed

litert/c/litert_model.cc

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
#include <cstddef>
1818
#include <cstdint>
19-
#include <memory>
19+
#include <string>
2020
#include <tuple>
2121
#include <utility>
22+
#include <vector>
2223

2324
#include "absl/strings/string_view.h" // from @com_google_absl
2425
#include "litert/c/litert_common.h"
26+
#include "litert/c/litert_model_types.h"
2527
#include "litert/c/litert_op_code.h"
2628
#include "litert/cc/litert_buffer_ref.h"
2729
#include "litert/cc/litert_macros.h"
@@ -137,6 +139,47 @@ LiteRtStatus LiteRtGetModelSignature(LiteRtModel model,
137139

138140
void LiteRtDestroyModel(LiteRtModel model) { delete model; }
139141

142+
LiteRtStatus LiteRtSerializeModelWithSignatures(
143+
LiteRtModel model, uint8_t** buf, size_t* size, size_t* offset,
144+
bool destroy_model, char** signatures, LiteRtParamIndex num_signatures,
145+
LiteRtModelSerializationOptions options) {
146+
size_t num_subgraphs = model->NumSubgraphs();
147+
if (num_subgraphs != num_signatures) {
148+
return kLiteRtStatusErrorInvalidArgument;
149+
}
150+
for (size_t i = 0; i < num_subgraphs; ++i) {
151+
if (signatures[i] == nullptr) {
152+
// If the signature is null, we will use the default signature.
153+
// This is to support the backward compatibility with the previous version
154+
// of the compiler.
155+
continue;
156+
}
157+
std::string signature_key(signatures[i]);
158+
159+
LiteRtSubgraphT& subgraph = model->Subgraph(i);
160+
161+
std::vector<std::string> input_names;
162+
std::vector<LiteRtTensor> input_tensors;
163+
for (auto& tensor : subgraph.Inputs()) {
164+
input_names.push_back(std::string(tensor->Name()));
165+
input_tensors.push_back(tensor);
166+
}
167+
std::vector<std::string> output_names;
168+
std::vector<LiteRtTensor> output_tensors;
169+
for (auto& tensor : subgraph.Outputs()) {
170+
output_names.push_back(std::string(tensor->Name()));
171+
output_tensors.push_back(tensor);
172+
}
173+
174+
// Use EmplaceSignature to add a new signature
175+
model->EmplaceSignature(&subgraph, std::move(input_names),
176+
std::move(input_tensors), std::move(output_names),
177+
std::move(output_tensors),
178+
std::move(signature_key));
179+
}
180+
return LiteRtSerializeModel(model, buf, size, offset, destroy_model, options);
181+
}
182+
140183
LiteRtStatus LiteRtSerializeModel(LiteRtModel model, uint8_t** buf,
141184
size_t* size, size_t* offset,
142185
bool destroy_model,

litert/c/litert_model.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,17 @@ LiteRtStatus LiteRtPushOp(LiteRtOpList op_list, LiteRtOp op,
262262
// Serialization related functions
263263
//
264264

265+
// Serializes model to valid tflite flatbuffer bytes with signatures.
266+
//
267+
// This destroys the model before it returns unless destroy_model is false.
268+
// Caller takes ownership of `buf`. Flatbuffers are packed into their arrays
269+
// back to front, so the valid flatbuffer is buf[offset, size]. See the above
270+
// options for more details.
271+
LiteRtStatus LiteRtSerializeModelWithSignatures(
272+
LiteRtModel model, uint8_t** buf, size_t* size, size_t* offset,
273+
bool destroy_model, char** signatures, LiteRtParamIndex num_signatures,
274+
LiteRtModelSerializationOptions options);
275+
265276
// Serializes model to valid tflite flatbuffer bytes.
266277
//
267278
// This destroys the model before it returns unless destroy_model is false.

litert/c/litert_model_test.cc

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <array>
1818
#include <cstddef>
1919
#include <cstdint>
20+
#include <cstdlib>
2021
#include <initializer_list>
2122
#include <string>
2223
#include <utility>
@@ -26,6 +27,7 @@
2627
#include "absl/strings/string_view.h" // from @com_google_absl
2728
#include "absl/types/span.h" // from @com_google_absl
2829
#include "litert/c/litert_common.h"
30+
#include "litert/c/litert_model_types.h"
2931
#include "litert/c/litert_op_code.h"
3032
#include "litert/cc/litert_buffer_ref.h"
3133
#include "litert/core/model/model.h"
@@ -418,6 +420,180 @@ TEST(LiteRtModelTest, GetSubgraphOOB) {
418420
IsError(kLiteRtStatusErrorIndexOOB));
419421
}
420422

423+
TEST(LiteRtModelTest, SerializeModelWithSignaturesWithOneSignature) {
424+
// This test checks that the serialization succeeds and the signature is
425+
// added correctly even if the model has only one subgraph.
426+
LiteRtModelT model;
427+
auto& subgraph = model.EmplaceSubgraph();
428+
auto& input_tensor = subgraph.EmplaceTensor();
429+
input_tensor.SetName("input");
430+
subgraph.Inputs().push_back(&input_tensor);
431+
432+
auto& output_tensor = subgraph.EmplaceTensor();
433+
output_tensor.SetName("output");
434+
subgraph.Outputs().push_back(&output_tensor);
435+
436+
const char* signature_key = "serving_default";
437+
char* signatures[] = {const_cast<char*>(signature_key)};
438+
439+
uint8_t* buf = nullptr;
440+
size_t size = 0;
441+
size_t offset = 0;
442+
const LiteRtModelSerializationOptions options = {/*bytecode_alignment=*/64};
443+
444+
// We expect this to fail on serialization if NPU is disabled, but signature
445+
// should be added regardless.
446+
const LiteRtStatus status = LiteRtSerializeModelWithSignatures(
447+
&model, &buf, &size, &offset,
448+
/*destroy_model=*/false, signatures,
449+
/*num_signatures=*/1, options);
450+
451+
#ifdef LITERT_BUILD_INCLUDE_NPU
452+
LITERT_ASSERT_OK(status);
453+
EXPECT_NE(buf, nullptr);
454+
EXPECT_GT(size, 0);
455+
#else
456+
EXPECT_NE(status, kLiteRtStatusOk);
457+
EXPECT_EQ(buf, nullptr);
458+
EXPECT_EQ(size, 0);
459+
#endif
460+
461+
// The model should now have one signature.
462+
LiteRtParamIndex num_signatures;
463+
LITERT_ASSERT_OK(LiteRtGetNumModelSignatures(&model, &num_signatures));
464+
ASSERT_EQ(num_signatures, 1);
465+
466+
LiteRtSignature signature;
467+
LITERT_ASSERT_OK(LiteRtGetModelSignature(&model, 0, &signature));
468+
469+
const char* key;
470+
LITERT_ASSERT_OK(LiteRtGetSignatureKey(signature, &key));
471+
EXPECT_STREQ(key, "serving_default");
472+
473+
LiteRtParamIndex num_inputs;
474+
LITERT_ASSERT_OK(LiteRtGetNumSignatureInputs(signature, &num_inputs));
475+
ASSERT_EQ(num_inputs, 1);
476+
477+
const char* input_name;
478+
LITERT_ASSERT_OK(LiteRtGetSignatureInputName(signature, 0, &input_name));
479+
EXPECT_STREQ(input_name, "input");
480+
481+
LiteRtTensor sig_input_tensor;
482+
LITERT_ASSERT_OK(
483+
LiteRtGetSignatureInputTensorByIndex(signature, 0, &sig_input_tensor));
484+
EXPECT_EQ(sig_input_tensor, &input_tensor);
485+
486+
LiteRtParamIndex num_outputs;
487+
LITERT_ASSERT_OK(LiteRtGetNumSignatureOutputs(signature, &num_outputs));
488+
ASSERT_EQ(num_outputs, 1);
489+
490+
const char* output_name;
491+
LITERT_ASSERT_OK(LiteRtGetSignatureOutputName(signature, 0, &output_name));
492+
EXPECT_STREQ(output_name, "output");
493+
494+
LiteRtTensor sig_output_tensor;
495+
LITERT_ASSERT_OK(
496+
LiteRtGetSignatureOutputTensorByIndex(signature, 0, &sig_output_tensor));
497+
EXPECT_EQ(sig_output_tensor, &output_tensor);
498+
499+
// Clean up buffer if serialization succeeded.
500+
free(buf);
501+
}
502+
503+
TEST(LiteRtModelTest, SerializeModelWithSignaturesMultipleSubgraphs) {
504+
// This test checks that the serialization succeeds and the signatures are
505+
// added correctly even if the model has multiple subgraphs.
506+
LiteRtModelT model;
507+
auto& subgraph1 = model.EmplaceSubgraph();
508+
auto& input_tensor1 = subgraph1.EmplaceTensor();
509+
input_tensor1.SetName("input1");
510+
subgraph1.Inputs().push_back(&input_tensor1);
511+
auto& output_tensor1 = subgraph1.EmplaceTensor();
512+
output_tensor1.SetName("output1");
513+
subgraph1.Outputs().push_back(&output_tensor1);
514+
515+
auto& subgraph2 = model.EmplaceSubgraph();
516+
auto& input_tensor2 = subgraph2.EmplaceTensor();
517+
input_tensor2.SetName("input2");
518+
subgraph2.Inputs().push_back(&input_tensor2);
519+
auto& output_tensor2 = subgraph2.EmplaceTensor();
520+
output_tensor2.SetName("output2");
521+
subgraph2.Outputs().push_back(&output_tensor2);
522+
523+
const char* signature_key1 = "sig1";
524+
const char* signature_key2 = "sig2";
525+
char* signatures[] = {const_cast<char*>(signature_key1),
526+
const_cast<char*>(signature_key2)};
527+
528+
uint8_t* buf = nullptr;
529+
size_t size = 0;
530+
size_t offset = 0;
531+
const LiteRtModelSerializationOptions options = {/*bytecode_alignment=*/64};
532+
533+
const LiteRtStatus status = LiteRtSerializeModelWithSignatures(
534+
&model, &buf, &size, &offset,
535+
/*destroy_model=*/false, signatures,
536+
/*num_signatures=*/2, options);
537+
538+
#ifdef LITERT_BUILD_INCLUDE_NPU
539+
LITERT_ASSERT_OK(status);
540+
EXPECT_NE(buf, nullptr);
541+
EXPECT_GT(size, 0);
542+
#else
543+
EXPECT_NE(status, kLiteRtStatusOk);
544+
EXPECT_EQ(buf, nullptr);
545+
EXPECT_EQ(size, 0);
546+
#endif
547+
548+
// The model should now have two signatures.
549+
LiteRtParamIndex num_signatures;
550+
LITERT_ASSERT_OK(LiteRtGetNumModelSignatures(&model, &num_signatures));
551+
ASSERT_EQ(num_signatures, 2);
552+
553+
// Check first signature.
554+
{
555+
LiteRtSignature signature;
556+
LITERT_ASSERT_OK(LiteRtGetModelSignature(&model, 0, &signature));
557+
558+
const char* key;
559+
LITERT_ASSERT_OK(LiteRtGetSignatureKey(signature, &key));
560+
EXPECT_STREQ(key, "sig1");
561+
562+
LiteRtTensor sig_input_tensor;
563+
LITERT_ASSERT_OK(
564+
LiteRtGetSignatureInputTensorByIndex(signature, 0, &sig_input_tensor));
565+
EXPECT_EQ(sig_input_tensor, &input_tensor1);
566+
567+
LiteRtTensor sig_output_tensor;
568+
LITERT_ASSERT_OK(LiteRtGetSignatureOutputTensorByIndex(
569+
signature, 0, &sig_output_tensor));
570+
EXPECT_EQ(sig_output_tensor, &output_tensor1);
571+
}
572+
573+
// Check second signature.
574+
{
575+
LiteRtSignature signature;
576+
LITERT_ASSERT_OK(LiteRtGetModelSignature(&model, 1, &signature));
577+
578+
const char* key;
579+
LITERT_ASSERT_OK(LiteRtGetSignatureKey(signature, &key));
580+
EXPECT_STREQ(key, "sig2");
581+
582+
LiteRtTensor sig_input_tensor;
583+
LITERT_ASSERT_OK(
584+
LiteRtGetSignatureInputTensorByIndex(signature, 0, &sig_input_tensor));
585+
EXPECT_EQ(sig_input_tensor, &input_tensor2);
586+
587+
LiteRtTensor sig_output_tensor;
588+
LITERT_ASSERT_OK(LiteRtGetSignatureOutputTensorByIndex(
589+
signature, 0, &sig_output_tensor));
590+
EXPECT_EQ(sig_output_tensor, &output_tensor2);
591+
}
592+
593+
// Clean up buffer if serialization succeeded.
594+
free(buf);
595+
}
596+
421597
TEST(LiteRtOpListTest, PushOps) {
422598
LiteRtOpListT op_list;
423599
LiteRtOpT op;

0 commit comments

Comments
 (0)