Skip to content

Commit 0c52c20

Browse files
Marcin RadomskiGoogle-ML-Automation
authored andcommitted
[XLA:GPU] ThunkBufferDebugPass: implement thunk filtering
Interpret the values passed via command line flags: - `--xla_gpu_experimental_thunk_buffer_debug_filter_by_thunk_id_ranges` - `--xla_gpu_experimental_thunk_buffer_debug_filter_by_profile_annotation_re` And only instrument thunks that pass all the configured filters. PiperOrigin-RevId: 827881515
1 parent e08dbc8 commit 0c52c20

File tree

3 files changed

+332
-11
lines changed

3 files changed

+332
-11
lines changed

xla/backends/gpu/runtime/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,6 +2904,7 @@ cc_library(
29042904
":custom_call_thunk",
29052905
":sequential_thunk",
29062906
":thunk",
2907+
":thunk_id",
29072908
":thunk_pass_pipeline",
29082909
"//xla:shape_util",
29092910
"//xla:xla_data_proto_cc",
@@ -2918,14 +2919,17 @@ cc_library(
29182919
"//xla/stream_executor:stream",
29192920
"//xla/stream_executor/gpu:buffer_debug_log",
29202921
"//xla/tsl/platform:statusor",
2922+
"@com_google_absl//absl/algorithm:container",
29212923
"@com_google_absl//absl/base:nullability",
29222924
"@com_google_absl//absl/container:flat_hash_map",
2925+
"@com_google_absl//absl/functional:any_invocable",
29232926
"@com_google_absl//absl/functional:bind_front",
29242927
"@com_google_absl//absl/log",
29252928
"@com_google_absl//absl/log:check",
29262929
"@com_google_absl//absl/status",
29272930
"@com_google_absl//absl/status:statusor",
29282931
"@com_google_absl//absl/strings:string_view",
2932+
"@com_googlesource_code_re2//:re2",
29292933
],
29302934
)
29312935

xla/backends/gpu/runtime/thunk_buffer_debug_pass.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,31 @@ limitations under the License.
1616
#include "xla/backends/gpu/runtime/thunk_buffer_debug_pass.h"
1717

1818
#include <cstddef>
19+
#include <cstdint>
1920
#include <cstring>
2021
#include <memory>
22+
#include <string>
2123
#include <utility>
2224
#include <vector>
2325

26+
#include "absl/algorithm/container.h"
2427
#include "absl/base/nullability.h"
2528
#include "absl/container/flat_hash_map.h"
29+
#include "absl/functional/any_invocable.h"
2630
#include "absl/functional/bind_front.h"
2731
#include "absl/log/check.h"
2832
#include "absl/log/log.h"
2933
#include "absl/status/status.h"
3034
#include "absl/status/statusor.h"
35+
#include "re2/re2.h"
3136
#include "xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h"
3237
#include "xla/backends/gpu/runtime/buffer_debug_log_structs.h"
3338
#include "xla/backends/gpu/runtime/buffers_checksum_thunk.h"
3439
#include "xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
3540
#include "xla/backends/gpu/runtime/custom_call_thunk.h"
3641
#include "xla/backends/gpu/runtime/sequential_thunk.h"
3742
#include "xla/backends/gpu/runtime/thunk.h"
43+
#include "xla/backends/gpu/runtime/thunk_id.h"
3844
#include "xla/backends/gpu/runtime/thunk_pass_pipeline.h"
3945
#include "xla/ffi/api/c_api.h"
4046
#include "xla/ffi/attribute_map.h"
@@ -226,6 +232,103 @@ absl::Status DumpBufferDebugLog(
226232
return absl::OkStatus();
227233
}
228234

235+
// A boolean-like value returned from thunk filters to indicate whether the
236+
// thunk should be instrumented or left as is.
237+
enum class InstrumentAction : bool {
238+
// Don't instrument the thunk, leave it as is.
239+
kSkip,
240+
// Instrument the thunk.
241+
kInstrument,
242+
};
243+
244+
// A function that decides whether the thunk should be instrumented
245+
// (kInstrument) or not (kSkip).
246+
using ThunkFilter = absl::AnyInvocable<InstrumentAction(const Thunk&) const>;
247+
248+
// Creates a thunk filter that filters thunks by their IDs, based the allowed
249+
// ranges passed in debug options.
250+
ThunkFilter CreateThunkIdFilter(const DebugOptions& debug_options) {
251+
std::vector<std::pair<int64_t, int64_t>> thunk_id_ranges;
252+
for (const auto& range :
253+
debug_options.xla_gpu_experimental_thunk_buffer_debug_filter()
254+
.thunk_id_ranges()) {
255+
VLOG(1) << "Thunk filter: id range [" << range.first() << ", "
256+
<< range.last() << "]";
257+
thunk_id_ranges.emplace_back(range.first(), range.last());
258+
}
259+
260+
return [id_ranges = std::move(thunk_id_ranges)](const Thunk& thunk) {
261+
if (id_ranges.empty()) {
262+
return InstrumentAction::kInstrument;
263+
}
264+
265+
const ThunkId thunk_id = thunk.thunk_info().thunk_id;
266+
if (absl::c_any_of(id_ranges, [&](const auto& range) {
267+
VLOG(2) << "Thunk filter: check ID range: " << range.first
268+
<< " <= " << thunk_id.value() << " <= " << range.second;
269+
return range.first <= thunk_id.value() &&
270+
thunk_id.value() <= range.second;
271+
})) {
272+
VLOG(2) << "Thunk filter: ID matches";
273+
return InstrumentAction::kInstrument;
274+
}
275+
276+
VLOG(2) << "Thunk filter: ID does not match";
277+
return InstrumentAction::kSkip;
278+
};
279+
}
280+
281+
// Creates a thunk filter that filters thunks by matching their profile
282+
// annotations against regexes configured in debug options.
283+
ThunkFilter CreateProfileAnnotationRegexFilter(
284+
const DebugOptions& debug_options) {
285+
std::vector<std::unique_ptr<RE2>> profile_annotation_regexes;
286+
for (const auto& regex :
287+
debug_options.xla_gpu_experimental_thunk_buffer_debug_filter()
288+
.profile_annotation_regexes()) {
289+
VLOG(1) << "Thunk filter: profile annotation regex: " << regex;
290+
profile_annotation_regexes.push_back(std::make_unique<RE2>(regex));
291+
}
292+
return [regexes = std::move(profile_annotation_regexes)](const Thunk& thunk) {
293+
if (regexes.empty()) {
294+
return InstrumentAction::kInstrument;
295+
}
296+
297+
const std::string& profile_annotation =
298+
thunk.thunk_info().profile_annotation;
299+
if (absl::c_any_of(regexes, [&](const auto& regex) {
300+
VLOG(2) << "Thunk filter: check profile annotation regex: "
301+
<< regex->pattern();
302+
return RE2::PartialMatch(profile_annotation, *regex);
303+
})) {
304+
VLOG(2) << "Thunk filter: profile annotation matches";
305+
return InstrumentAction::kInstrument;
306+
}
307+
308+
VLOG(2) << "Thunk filter: profile annotation does not match";
309+
return InstrumentAction::kSkip;
310+
};
311+
}
312+
313+
// Creates a thunk filter that filters thunks by all the conditions configured
314+
// in debug options.
315+
ThunkFilter CreateThunkFilter(const DebugOptions& debug_options) {
316+
std::vector<ThunkFilter> filters;
317+
filters.push_back(CreateThunkIdFilter(debug_options));
318+
filters.push_back(CreateProfileAnnotationRegexFilter(debug_options));
319+
320+
return [filters = std::move(filters)](const Thunk& thunk) {
321+
VLOG(2) << "Thunk filter: check ID " << thunk.thunk_info().thunk_id
322+
<< ", profile annotation " << thunk.thunk_info().profile_annotation;
323+
if (absl::c_all_of(filters, [&](const auto& filter) {
324+
return filter(thunk) == InstrumentAction::kInstrument;
325+
})) {
326+
return InstrumentAction::kInstrument;
327+
}
328+
return InstrumentAction::kSkip;
329+
};
330+
}
331+
229332
XLA_FFI_DEFINE_HANDLER_SYMBOL(
230333
kDebugLogInitHandler,
231334
[](se::Stream* absl_nonnull stream, xla::ffi::Buffer<U8> log_buffer) {
@@ -285,7 +388,11 @@ absl::StatusOr<bool> ThunkBufferDebugPass::Run(
285388
/*results=*/{}, /*attributes=*/{},
286389
hlo_module->entry_computation()));
287390

391+
ThunkFilter thunk_filter = CreateThunkFilter(debug_options);
288392
root_thunk->TransformAllNestedThunks([&](std::unique_ptr<Thunk> thunk) {
393+
if (thunk_filter(*thunk) == InstrumentAction::kSkip) {
394+
return thunk;
395+
}
289396
switch (mode_) {
290397
case Mode::kChecksum:
291398
VLOG(1) << "Wrapping with checksum thunk";

0 commit comments

Comments
 (0)