|
42 | 42 | #include "litert/cc/litert_tensor_buffer.h" |
43 | 43 | #include "litert/cc/options/litert_gpu_options.h" |
44 | 44 | #include "litert/core/filesystem.h" |
| 45 | +#include "tflite/c/c_api_types.h" |
| 46 | +#include "tflite/tools/utils.h" |
45 | 47 |
|
46 | 48 | ABSL_FLAG(bool, print_diff_stats, false, |
47 | 49 | "Whether to print the diff stats CSV."); |
@@ -122,59 +124,32 @@ Expected<Options> GetCpuOptions() { |
122 | 124 | return options; |
123 | 125 | } |
124 | 126 |
|
125 | | -Expected<void> FillInputTensor(TensorBuffer& buffer, float scale) { |
| 127 | +Expected<void> FillInputTensor(TensorBuffer& buffer) { |
126 | 128 | LITERT_ASSIGN_OR_RETURN(auto type, buffer.TensorType()); |
127 | 129 | const auto& layout = type.Layout(); |
128 | 130 | size_t total_elements = |
129 | 131 | std::accumulate(layout.Dimensions().begin(), layout.Dimensions().end(), 1, |
130 | 132 | std::multiplies<size_t>()); |
131 | | - |
132 | | - if (type.ElementType() == ElementType::Float16 || |
133 | | - type.ElementType() == ElementType::Float32 || |
134 | | - type.ElementType() == ElementType::BFloat16) { |
135 | | - std::vector<float> data(total_elements); |
136 | | - for (size_t i = 0; i < total_elements; ++i) { |
137 | | - data[i] = std::sin(i * scale); |
138 | | - } |
139 | | - return buffer.Write<float>(absl::MakeConstSpan(data)); |
140 | | - } else if (type.ElementType() == ElementType::Int32) { |
141 | | - std::vector<int32_t> data(total_elements); |
142 | | - for (size_t i = 0; i < total_elements; ++i) { |
143 | | - data[i] = i % 32; |
144 | | - } |
145 | | - return buffer.Write<int32_t>(absl::MakeConstSpan(data)); |
146 | | - } else if (type.ElementType() == ElementType::Int16) { |
147 | | - std::vector<int16_t> data(total_elements); |
148 | | - for (size_t i = 0; i < total_elements; ++i) { |
149 | | - data[i] = i % 2048; |
150 | | - } |
151 | | - return buffer.Write<int16_t>(absl::MakeConstSpan(data)); |
152 | | - } else if (type.ElementType() == ElementType::Int8) { |
153 | | - std::vector<int8_t> data(total_elements); |
154 | | - for (size_t i = 0; i < total_elements; ++i) { |
155 | | - data[i] = i % 256 - 128; |
156 | | - } |
157 | | - return buffer.Write<int8_t>(absl::MakeConstSpan(data)); |
158 | | - } else if (type.ElementType() == ElementType::UInt8) { |
159 | | - std::vector<uint8_t> data(total_elements); |
160 | | - for (size_t i = 0; i < total_elements; ++i) { |
161 | | - data[i] = i % 256; |
162 | | - } |
163 | | - return buffer.Write<uint8_t>(absl::MakeConstSpan(data)); |
164 | | - } else { |
165 | | - return Error(kLiteRtStatusErrorInvalidArgument, |
166 | | - "Unsupported element type for filling tensor."); |
167 | | - } |
| 133 | + float low_range = 0; |
| 134 | + float high_range = 0; |
| 135 | + tflite::utils::GetDataRangesForType( |
| 136 | + static_cast<TfLiteType>(type.ElementType()), &low_range, &high_range); |
| 137 | + auto tensor_data = tflite::utils::CreateRandomTensorData( |
| 138 | + /*name=*/"", static_cast<TfLiteType>(type.ElementType()), total_elements, |
| 139 | + low_range, high_range); |
| 140 | + |
| 141 | + return buffer.Write<char>(absl::MakeSpan( |
| 142 | + reinterpret_cast<char*>(tensor_data.data.get()), tensor_data.bytes)); |
168 | 143 | } |
169 | 144 |
|
170 | 145 | // Creates and fills input buffers for a given compiled model. |
171 | 146 | Expected<std::vector<TensorBuffer>> CreateAndFillInputBuffers( |
172 | | - const CompiledModel& compiled_model, size_t signature_index, float scale) { |
| 147 | + const CompiledModel& compiled_model, size_t signature_index) { |
173 | 148 | LITERT_ASSIGN_OR_RETURN(auto input_buffers, |
174 | 149 | compiled_model.CreateInputBuffers(signature_index)); |
175 | 150 |
|
176 | 151 | for (auto& buffer : input_buffers) { |
177 | | - LITERT_RETURN_IF_ERROR(FillInputTensor(buffer, scale)); |
| 152 | + LITERT_RETURN_IF_ERROR(FillInputTensor(buffer)); |
178 | 153 | } |
179 | 154 | return input_buffers; |
180 | 155 | } |
@@ -386,17 +361,13 @@ Expected<std::vector<BufferDiffStats>> RunModel(absl::string_view model_path) { |
386 | 361 | size_t signature_index = absl::GetFlag(FLAGS_signature_index); |
387 | 362 | ABSL_LOG(INFO) << "Signature index: " << signature_index; |
388 | 363 |
|
389 | | - float input_scale = 0.12345f; |
390 | | - |
391 | 364 | // Create and fill input buffers |
392 | 365 | LITERT_ASSIGN_OR_RETURN( |
393 | 366 | auto cpu_input_buffers, |
394 | | - CreateAndFillInputBuffers(compiled_model_cpu, signature_index, |
395 | | - input_scale)); |
| 367 | + CreateAndFillInputBuffers(compiled_model_cpu, signature_index)); |
396 | 368 | LITERT_ASSIGN_OR_RETURN( |
397 | 369 | auto gpu_input_buffers, |
398 | | - CreateAndFillInputBuffers(compiled_model_gpu, signature_index, |
399 | | - input_scale)); |
| 370 | + CreateAndFillInputBuffers(compiled_model_gpu, signature_index)); |
400 | 371 |
|
401 | 372 | // Create output buffers |
402 | 373 | LITERT_ASSIGN_OR_RETURN( |
|
0 commit comments