Skip to content

Commit be35f6c

Browse files
ai-edge-botcopybara-github
authored andcommitted
Rename Google Tensor float truncation type enum and update flag strings.
LiteRT-PiperOrigin-RevId: 820099975
1 parent 718e0a3 commit be35f6c

File tree

8 files changed

+20
-20
lines changed

8 files changed

+20
-20
lines changed

litert/c/options/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ cc_test(
107107
srcs = ["litert_google_tensor_options_test.cc"],
108108
deps = [
109109
":litert_google_tensor_options",
110+
":litert_google_tensor_options_type",
110111
"//litert/c:litert_common",
111112
"//litert/c:litert_opaque_options",
112113
"//litert/cc/options:litert_google_tensor_options",

litert/c/options/litert_google_tensor_options_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <gtest/gtest.h>
2020
#include "litert/c/litert_common.h"
2121
#include "litert/c/litert_opaque_options.h"
22+
#include "litert/c/options/litert_google_tensor_options_type.h"
2223
#include "litert/cc/options/litert_google_tensor_options.h"
2324
#include "litert/test/matchers.h"
2425

@@ -108,7 +109,7 @@ TEST(LiteRtGoogleTensorOptionsTest, DumpOpTimings) {
108109
LiteRtGoogleTensorOptionsTruncationType truncation_type;
109110
LITERT_ASSERT_OK(LiteRtGoogleTensorOptionsGetFloatTruncationType(
110111
options_data, &truncation_type));
111-
ASSERT_EQ(truncation_type, kLiteRtGoogleTensorFloatTruncationTypeUnspecified);
112+
ASSERT_EQ(truncation_type, kLiteRtGoogleTensorFloatTruncationTypeAuto);
112113

113114
LITERT_ASSERT_OK(LiteRtGoogleTensorOptionsSetFloatTruncationType(
114115
options_data, kLiteRtGoogleTensorFloatTruncationTypeBfloat16));
@@ -128,7 +129,7 @@ TEST(GoogleTensorOptionsTest, CppApi) {
128129
EXPECT_TRUE(options->GetInt64ToInt32Truncation());
129130

130131
EXPECT_EQ(options->GetFloatTruncationType(),
131-
kLiteRtGoogleTensorFloatTruncationTypeUnspecified);
132+
kLiteRtGoogleTensorFloatTruncationTypeAuto);
132133
options->SetFloatTruncationType(
133134
kLiteRtGoogleTensorFloatTruncationTypeBfloat16);
134135
EXPECT_EQ(options->GetFloatTruncationType(),

litert/c/options/litert_google_tensor_options_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
// float_truncation_type -------------------------------------------------------
1919

2020
typedef enum LiteRtGoogleTensorOptionsTruncationType {
21-
kLiteRtGoogleTensorFloatTruncationTypeUnspecified = 0,
21+
kLiteRtGoogleTensorFloatTruncationTypeAuto = 0,
2222
kLiteRtGoogleTensorFloatTruncationTypeNoTruncation = 1,
2323
kLiteRtGoogleTensorFloatTruncationTypeBfloat16 = 2,
2424
kLiteRtGoogleTensorFloatTruncationTypeHalf = 3,

litert/runtime/litert_google_tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
struct LiteRtGoogleTensorOptionsT {
2424
LiteRtGoogleTensorOptionsTruncationType float_truncation_type =
25-
kLiteRtGoogleTensorFloatTruncationTypeUnspecified;
25+
kLiteRtGoogleTensorFloatTruncationTypeAuto;
2626
bool int64_to_int32_truncation = false;
2727
std::string output_dir = "";
2828
bool dump_op_timings = false;

litert/tools/flags/vendors/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ cc_library(
8787
"-DINCLUDE_GOOGLE_TENSOR_RUNTIME_FLAGS",
8888
],
8989
deps = [
90+
"//litert/c/options:litert_google_tensor_options_type",
9091
"//litert/cc:litert_expected",
9192
"//litert/cc:litert_macros",
9293
"//litert/cc/options:litert_google_tensor_options",
@@ -124,6 +125,7 @@ cc_test(
124125
deps = [
125126
":google_tensor_flags",
126127
"//litert/c/options:litert_google_tensor_options",
128+
"//litert/c/options:litert_google_tensor_options_type",
127129
"@com_google_googletest//:gtest_main",
128130
],
129131
)

litert/tools/flags/vendors/google_tensor_flags.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "absl/flags/flag.h" // from @com_google_absl
2020
#include "absl/strings/string_view.h" // from @com_google_absl
21+
#include "litert/c/options/litert_google_tensor_options_type.h"
2122
#include "litert/cc/litert_expected.h"
2223
#include "litert/cc/litert_macros.h"
2324
#include "litert/cc/options/litert_google_tensor_options.h"
@@ -29,15 +30,15 @@
2930
bool AbslParseFlag(absl::string_view text,
3031
LiteRtGoogleTensorOptionsTruncationType* options,
3132
std::string* error) {
32-
if (text == "unspecified") {
33-
*options = kLiteRtGoogleTensorFloatTruncationTypeUnspecified;
33+
if (text == "auto") {
34+
*options = kLiteRtGoogleTensorFloatTruncationTypeAuto;
3435
return true;
3536
}
3637
if (text == "no_truncation") {
3738
*options = kLiteRtGoogleTensorFloatTruncationTypeNoTruncation;
3839
return true;
3940
}
40-
if (text == "bf16") {
41+
if (text == "bfloat16") {
4142
*options = kLiteRtGoogleTensorFloatTruncationTypeBfloat16;
4243
return true;
4344
}
@@ -51,12 +52,12 @@ bool AbslParseFlag(absl::string_view text,
5152

5253
std::string AbslUnparseFlag(LiteRtGoogleTensorOptionsTruncationType options) {
5354
switch (options) {
54-
case kLiteRtGoogleTensorFloatTruncationTypeUnspecified:
55-
return "unspecified";
55+
case kLiteRtGoogleTensorFloatTruncationTypeAuto:
56+
return "auto";
5657
case kLiteRtGoogleTensorFloatTruncationTypeNoTruncation:
5758
return "no_truncation";
5859
case kLiteRtGoogleTensorFloatTruncationTypeBfloat16:
59-
return "bf16";
60+
return "bfloat16";
6061
case kLiteRtGoogleTensorFloatTruncationTypeHalf:
6162
return "half";
6263
}
@@ -101,15 +102,12 @@ std::string AbslUnparseFlag(
101102

102103
ABSL_FLAG(LiteRtGoogleTensorOptionsTruncationType,
103104
google_tensor_truncation_type,
104-
kLiteRtGoogleTensorFloatTruncationTypeUnspecified,
105+
kLiteRtGoogleTensorFloatTruncationTypeAuto,
105106
"Float truncation type for Google Tensor.");
106107

107108
ABSL_FLAG(bool, google_tensor_int64_to_int32, false,
108109
"Whether to truncate int64 to int32.");
109110

110-
ABSL_FLAG(std::string, google_tensor_output_dir, "",
111-
"Output directory for Google Tensor.");
112-
113111
ABSL_FLAG(bool, google_tensor_dump_op_timings, false,
114112
"Whether to dump op timings.");
115113

@@ -137,7 +135,6 @@ Expected<GoogleTensorOptions> GoogleTensorOptionsFromFlags() {
137135
absl::GetFlag(FLAGS_google_tensor_truncation_type));
138136
options.SetInt64ToInt32Truncation(
139137
absl::GetFlag(FLAGS_google_tensor_int64_to_int32));
140-
options.SetOutputDir(absl::GetFlag(FLAGS_google_tensor_output_dir));
141138
options.SetDumpOpTimings(absl::GetFlag(FLAGS_google_tensor_dump_op_timings));
142139
options.SetEnableLargeModelSupport(
143140
absl::GetFlag(FLAGS_google_tensor_enable_large_model_support));

litert/tools/flags/vendors/google_tensor_flags.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ std::string AbslUnparseFlag(LiteRtGoogleTensorOptionsTruncationType options);
3737

3838
ABSL_DECLARE_FLAG(bool, google_tensor_int64_to_int32);
3939

40-
ABSL_DECLARE_FLAG(std::string, google_tensor_output_dir);
41-
4240
ABSL_DECLARE_FLAG(bool, google_tensor_dump_op_timings);
4341

4442
ABSL_DECLARE_FLAG(bool, google_tensor_enable_large_model_support);

litert/tools/flags/vendors/google_tensor_flags_test.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include <gtest/gtest.h>
2020
#include "litert/c/options/litert_google_tensor_options.h"
21+
#include "litert/c/options/litert_google_tensor_options_type.h"
2122

2223
namespace litert::google_tensor {
2324
namespace {
@@ -34,9 +35,9 @@ TEST(TruncationTypeFlagTest, Parse) {
3435
LiteRtGoogleTensorOptionsTruncationType value;
3536

3637
{
37-
static constexpr absl::string_view kLevel = "unspecified";
38+
static constexpr absl::string_view kLevel = "auto";
3839
static constexpr LiteRtGoogleTensorOptionsTruncationType kLevelEnum =
39-
kLiteRtGoogleTensorFloatTruncationTypeUnspecified;
40+
kLiteRtGoogleTensorFloatTruncationTypeAuto;
4041
EXPECT_TRUE(AbslParseFlag(kLevel, &value, &error));
4142
EXPECT_EQ(value, kLevelEnum);
4243
EXPECT_EQ(kLevel, AbslUnparseFlag(value));
@@ -52,7 +53,7 @@ TEST(TruncationTypeFlagTest, Parse) {
5253
}
5354

5455
{
55-
static constexpr absl::string_view kLevel = "bf16";
56+
static constexpr absl::string_view kLevel = "bfloat16";
5657
static constexpr LiteRtGoogleTensorOptionsTruncationType kLevelEnum =
5758
kLiteRtGoogleTensorFloatTruncationTypeBfloat16;
5859
EXPECT_TRUE(AbslParseFlag(kLevel, &value, &error));

0 commit comments

Comments
 (0)