Skip to content

Commit 731f9b0

Browse files
LukeBoyercopybara-github
authored andcommitted
Switch ats to use micro seconds for latency instead of nano
LiteRT-PiperOrigin-RevId: 826195550
1 parent c41d2a2 commit 731f9b0

File tree

9 files changed

+84
-37
lines changed

9 files changed

+84
-37
lines changed

litert/ats/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ litert_define_ats(
380380
# "^(?:(?!npu_ats_87).)*$$",
381381
],
382382
extra_flags = [
383-
"--f16_range_for_f32=true",
384383
"--data_seed=42",
385384
],
386385
jit_suffix = "",

litert/ats/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ enum class CompilationStatus {
100100
// Timing related types.
101101
using Clock = std::chrono::steady_clock;
102102
using TimePoint = Clock::time_point;
103-
using Nanoseconds = uint64_t;
103+
using Microseconds = uint64_t;
104104

105105
// Which backend to use as the "actual".
106106
enum class ExecutionBackend { kCpu, kGpu, kNpu };
@@ -177,7 +177,7 @@ void AbslStringify(Sink& sink, const CompilationStatus& status) {
177177
}
178178

179179
template <typename Sink>
180-
void AbslStringify(Sink& sink, const Nanoseconds& ns) {
180+
void AbslStringify(Sink& sink, const Microseconds& ns) {
181181
absl::Format(&sink, "%e", ns);
182182
}
183183

litert/ats/compile_capture.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,25 @@
2929
namespace litert::testing {
3030

3131
// Information about the time taken to compile the model.
32-
class CompilationTime : public Printable<Nanoseconds> {
32+
class CompilationTime : public Printable<Microseconds> {
3333
public:
3434
// Start timing.
3535
TimePoint Start() const { return Clock::now(); }
3636

3737
// Stop timing and record the latency.
3838
void Stop(const TimePoint& start) {
39-
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
39+
std::chrono::duration<Microseconds, std::nano> nano = Clock::now() - start;
4040
nanos_ = nano.count();
4141
}
4242

43-
Nanoseconds Nanos() const { return nanos_; }
43+
Microseconds Nanos() const { return nanos_; }
4444

4545
CompilationTime() : Printable("CompilationTime", "compile_time(ns)") {}
4646

4747
private:
4848
Fields GetFields() const override { return Fields{nanos_}; }
4949

50-
Nanoseconds nanos_ = std::numeric_limits<Nanoseconds>::max();
50+
Microseconds nanos_ = std::numeric_limits<Microseconds>::max();
5151
};
5252

5353
// Type to hold all of the capturable information related compilation test

litert/ats/compile_capture_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST(AtsCompileCaptureTest, Basic) {
3333
e.model.name = "FOO";
3434

3535
ASSERT_NE(e.compilation_time.Nanos(),
36-
std::numeric_limits<Nanoseconds>::max());
36+
std::numeric_limits<Microseconds>::max());
3737

3838
std::ostringstream s;
3939
cap.Print(s);

litert/ats/configure.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
196196
std::regex pos_re(absl::GetFlag(FLAGS_do_register),
197197
std::regex_constants::ECMAScript);
198198
auto extra_models = absl::GetFlag(FLAGS_extra_models);
199-
auto f16_range_for_f32 = absl::GetFlag(FLAGS_f16_range_for_f32);
200199
auto data_seed = absl::GetFlag(FLAGS_data_seed);
201200
auto dispatch_dir = absl::GetFlag(FLAGS_dispatch_dir);
202201
auto plugin_dir = absl::GetFlag(FLAGS_plugin_dir);
@@ -223,12 +222,11 @@ Expected<AtsConf> AtsConf::ParseFlagsAndDoSetup() {
223222
ParsePlugin(plugin_dir, soc_manufacturer, compile_mode, target_options));
224223
AtsConf res(std::move(seeds), backend, quiet, dispatch_dir, plugin_dir,
225224
std::move(neg_re), std::move(pos_re), std::move(extra_models),
226-
f16_range_for_f32, data_seed, iters_per_test,
227-
std::move(max_ms_per_test_opt), fail_on_timeout, dump_report,
228-
std::move(csv), compile_mode, std::move(models_out), limit,
229-
std::move(plugin), std::move(soc_manufacturer),
230-
std::move(soc_model), std::move(target_options),
231-
std::move(reference_options));
225+
data_seed, iters_per_test, std::move(max_ms_per_test_opt),
226+
fail_on_timeout, dump_report, std::move(csv), compile_mode,
227+
std::move(models_out), limit, std::move(plugin),
228+
std::move(soc_manufacturer), std::move(soc_model),
229+
std::move(target_options), std::move(reference_options));
232230
Setup(res);
233231
return res;
234232
}

litert/ats/configure.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ ABSL_DECLARE_FLAG(std::string, dont_register);
6262
// Regex for explicit inclusions.
6363
ABSL_DECLARE_FLAG(std::string, do_register);
6464

65-
// Will generate values for f32 tensors in the range of f16 values.
66-
ABSL_DECLARE_FLAG(bool, f16_range_for_f32);
67-
6865
// Optional list of directories, or model files to add to the test.
6966
ABSL_DECLARE_FLAG(std::vector<std::string>, extra_models);
7067

@@ -251,8 +248,7 @@ class AtsConf {
251248
bool quiet, std::string dispatch_dir, std::string plugin_dir,
252249
std::regex&& neg_re, std::regex&& pos_re,
253250
std::vector<std::string> extra_models,
254-
bool f16_range_for_f32, std::optional<int> data_seed,
255-
size_t iters_per_test,
251+
std::optional<int> data_seed, size_t iters_per_test,
256252
std::chrono::milliseconds max_ms_per_test,
257253
bool fail_on_timeout, bool dump_report, std::string csv,
258254
bool compile_mode, std::string models_out, int32_t limit,
@@ -267,7 +263,7 @@ class AtsConf {
267263
neg_re_(std::move(neg_re)),
268264
pos_re_(std::move(pos_re)),
269265
extra_models_(std::move(extra_models)),
270-
f16_range_for_f32_(f16_range_for_f32),
266+
271267
data_seed_(data_seed),
272268
iters_per_test_(iters_per_test),
273269
max_ms_per_test_(std::move(max_ms_per_test)),
@@ -282,9 +278,9 @@ class AtsConf {
282278
soc_model_(std::move(soc_model)),
283279
target_options_(std::move(target_options)),
284280
reference_options_(std::move(reference_options)) {
285-
if (f16_range_for_f32_) {
286-
data_builder_.SetF16InF32();
287-
}
281+
// For now, we will provide default settings for data generation.
282+
// More configurability may be introduced later.
283+
data_builder_.SetSin();
288284
}
289285

290286
SeedMap seeds_for_params_;
@@ -296,7 +292,6 @@ class AtsConf {
296292
std::regex neg_re_;
297293
std::regex pos_re_;
298294
std::vector<std::string> extra_models_;
299-
bool f16_range_for_f32_;
300295
std::optional<int> data_seed_;
301296
size_t iters_per_test_;
302297
std::chrono::milliseconds max_ms_per_test_;

litert/ats/inference_capture.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace litert::testing {
3535

3636
// Information about the latency of the execution.
3737
class Latency
38-
: public Printable<Nanoseconds, Nanoseconds, Nanoseconds, size_t> {
38+
: public Printable<Microseconds, Microseconds, Microseconds, size_t> {
3939
public:
4040
using Ref = std::reference_wrapper<Latency>;
4141

@@ -58,25 +58,26 @@ class Latency
5858

5959
// Stop timing and record the latency.
6060
void Stop(const TimePoint& start) {
61-
std::chrono::duration<Nanoseconds, std::nano> nano = Clock::now() - start;
62-
latencies_.push_back(nano.count());
61+
const auto micro = std::chrono::duration_cast<
62+
std::chrono::duration<Microseconds, std::micro>>(Clock::now() - start);
63+
latencies_.push_back(micro.count());
6364
}
6465

6566
// Average latency.
66-
Nanoseconds Avg() const {
67+
Microseconds Avg() const {
6768
return ::litert::Avg(latencies_.cbegin(), latencies_.cend());
6869
}
6970

7071
// Maximum latency.
71-
Nanoseconds Max() const {
72+
Microseconds Max() const {
7273
if (latencies_.empty()) {
7374
return 0;
7475
}
7576
return *std::max_element(latencies_.begin(), latencies_.end());
7677
}
7778

7879
// Minimum latency.
79-
Nanoseconds Min() const {
80+
Microseconds Min() const {
8081
if (latencies_.empty()) {
8182
return 0;
8283
}
@@ -87,15 +88,15 @@ class Latency
8788
size_t NumSamples() const { return latencies_.size(); }
8889

8990
Latency()
90-
: Printable("Latency", "avg_latency(ns)", "max_latency(ns)",
91-
"min_latency(ns)", "num_samples") {}
91+
: Printable("Latency", "avg_latency(us)", "max_latency(us)",
92+
"min_latency(us)", "num_samples") {}
9293

9394
private:
9495
Fields GetFields() const override {
9596
return Fields{Avg(), Max(), Min(), NumSamples()};
9697
}
9798

98-
std::vector<Nanoseconds> latencies_;
99+
std::vector<Microseconds> latencies_;
99100
};
100101

101102
// Information about the numerics of the execution.

litert/ats/inference_fixture.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef THIRD_PARTY_ODML_LITERT_LITERT_ATS_INFERENCE_FIXTURE_H_
1616
#define THIRD_PARTY_ODML_LITERT_LITERT_ATS_INFERENCE_FIXTURE_H_
1717

18+
#include <algorithm>
1819
#include <cstddef>
1920
#include <cstdint>
2021
#include <limits>
@@ -136,7 +137,24 @@ class AtsInferenceTest : public RngTest {
136137

137138
template <typename Rng>
138139
Expected<VarBuffers> MakeInputs(Rng& device) const {
139-
return graph_->MakeInputs(device, conf_.DataBuilder());
140+
auto inputs = graph_->MakeInputs(device, conf_.DataBuilder());
141+
if (!inputs.HasValue()) return inputs.Error();
142+
#ifndef NDEBUG
143+
LITERT_LOG(LITERT_INFO, "First 5 elements of each input:");
144+
for (size_t i = 0; i < inputs->size(); ++i) {
145+
const auto& input = (*inputs)[i];
146+
LITERT_LOG(LITERT_INFO, " Input %zu:", i);
147+
if (input.Type().ElementType() == ElementType::Float32) {
148+
const auto& view = input.template AsView<float>();
149+
for (int j = 0; j < std::min(5, view.data.size()); ++j) {
150+
LITERT_LOG(LITERT_INFO, " [%d]: %f", j, view.data[j]);
151+
}
152+
} else {
153+
LITERT_LOG(LITERT_INFO, " Unsupported element type for printing.");
154+
}
155+
}
156+
#endif
157+
return inputs;
140158
}
141159

142160
Expected<VarBuffers> Actual(const VarBuffers& inputs,
@@ -185,7 +203,15 @@ class AtsInferenceTest : public RngTest {
185203
template <typename T>
186204
void CheckOutputImpl(const BufferView<T>& actual, const BufferView<T>& ref) {
187205
double mse = std::numeric_limits<double>::max();
188-
EXPECT_THAT(actual.data, MeanSquaredErrorLt(ref.data, 1e-5, &mse));
206+
#ifndef NDEBUG
207+
LITERT_LOG(LITERT_INFO, "First 5 elements:");
208+
for (int i = 0; i < std::min(5, actual.data.size()); ++i) {
209+
LITERT_LOG(LITERT_INFO, " actual[%d]: %f, ref[%d]: %f", i,
210+
static_cast<float>(actual.data[i]), i,
211+
static_cast<float>(ref.data[i]));
212+
}
213+
#endif
214+
EXPECT_THAT(actual.data, MeanSquaredErrorLt(ref.data, 1e-4, &mse));
189215
cap_.numerics.NewMse(mse);
190216
}
191217

litert/cc/internal/litert_rng.h

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,23 @@ class F16InF32Generator<float> final {
260260
}
261261
};
262262

263+
template <typename D>
264+
class SinGenerator final : public DataGeneratorBase<D> {};
265+
266+
// Generates sin values in the range [-1, 1].
267+
template <>
268+
class SinGenerator<float> final : public DataGeneratorBase<float> {
269+
public:
270+
SinGenerator() = default;
271+
272+
template <typename Rng>
273+
float operator()(Rng& rng) {
274+
return std::sin(rng() * 0.12345f);
275+
}
276+
float Max() const override { return 1.0f; }
277+
float Min() const override { return -1.0f; }
278+
};
279+
263280
// Dummy primitive generator that returns a monotonically increasing sequence.
264281
template <typename D>
265282
class DummyGenerator final : public DataGeneratorBase<D> {
@@ -536,6 +553,11 @@ class RandomTensorDataBuilder {
536553
return *this;
537554
}
538555

556+
RandomTensorDataBuilder& SetSin() {
557+
float_config_ = Sin();
558+
return *this;
559+
}
560+
539561
template <typename D>
540562
std::pair<double, double> Bounds() const {
541563
if constexpr (std::is_same_v<D, int32_t>) {
@@ -556,6 +578,8 @@ class RandomTensorDataBuilder {
556578
std::numeric_limits<float>::max()};
557579
} else if (std::holds_alternative<F16InF32>(float_config_)) {
558580
return {std::numeric_limits<float>::lowest(), 65504.0};
581+
} else if (std::holds_alternative<Sin>(float_config_)) {
582+
return {-1.0f, 1.0f};
559583
} else {
560584
auto [min, max] = std::get<std::pair<float, float>>(float_config_);
561585
return {min, max};
@@ -589,6 +613,9 @@ class RandomTensorDataBuilder {
589613
} else if (std::holds_alternative<F16InF32>(float_config_)) {
590614
RandomTensorData<D, F16InF32Generator> data;
591615
return Functor()(data, std::forward<Args>(args)...);
616+
} else if (std::holds_alternative<Sin>(float_config_)) {
617+
RandomTensorData<D, SinGenerator> data;
618+
return Functor()(data, std::forward<Args>(args)...);
592619
} else {
593620
auto [min, max] = std::get<std::pair<D, D>>(float_config_);
594621
RandomTensorData<D, DefaultRangedGenerator> data(min, max);
@@ -603,12 +630,13 @@ class RandomTensorDataBuilder {
603630
struct Dummy {};
604631
struct NullOpt {};
605632
struct F16InF32 {};
633+
struct Sin {};
606634

607635
template <typename D>
608636
using IntConfig = std::variant<std::pair<D, D>, Dummy, NullOpt>;
609637
template <typename D>
610638
using FloatConfig =
611-
std::variant<std::pair<float, float>, Dummy, NullOpt, F16InF32>;
639+
std::variant<std::pair<float, float>, Dummy, NullOpt, F16InF32, Sin>;
612640

613641
IntConfig<int32_t> int_config_ = NullOpt();
614642
FloatConfig<float> float_config_ = NullOpt();

0 commit comments

Comments
 (0)