Skip to content

Commit c41d2a2

Browse files
fengwuyaocopybara-github
authored andcommitted
Add bool type support in gpu numerics check
LiteRT-PiperOrigin-RevId: 829055207
1 parent 81ca1f2 commit c41d2a2

File tree

2 files changed

+25
-46
lines changed

2 files changed

+25
-46
lines changed

litert/tools/BUILD

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,8 @@ cc_binary(
518518
deps = NUMERICS_CHECK_DEPS + [
519519
"//litert/cc/options:litert_gpu_options",
520520
"//litert/core:filesystem",
521+
"//tflite/c:c_api_types",
522+
"//tflite/tools:utils",
521523
] + GPU_ACCELERATOR_DEPS,
522524
)
523525

@@ -529,6 +531,8 @@ cc_binary(
529531
"//litert/core:filesystem",
530532
# copybara:uncomment_begin(google-only)
531533
# "//litert/runtime/accelerators/gpu:ml_drift_cl_gl_accelerator", # buildcleaner: keep
534+
# "//tflite/c:c_api_types",
535+
# "//tflite/tools:utils",
532536
# copybara:uncomment_end
533537
],
534538
)
@@ -542,6 +546,8 @@ cc_binary(
542546
"//litert/core:filesystem",
543547
# copybara:uncomment_begin(google-only)
544548
# "//litert/runtime/accelerators/gpu:ml_drift_vulkan_accelerator", # buildcleaner: keep
549+
# "//tflite/c:c_api_types",
550+
# "//tflite/tools:utils",
545551
# "//third_party/vulkan_loader",
546552
# copybara:uncomment_end
547553
],
@@ -560,6 +566,8 @@ cc_binary(
560566
"//litert/core:filesystem",
561567
# copybara:uncomment_begin(google-only)
562568
# "//litert/runtime/accelerators/gpu/google:jet_gpu_accelerator", # buildcleaner: keep
569+
# "//tflite/c:c_api_types",
570+
# "//tflite/tools:utils",
563571
# copybara:uncomment_end
564572
],
565573
)

litert/tools/gpu_numerics_check.cc

Lines changed: 17 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
#include "litert/cc/litert_tensor_buffer.h"
4343
#include "litert/cc/options/litert_gpu_options.h"
4444
#include "litert/core/filesystem.h"
45+
#include "tflite/c/c_api_types.h"
46+
#include "tflite/tools/utils.h"
4547

4648
ABSL_FLAG(bool, print_diff_stats, false,
4749
"Whether to print the diff stats CSV.");
@@ -122,59 +124,32 @@ Expected<Options> GetCpuOptions() {
122124
return options;
123125
}
124126

125-
Expected<void> FillInputTensor(TensorBuffer& buffer, float scale) {
127+
Expected<void> FillInputTensor(TensorBuffer& buffer) {
126128
LITERT_ASSIGN_OR_RETURN(auto type, buffer.TensorType());
127129
const auto& layout = type.Layout();
128130
size_t total_elements =
129131
std::accumulate(layout.Dimensions().begin(), layout.Dimensions().end(), 1,
130132
std::multiplies<size_t>());
131-
132-
if (type.ElementType() == ElementType::Float16 ||
133-
type.ElementType() == ElementType::Float32 ||
134-
type.ElementType() == ElementType::BFloat16) {
135-
std::vector<float> data(total_elements);
136-
for (size_t i = 0; i < total_elements; ++i) {
137-
data[i] = std::sin(i * scale);
138-
}
139-
return buffer.Write<float>(absl::MakeConstSpan(data));
140-
} else if (type.ElementType() == ElementType::Int32) {
141-
std::vector<int32_t> data(total_elements);
142-
for (size_t i = 0; i < total_elements; ++i) {
143-
data[i] = i % 32;
144-
}
145-
return buffer.Write<int32_t>(absl::MakeConstSpan(data));
146-
} else if (type.ElementType() == ElementType::Int16) {
147-
std::vector<int16_t> data(total_elements);
148-
for (size_t i = 0; i < total_elements; ++i) {
149-
data[i] = i % 2048;
150-
}
151-
return buffer.Write<int16_t>(absl::MakeConstSpan(data));
152-
} else if (type.ElementType() == ElementType::Int8) {
153-
std::vector<int8_t> data(total_elements);
154-
for (size_t i = 0; i < total_elements; ++i) {
155-
data[i] = i % 256 - 128;
156-
}
157-
return buffer.Write<int8_t>(absl::MakeConstSpan(data));
158-
} else if (type.ElementType() == ElementType::UInt8) {
159-
std::vector<uint8_t> data(total_elements);
160-
for (size_t i = 0; i < total_elements; ++i) {
161-
data[i] = i % 256;
162-
}
163-
return buffer.Write<uint8_t>(absl::MakeConstSpan(data));
164-
} else {
165-
return Error(kLiteRtStatusErrorInvalidArgument,
166-
"Unsupported element type for filling tensor.");
167-
}
133+
float low_range = 0;
134+
float high_range = 0;
135+
tflite::utils::GetDataRangesForType(
136+
static_cast<TfLiteType>(type.ElementType()), &low_range, &high_range);
137+
auto tensor_data = tflite::utils::CreateRandomTensorData(
138+
/*name=*/"", static_cast<TfLiteType>(type.ElementType()), total_elements,
139+
low_range, high_range);
140+
141+
return buffer.Write<char>(absl::MakeSpan(
142+
reinterpret_cast<char*>(tensor_data.data.get()), tensor_data.bytes));
168143
}
169144

170145
// Creates and fills input buffers for a given compiled model.
171146
Expected<std::vector<TensorBuffer>> CreateAndFillInputBuffers(
172-
const CompiledModel& compiled_model, size_t signature_index, float scale) {
147+
const CompiledModel& compiled_model, size_t signature_index) {
173148
LITERT_ASSIGN_OR_RETURN(auto input_buffers,
174149
compiled_model.CreateInputBuffers(signature_index));
175150

176151
for (auto& buffer : input_buffers) {
177-
LITERT_RETURN_IF_ERROR(FillInputTensor(buffer, scale));
152+
LITERT_RETURN_IF_ERROR(FillInputTensor(buffer));
178153
}
179154
return input_buffers;
180155
}
@@ -386,17 +361,13 @@ Expected<std::vector<BufferDiffStats>> RunModel(absl::string_view model_path) {
386361
size_t signature_index = absl::GetFlag(FLAGS_signature_index);
387362
ABSL_LOG(INFO) << "Signature index: " << signature_index;
388363

389-
float input_scale = 0.12345f;
390-
391364
// Create and fill input buffers
392365
LITERT_ASSIGN_OR_RETURN(
393366
auto cpu_input_buffers,
394-
CreateAndFillInputBuffers(compiled_model_cpu, signature_index,
395-
input_scale));
367+
CreateAndFillInputBuffers(compiled_model_cpu, signature_index));
396368
LITERT_ASSIGN_OR_RETURN(
397369
auto gpu_input_buffers,
398-
CreateAndFillInputBuffers(compiled_model_gpu, signature_index,
399-
input_scale));
370+
CreateAndFillInputBuffers(compiled_model_gpu, signature_index));
400371

401372
// Create output buffers
402373
LITERT_ASSIGN_OR_RETURN(

0 commit comments

Comments
 (0)