Skip to content

Commit 6da00db

Browse files
LukeBoyercopybara-github
authored andcommitted
Add printer for core op type
LiteRT-PiperOrigin-RevId: 774977275
1 parent 4a11397 commit 6da00db

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

litert/core/model/model.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,25 @@ void AbslStringify(Sink& sink, const std::vector<LiteRtTensor>& tensors) {
10941094
sink.Append(")");
10951095
}
10961096

1097+
// OP PRINTING
1098+
1099+
template <class Sink>
1100+
void AbslStringify(Sink& sink, const LiteRtOpT& op) {
1101+
static constexpr auto kFmt = "%v%v%v->%v";
1102+
const auto& opts = ::litert::internal::GetTflOptions(op);
1103+
if (opts.type != ::tflite::BuiltinOptions_NONE) {
1104+
absl::Format(&sink, kFmt, op.OpCode(), opts, op.Inputs(), op.Outputs());
1105+
return;
1106+
}
1107+
absl::Format(&sink, kFmt, op.OpCode(), ::litert::internal::GetTflOptions2(op),
1108+
op.Inputs(), op.Outputs());
1109+
}
1110+
1111+
template <class Sink>
1112+
void AbslStringify(Sink& sink, const LiteRtOpT* op) {
1113+
absl::Format(&sink, "null");
1114+
}
1115+
10971116
// OPTIONS PRINTING
10981117

10991118
namespace litert::internal {

litert/core/model/model_test.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,36 @@ TEST(PrintingTest, TensoVector) {
580580
EXPECT_EQ(absl::StrFormat("%v", tensors), "(3d_i32<2x2x2>,3d_i32<2x2x2>)");
581581
}
582582

583+
TEST(PrintingTest, Op) {
584+
LiteRtOpT op;
585+
op.SetOpCode(kLiteRtOpCodeTflAdd);
586+
587+
{
588+
::tflite::AddOptionsT add_opts;
589+
add_opts.fused_activation_function = ::tflite::ActivationFunctionType_RELU;
590+
add_opts.pot_scale_int16 = false;
591+
TflOptions opts;
592+
opts.type = ::tflite::BuiltinOptions_AddOptions;
593+
opts.Set(std::move(add_opts));
594+
litert::internal::SetTflOptions(op, std::move(opts));
595+
}
596+
597+
LiteRtTensorT tensor;
598+
tensor.SetType(MakeRankedTensorType(kLiteRtElementTypeInt32, {2, 2, 2}));
599+
op.Inputs().push_back(&tensor);
600+
601+
LiteRtTensorT tensor2;
602+
tensor2.SetType(MakeRankedTensorType(kLiteRtElementTypeInt32, {2}));
603+
op.Inputs().push_back(&tensor2);
604+
605+
LiteRtTensorT tensor3;
606+
tensor3.SetType(MakeRankedTensorType(kLiteRtElementTypeInt32, {2, 2, 2}));
607+
op.Outputs().push_back(&tensor3);
608+
609+
EXPECT_EQ(absl::StrFormat("%v", op),
610+
"tfl.add{fa=RELU}(3d_i32<2x2x2>,1d_i32<2>)->(3d_i32<2x2x2>)");
611+
}
612+
583613
TEST(PrintingTest, TflOptions) {
584614
TflOptions opts;
585615
opts.type = ::tflite::BuiltinOptions_AddOptions;

0 commit comments

Comments
 (0)