1818
1919#include " absl/flags/flag.h" // from @com_google_absl
2020#include " absl/strings/string_view.h" // from @com_google_absl
21+ #include " litert/c/options/litert_google_tensor_options_type.h"
2122#include " litert/cc/litert_expected.h"
2223#include " litert/cc/litert_macros.h"
2324#include " litert/cc/options/litert_google_tensor_options.h"
2930bool AbslParseFlag (absl::string_view text,
3031 LiteRtGoogleTensorOptionsTruncationType* options,
3132 std::string* error) {
32- if (text == " unspecified " ) {
33- *options = kLiteRtGoogleTensorFloatTruncationTypeUnspecified ;
33+ if (text == " auto " ) {
34+ *options = kLiteRtGoogleTensorFloatTruncationTypeAuto ;
3435 return true ;
3536 }
3637 if (text == " no_truncation" ) {
3738 *options = kLiteRtGoogleTensorFloatTruncationTypeNoTruncation ;
3839 return true ;
3940 }
40- if (text == " bf16 " ) {
41+ if (text == " bfloat16 " ) {
4142 *options = kLiteRtGoogleTensorFloatTruncationTypeBfloat16 ;
4243 return true ;
4344 }
@@ -51,12 +52,12 @@ bool AbslParseFlag(absl::string_view text,
5152
5253std::string AbslUnparseFlag (LiteRtGoogleTensorOptionsTruncationType options) {
5354 switch (options) {
54- case kLiteRtGoogleTensorFloatTruncationTypeUnspecified :
55- return " unspecified " ;
55+ case kLiteRtGoogleTensorFloatTruncationTypeAuto :
56+ return " auto " ;
5657 case kLiteRtGoogleTensorFloatTruncationTypeNoTruncation :
5758 return " no_truncation" ;
5859 case kLiteRtGoogleTensorFloatTruncationTypeBfloat16 :
59- return " bf16 " ;
60+ return " bfloat16 " ;
6061 case kLiteRtGoogleTensorFloatTruncationTypeHalf :
6162 return " half" ;
6263 }
@@ -101,15 +102,12 @@ std::string AbslUnparseFlag(
101102
102103ABSL_FLAG (LiteRtGoogleTensorOptionsTruncationType,
103104 google_tensor_truncation_type,
104- kLiteRtGoogleTensorFloatTruncationTypeUnspecified ,
105+ kLiteRtGoogleTensorFloatTruncationTypeAuto ,
105106 " Float truncation type for Google Tensor." );
106107
107108ABSL_FLAG (bool , google_tensor_int64_to_int32, false ,
108109 " Whether to truncate int64 to int32." );
109110
110- ABSL_FLAG (std::string, google_tensor_output_dir, " " ,
111- " Output directory for Google Tensor." );
112-
113111ABSL_FLAG (bool , google_tensor_dump_op_timings, false ,
114112 " Whether to dump op timings." );
115113
@@ -137,7 +135,6 @@ Expected<GoogleTensorOptions> GoogleTensorOptionsFromFlags() {
137135 absl::GetFlag (FLAGS_google_tensor_truncation_type));
138136 options.SetInt64ToInt32Truncation (
139137 absl::GetFlag (FLAGS_google_tensor_int64_to_int32));
140- options.SetOutputDir (absl::GetFlag (FLAGS_google_tensor_output_dir));
141138 options.SetDumpOpTimings (absl::GetFlag (FLAGS_google_tensor_dump_op_timings));
142139 options.SetEnableLargeModelSupport (
143140 absl::GetFlag (FLAGS_google_tensor_enable_large_model_support));
0 commit comments