@@ -16,25 +16,31 @@ limitations under the License.
1616#include " xla/backends/gpu/runtime/thunk_buffer_debug_pass.h"
1717
1818#include < cstddef>
19+ #include < cstdint>
1920#include < cstring>
2021#include < memory>
22+ #include < string>
2123#include < utility>
2224#include < vector>
2325
26+ #include " absl/algorithm/container.h"
2427#include " absl/base/nullability.h"
2528#include " absl/container/flat_hash_map.h"
29+ #include " absl/functional/any_invocable.h"
2630#include " absl/functional/bind_front.h"
2731#include " absl/log/check.h"
2832#include " absl/log/log.h"
2933#include " absl/status/status.h"
3034#include " absl/status/statusor.h"
35+ #include " re2/re2.h"
3136#include " xla/backends/gpu/runtime/buffer_debug_log_entry_metadata_store.h"
3237#include " xla/backends/gpu/runtime/buffer_debug_log_structs.h"
3338#include " xla/backends/gpu/runtime/buffers_checksum_thunk.h"
3439#include " xla/backends/gpu/runtime/buffers_nan_count_thunk.h"
3540#include " xla/backends/gpu/runtime/custom_call_thunk.h"
3641#include " xla/backends/gpu/runtime/sequential_thunk.h"
3742#include " xla/backends/gpu/runtime/thunk.h"
43+ #include " xla/backends/gpu/runtime/thunk_id.h"
3844#include " xla/backends/gpu/runtime/thunk_pass_pipeline.h"
3945#include " xla/ffi/api/c_api.h"
4046#include " xla/ffi/attribute_map.h"
@@ -226,6 +232,103 @@ absl::Status DumpBufferDebugLog(
226232 return absl::OkStatus ();
227233}
228234
235+ // A boolean-like value returned from thunk filters to indicate whether the
236+ // thunk should be instrumented or left as is.
237+ enum class InstrumentAction : bool {
238+ // Don't instrument the thunk, leave it as is.
239+ kSkip ,
240+ // Instrument the thunk.
241+ kInstrument ,
242+ };
243+
244+ // A function that decides whether the thunk should be instrumented
245+ // (kInstrument) or not (kSkip).
246+ using ThunkFilter = absl::AnyInvocable<InstrumentAction(const Thunk&) const >;
247+
248+ // Creates a thunk filter that filters thunks by their IDs, based the allowed
249+ // ranges passed in debug options.
250+ ThunkFilter CreateThunkIdFilter (const DebugOptions& debug_options) {
251+ std::vector<std::pair<int64_t , int64_t >> thunk_id_ranges;
252+ for (const auto & range :
253+ debug_options.xla_gpu_experimental_thunk_buffer_debug_filter ()
254+ .thunk_id_ranges ()) {
255+ VLOG (1 ) << " Thunk filter: id range [" << range.first () << " , "
256+ << range.last () << " ]" ;
257+ thunk_id_ranges.emplace_back (range.first (), range.last ());
258+ }
259+
260+ return [id_ranges = std::move (thunk_id_ranges)](const Thunk& thunk) {
261+ if (id_ranges.empty ()) {
262+ return InstrumentAction::kInstrument ;
263+ }
264+
265+ const ThunkId thunk_id = thunk.thunk_info ().thunk_id ;
266+ if (absl::c_any_of (id_ranges, [&](const auto & range) {
267+ VLOG (2 ) << " Thunk filter: check ID range: " << range.first
268+ << " <= " << thunk_id.value () << " <= " << range.second ;
269+ return range.first <= thunk_id.value () &&
270+ thunk_id.value () <= range.second ;
271+ })) {
272+ VLOG (2 ) << " Thunk filter: ID matches" ;
273+ return InstrumentAction::kInstrument ;
274+ }
275+
276+ VLOG (2 ) << " Thunk filter: ID does not match" ;
277+ return InstrumentAction::kSkip ;
278+ };
279+ }
280+
281+ // Creates a thunk filter that filters thunks by matching their profile
282+ // annotations against regexes configured in debug options.
283+ ThunkFilter CreateProfileAnnotationRegexFilter (
284+ const DebugOptions& debug_options) {
285+ std::vector<std::unique_ptr<RE2>> profile_annotation_regexes;
286+ for (const auto & regex :
287+ debug_options.xla_gpu_experimental_thunk_buffer_debug_filter ()
288+ .profile_annotation_regexes ()) {
289+ VLOG (1 ) << " Thunk filter: profile annotation regex: " << regex;
290+ profile_annotation_regexes.push_back (std::make_unique<RE2>(regex));
291+ }
292+ return [regexes = std::move (profile_annotation_regexes)](const Thunk& thunk) {
293+ if (regexes.empty ()) {
294+ return InstrumentAction::kInstrument ;
295+ }
296+
297+ const std::string& profile_annotation =
298+ thunk.thunk_info ().profile_annotation ;
299+ if (absl::c_any_of (regexes, [&](const auto & regex) {
300+ VLOG (2 ) << " Thunk filter: check profile annotation regex: "
301+ << regex->pattern ();
302+ return RE2::PartialMatch (profile_annotation, *regex);
303+ })) {
304+ VLOG (2 ) << " Thunk filter: profile annotation matches" ;
305+ return InstrumentAction::kInstrument ;
306+ }
307+
308+ VLOG (2 ) << " Thunk filter: profile annotation does not match" ;
309+ return InstrumentAction::kSkip ;
310+ };
311+ }
312+
313+ // Creates a thunk filter that filters thunks by all the conditions configured
314+ // in debug options.
315+ ThunkFilter CreateThunkFilter (const DebugOptions& debug_options) {
316+ std::vector<ThunkFilter> filters;
317+ filters.push_back (CreateThunkIdFilter (debug_options));
318+ filters.push_back (CreateProfileAnnotationRegexFilter (debug_options));
319+
320+ return [filters = std::move (filters)](const Thunk& thunk) {
321+ VLOG (2 ) << " Thunk filter: check ID " << thunk.thunk_info ().thunk_id
322+ << " , profile annotation " << thunk.thunk_info ().profile_annotation ;
323+ if (absl::c_all_of (filters, [&](const auto & filter) {
324+ return filter (thunk) == InstrumentAction::kInstrument ;
325+ })) {
326+ return InstrumentAction::kInstrument ;
327+ }
328+ return InstrumentAction::kSkip ;
329+ };
330+ }
331+
229332XLA_FFI_DEFINE_HANDLER_SYMBOL (
230333 kDebugLogInitHandler ,
231334 [](se::Stream* absl_nonnull stream, xla::ffi::Buffer<U8> log_buffer) {
@@ -285,7 +388,11 @@ absl::StatusOr<bool> ThunkBufferDebugPass::Run(
285388 /* results=*/ {}, /* attributes=*/ {},
286389 hlo_module->entry_computation ()));
287390
391+ ThunkFilter thunk_filter = CreateThunkFilter (debug_options);
288392 root_thunk->TransformAllNestedThunks ([&](std::unique_ptr<Thunk> thunk) {
393+ if (thunk_filter (*thunk) == InstrumentAction::kSkip ) {
394+ return thunk;
395+ }
289396 switch (mode_) {
290397 case Mode::kChecksum :
291398 VLOG (1 ) << " Wrapping with checksum thunk" ;
0 commit comments