Skip to content

Commit 31a00fc

Browse files
authored
Add basic support for tracing (microsoft#1524)
1 parent bfc8027 commit 31a00fc

File tree

10 files changed

+232
-7
lines changed

10 files changed

+232
-7
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ if(ENABLE_TESTS)
8080
endif()
8181
endif()
8282

83+
if(ENABLE_TRACING)
84+
message(STATUS "Tracing is enabled.")
85+
add_compile_definitions(ORTGENAI_ENABLE_TRACING)
86+
endif()
87+
8388
find_package(Threads REQUIRED)
8489

8590
if(WIN32)

cmake/options.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ option(TEST_PHI2 "Enable tests for Phi2" OFF)
1717

1818
# performance
1919
option(ENABLE_MODEL_BENCHMARK "Build model benchmark program" ON)
20+
21+
# diagnostics
22+
option(ENABLE_TRACING "Enable recording of tracing data" OFF)

src/generators.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "models/decoder_only.h"
99
#include "constrained_logits_processor.h"
1010
#include "search.h"
11+
#include "tracing.h"
1112
#include "cpu/interface.h"
1213
#include "cuda/interface.h"
1314
#include "dml/interface.h"
@@ -47,7 +48,7 @@ static bool _ = (Ort::InitApi(), false);
4748

4849
static OrtLoggingLevel GetDefaultOrtLoggingLevel() {
4950
bool ort_verbose_logging = false;
50-
GetEnvironmentVariable("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging);
51+
GetEnv("ORTGENAI_ORT_VERBOSE_LOGGING", ort_verbose_logging);
5152
return ort_verbose_logging ? OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE : OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR;
5253
}
5354

@@ -353,6 +354,8 @@ void Generator::AuxAppendTokens(cpu_span<const int32_t> input_ids) {
353354
}
354355

355356
void Generator::AppendTokens(cpu_span<const int32_t> input_ids) {
357+
DurationTrace trace{"Generator::AppendTokens"};
358+
356359
ThrowErrorIfSessionTerminated(state_->session_terminated_);
357360
if (input_ids.size() == 0)
358361
throw std::runtime_error("input_ids is empty");
@@ -440,6 +443,8 @@ void Generator::SetLogits(DeviceSpan<float> logits) {
440443
}
441444

442445
void Generator::GenerateNextToken() {
446+
DurationTrace trace{"Generator::GenerateNextToken"};
447+
443448
ThrowErrorIfSessionTerminated(state_->session_terminated_);
444449
if (search_->GetSequenceLength() == 0 && !computed_logits_)
445450
throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken.");

src/models/decoder_only_pipeline.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "../generators.h"
55
#include "../logging.h"
6+
#include "../tracing.h"
67
#include "decoder_only_pipeline.h"
78
#include "windowed_kv_cache.h"
89

@@ -195,6 +196,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
195196
continue;
196197
}
197198

199+
DurationTrace trace{MakeString("DecoderOnlyPipelineState::RunPipeline[", pipeline_state->id_, "]")};
200+
198201
if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx > -1) {
199202
if (model_.config_->model.decoder.pipeline[pipeline_state->id_].reset_session_idx >=
200203
static_cast<int>(model_.sessions_.size())) {
@@ -332,6 +335,8 @@ void DecoderOnlyPipelineState::RunPipeline(int total_length, DeviceSpan<int32_t>
332335

333336
DeviceSpan<float> DecoderOnlyPipelineState::Run(int total_length, DeviceSpan<int32_t>& next_tokens,
334337
DeviceSpan<int32_t> next_indices) {
338+
DurationTrace trace{"DecoderOnlyPipelineState::Run"};
339+
335340
UpdateInputsOutputs(next_tokens, next_indices, total_length);
336341

337342
size_t num_chunks{1};

src/models/env_utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
namespace Generators {
1313

14-
std::string GetEnvironmentVariable(const char* var_name) {
14+
std::string GetEnv(const char* var_name) {
1515
#if _MSC_VER
1616
// Why getenv() should be avoided on Windows:
1717
// https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv
@@ -40,8 +40,8 @@ std::string GetEnvironmentVariable(const char* var_name) {
4040
#endif // _MSC_VER
4141
}
4242

43-
void GetEnvironmentVariable(const char* var_name, bool& value) {
44-
std::string str_value = GetEnvironmentVariable(var_name);
43+
void GetEnv(const char* var_name, bool& value) {
44+
std::string str_value = GetEnv(var_name);
4545
if (str_value == "1" || str_value == "true") {
4646
value = true;
4747
} else if (str_value == "0" || str_value == "false") {

src/models/env_utils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
namespace Generators {
77

8-
std::string GetEnvironmentVariable(const char* var_name);
8+
// Gets the environment variable value. If no environment variable is found, the result will be empty.
9+
std::string GetEnv(const char* var_name);
910

1011
// This overload is used to get boolean environment variables.
1112
// If the environment variable is set to "1" or "true" (case-sensitive), value will be set to true.
1213
// Otherwise, value will not be modified.
13-
void GetEnvironmentVariable(const char* var_name, bool& value);
14+
void GetEnv(const char* var_name, bool& value);
1415

1516
} // namespace Generators

src/models/model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "../generators.h"
1313
#include "../search.h"
14+
#include "../tracing.h"
1415
#include "model.h"
1516
#include "gpt.h"
1617
#include "decoder_only.h"
@@ -37,6 +38,8 @@ State::State(const GeneratorParams& params, const Model& model)
3738
}
3839

3940
void State::Run(OrtSession& session, bool graph_capture_this_run) {
41+
DurationTrace trace{"State::Run"};
42+
4043
if (params_->use_graph_capture) {
4144
if (graph_capture_this_run)
4245
run_options_->AddConfigEntry("gpu_graph_id", graph_id_.c_str());

src/models/onnxruntime_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ inline void InitApi() {
200200
}
201201

202202
bool ort_lib = false;
203-
Generators::GetEnvironmentVariable("ORTGENAI_LOG_ORT_LIB", ort_lib);
203+
Generators::GetEnv("ORTGENAI_LOG_ORT_LIB", ort_lib);
204204
if (ort_lib) {
205205
Generators::SetLogBool("enabled", true);
206206
Generators::SetLogBool("ort_lib", true);

src/tracing.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "tracing.h"
5+
6+
#include <chrono>
7+
#include <fstream>
8+
#include <mutex>
9+
#include <optional>
10+
#include <sstream>
11+
#include <thread>
12+
13+
#include "models/env_utils.h"
14+
15+
namespace Generators {
16+
17+
#if defined(ORTGENAI_ENABLE_TRACING)
18+
19+
namespace {
20+
21+
// Writes trace events to a file in Chrome tracing format.
22+
// See more details about the format here:
23+
// https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU
24+
class FileTraceSink : public TraceSink {
25+
public:
26+
FileTraceSink(std::string_view file_path)
27+
: ostream_{std::ofstream{file_path.data()}},
28+
start_{Clock::now()},
29+
insert_event_delimiter_{false} {
30+
ostream_ << "[";
31+
}
32+
33+
~FileTraceSink() {
34+
ostream_ << "]\n";
35+
}
36+
37+
void BeginDuration(std::string_view label) {
38+
LogEvent("B", label);
39+
}
40+
41+
void EndDuration() {
42+
LogEvent("E");
43+
}
44+
45+
private:
46+
using Clock = std::chrono::steady_clock;
47+
48+
void LogEvent(std::string_view phase_type, std::optional<std::string_view> label = std::nullopt) {
49+
const auto thread_id = std::this_thread::get_id();
50+
const auto ts = std::chrono::duration_cast<std::chrono::microseconds>(Clock::now() - start_);
51+
52+
std::ostringstream event{};
53+
54+
event << "{";
55+
56+
if (label.has_value()) {
57+
event << "\"name\": \"" << *label << "\", ";
58+
}
59+
60+
event << "\"cat\": \"perf\", "
61+
<< "\"ph\": \"" << phase_type << "\", "
62+
<< "\"pid\": 0, "
63+
<< "\"tid\": " << thread_id << ", "
64+
<< "\"ts\": " << ts.count()
65+
<< "}";
66+
67+
{
68+
std::scoped_lock g{output_mutex_};
69+
70+
// add the delimiter only after writing the first event
71+
if (insert_event_delimiter_) {
72+
ostream_ << ",\n";
73+
} else {
74+
insert_event_delimiter_ = true;
75+
}
76+
77+
ostream_ << event.str();
78+
}
79+
}
80+
81+
std::ofstream ostream_;
82+
const Clock::time_point start_;
83+
bool insert_event_delimiter_;
84+
85+
std::mutex output_mutex_;
86+
};
87+
88+
std::string GetTraceFileName() {
89+
constexpr const char* kTraceFileEnvironmentVariableName = "ORTGENAI_TRACE_FILE_PATH";
90+
auto trace_file_name = GetEnv(kTraceFileEnvironmentVariableName);
91+
if (trace_file_name.empty()) {
92+
trace_file_name = "ortgenai_trace.log";
93+
}
94+
return trace_file_name;
95+
}
96+
97+
} // namespace
98+
99+
#endif // defined(ORTGENAI_ENABLE_TRACING)
100+
101+
Tracer::Tracer() {
102+
#if defined(ORTGENAI_ENABLE_TRACING)
103+
const auto trace_file_name = GetTraceFileName();
104+
sink_ = std::make_unique<FileTraceSink>(trace_file_name);
105+
#endif
106+
}
107+
108+
void Tracer::BeginDuration(std::string_view label) {
109+
#if defined(ORTGENAI_ENABLE_TRACING)
110+
sink_->BeginDuration(label);
111+
#else
112+
static_cast<void>(label);
113+
#endif
114+
}
115+
116+
void Tracer::EndDuration() {
117+
#if defined(ORTGENAI_ENABLE_TRACING)
118+
sink_->EndDuration();
119+
#endif
120+
}
121+
122+
Tracer& DefaultTracerInstance() {
123+
static auto tracer = Tracer{};
124+
return tracer;
125+
}
126+
127+
} // namespace Generators

src/tracing.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// Build with CMake option ENABLE_TRACING=ON to enable tracing.
5+
// To avoid performance overhead, tracing is not enabled by default.
6+
7+
// When tracing is enabled, the trace data will be recorded to a file.
8+
// The trace file path can be specified with the environment variable ORTGENAI_TRACE_FILE_PATH.
9+
// The trace file can be viewed with Perfetto UI (https://ui.perfetto.dev/).
10+
11+
#pragma once
12+
13+
#include <memory>
14+
#include <string_view>
15+
16+
namespace Generators {
17+
18+
// Trace consumer interface.
19+
class TraceSink {
20+
public:
21+
virtual void BeginDuration(std::string_view label) = 0;
22+
virtual void EndDuration() = 0;
23+
virtual ~TraceSink() = default;
24+
};
25+
26+
// Main tracing class.
27+
class Tracer {
28+
public:
29+
Tracer();
30+
31+
// Begins a traced duration with the given label.
32+
void BeginDuration(std::string_view label);
33+
34+
// Ends the traced duration from the most recent call to BeginDuration() in the same thread.
35+
void EndDuration();
36+
37+
private:
38+
Tracer(const Tracer&) = delete;
39+
Tracer& operator=(const Tracer&) = delete;
40+
Tracer(Tracer&&) = delete;
41+
Tracer& operator=(Tracer&&) = delete;
42+
43+
#if defined(ORTGENAI_ENABLE_TRACING)
44+
std::unique_ptr<TraceSink> sink_;
45+
#endif
46+
};
47+
48+
// Gets the default tracer instance.
49+
Tracer& DefaultTracerInstance();
50+
51+
// Records a traced duration while in scope.
52+
class DurationTrace {
53+
public:
54+
[[nodiscard]] DurationTrace(std::string_view label)
55+
: DurationTrace{DefaultTracerInstance(), label} {
56+
}
57+
58+
[[nodiscard]] DurationTrace(Tracer& tracer, std::string_view label)
59+
: tracer_{tracer} {
60+
tracer_.BeginDuration(label);
61+
}
62+
63+
~DurationTrace() {
64+
tracer_.EndDuration();
65+
}
66+
67+
private:
68+
DurationTrace(const DurationTrace&) = delete;
69+
DurationTrace& operator=(const DurationTrace&) = delete;
70+
DurationTrace(DurationTrace&&) = delete;
71+
DurationTrace& operator=(DurationTrace&&) = delete;
72+
73+
Tracer& tracer_;
74+
};
75+
76+
} // namespace Generators

0 commit comments

Comments
 (0)