Skip to content

Commit 4a11397

Browse files
LukeBoyercopybara-github
authored andcommitted
Add printer for vector of tenssors
LiteRT-PiperOrigin-RevId: 774969098
1 parent af80f9b commit 4a11397

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

litert/core/model/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ cc_test(
6464
deps = [
6565
":buffer_manager",
6666
":model",
67+
"//litert/c:litert_common",
6768
"//litert/c:litert_op_code",
6869
"//litert/cc:litert_buffer_ref",
6970
"//litert/core:build_stamp",

litert/core/model/model.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,17 @@ void AbslStringify(Sink& sink, const LiteRtTensorT& tensor) {
10831083
absl::Format(&sink, "%v%s", tensor.Type(), weights_str);
10841084
}
10851085

1086+
template <class Sink>
1087+
void AbslStringify(Sink& sink, const std::vector<LiteRtTensor>& tensors) {
1088+
sink.Append("(");
1089+
for (auto it = tensors.begin(); it < tensors.end() - 1; ++it) {
1090+
sink.Append(absl::StrFormat("%v", **it));
1091+
sink.Append(",");
1092+
}
1093+
sink.Append(absl::StrFormat("%v", *tensors.back()));
1094+
sink.Append(")");
1095+
}
1096+
10861097
// OPTIONS PRINTING
10871098

10881099
namespace litert::internal {

litert/core/model/model_test.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "absl/strings/str_format.h" // from @com_google_absl
2727
#include "absl/strings/string_view.h" // from @com_google_absl
2828
#include "absl/types/span.h" // from @com_google_absl
29+
#include "litert/c/litert_common.h"
2930
#include "litert/c/litert_model.h"
3031
#include "litert/c/litert_op_code.h"
3132
#include "litert/cc/litert_buffer_ref.h"
@@ -568,6 +569,17 @@ TEST(PrintingTest, ConstTensor) {
568569
EXPECT_EQ(absl::StrFormat("%v", tensor), "3d_i32<2x2x2>_cst[8B]");
569570
}
570571

572+
TEST(PrintingTest, TensoVector) {
573+
LiteRtTensorT tensor;
574+
tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeInt32, {2, 2, 2}));
575+
576+
LiteRtTensorT tensor2;
577+
tensor2.SetType(MakeRankedTensorType(kLiteRtElementTypeInt32, {2, 2, 2}));
578+
579+
std::vector<LiteRtTensor> tensors = {&tensor, &tensor2};
580+
EXPECT_EQ(absl::StrFormat("%v", tensors), "(3d_i32<2x2x2>,3d_i32<2x2x2>)");
581+
}
582+
571583
TEST(PrintingTest, TflOptions) {
572584
TflOptions opts;
573585
opts.type = ::tflite::BuiltinOptions_AddOptions;

0 commit comments

Comments
 (0)