diff --git a/tsl/profiler/lib/BUILD b/tsl/profiler/lib/BUILD index c92b2ff19..d2fc13cb0 100644 --- a/tsl/profiler/lib/BUILD +++ b/tsl/profiler/lib/BUILD @@ -270,6 +270,7 @@ cc_library( ":traceme_encode", "//tsl/platform", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:source_location", "@xla//xla/tsl/platform:logging", "@xla//xla/tsl/platform:macros", "@xla//xla/tsl/platform:types", diff --git a/tsl/profiler/lib/traceme.h b/tsl/profiler/lib/traceme.h index 566dfef0a..3ec0e4e43 100644 --- a/tsl/profiler/lib/traceme.h +++ b/tsl/profiler/lib/traceme.h @@ -23,7 +23,9 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/source_location.h" #include "xla/tsl/platform/logging.h" #include "xla/tsl/platform/macros.h" #include "xla/tsl/profiler/utils/no_init.h" @@ -94,13 +96,20 @@ class TraceMe { // - Can be a value in enum TraceMeLevel. // Users are welcome to use level > 3 in their code, if they wish to filter // out their host traces based on verbosity. - explicit TraceMe(absl::string_view name, int level = 1, - uint64_t filter_mask = kTraceMeDefaultFilterMask) { + explicit TraceMe( + absl::string_view name, int level = 1, + uint64_t filter_mask = kTraceMeDefaultFilterMask, + absl::SourceLocation source_location = absl::SourceLocation::current()) { DCHECK_GE(level, 1); #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) && TraceMeRecorder::CheckFilter(filter_mask))) { name_.Emplace(std::string(name)); + traceme_internal::AppendMetadata( + &name_.value, + TraceMeEncode( + {{"traceme", absl::StrCat(source_location.file_name(), ":", + source_location.line())}})); start_time_ = GetCurrentTimeNanos(); } #endif @@ -123,9 +132,11 @@ class TraceMe { // This overload is necessary to make TraceMe's with string literals work. // Otherwise, the name_generator template would be used. - explicit TraceMe(const char* raw, int level = 1, - uint64_t filter_mask = kTraceMeDefaultFilterMask) - : TraceMe(absl::string_view(raw), level, filter_mask) {} + explicit TraceMe( + const char* raw, int level = 1, + uint64_t filter_mask = kTraceMeDefaultFilterMask, + absl::SourceLocation source_location = absl::SourceLocation::current()) + : TraceMe(absl::string_view(raw), level, filter_mask, source_location) {} // This overload only generates the name (and possibly metadata) if tracing is // enabled. Useful for avoiding expensive operations (e.g., string @@ -146,13 +157,20 @@ class TraceMe { // }); template , bool> = true> - explicit TraceMe(NameGeneratorT&& name_generator, int level = 1, - uint64_t filter_mask = kTraceMeDefaultFilterMask) { + explicit TraceMe( + NameGeneratorT&& name_generator, int level = 1, + uint64_t filter_mask = kTraceMeDefaultFilterMask, + absl::SourceLocation source_location = absl::SourceLocation::current()) { DCHECK_GE(level, 1); #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) && TraceMeRecorder::CheckFilter(filter_mask))) { name_.Emplace(std::forward(name_generator)()); + AppendMetadata([&]() { + return TraceMeEncode( + {{"traceme", absl::StrCat(source_location.file_name(), ":", + source_location.line())}}); + }); start_time_ = GetCurrentTimeNanos(); } #endif