Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tsl/profiler/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ cc_library(
":traceme_encode",
"//tsl/platform",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:source_location",
"@xla//xla/tsl/platform:logging",
"@xla//xla/tsl/platform:macros",
"@xla//xla/tsl/platform:types",
Expand Down
32 changes: 25 additions & 7 deletions tsl/profiler/lib/traceme.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the License.
#include <type_traits>
#include <utility>

#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/source_location.h"
#include "xla/tsl/platform/logging.h"
#include "xla/tsl/platform/macros.h"
#include "xla/tsl/profiler/utils/no_init.h"
Expand Down Expand Up @@ -94,13 +96,20 @@ class TraceMe {
// - Can be a value in enum TraceMeLevel.
// Users are welcome to use level > 3 in their code, if they wish to filter
// out their host traces based on verbosity.
explicit TraceMe(absl::string_view name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
explicit TraceMe(
absl::string_view name, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask,
absl::SourceLocation source_location = absl::SourceLocation::current()) {
DCHECK_GE(level, 1);
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
name_.Emplace(std::string(name));
traceme_internal::AppendMetadata(
&name_.value,
TraceMeEncode(
{{"traceme", absl::StrCat(source_location.file_name(), ":",
source_location.line())}}));
start_time_ = GetCurrentTimeNanos();
}
#endif
Expand All @@ -123,9 +132,11 @@ class TraceMe {

// This overload is necessary to make TraceMe's with string literals work.
// Otherwise, the name_generator template would be used.
explicit TraceMe(const char* raw, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask)
: TraceMe(absl::string_view(raw), level, filter_mask) {}
explicit TraceMe(
const char* raw, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask,
absl::SourceLocation source_location = absl::SourceLocation::current())
: TraceMe(absl::string_view(raw), level, filter_mask, source_location) {}

// This overload only generates the name (and possibly metadata) if tracing is
// enabled. Useful for avoiding expensive operations (e.g., string
Expand All @@ -146,13 +157,20 @@ class TraceMe {
// });
template <typename NameGeneratorT,
std::enable_if_t<std::is_invocable_v<NameGeneratorT>, bool> = true>
explicit TraceMe(NameGeneratorT&& name_generator, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask) {
explicit TraceMe(
NameGeneratorT&& name_generator, int level = 1,
uint64_t filter_mask = kTraceMeDefaultFilterMask,
absl::SourceLocation source_location = absl::SourceLocation::current()) {
DCHECK_GE(level, 1);
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(TraceMeRecorder::Active(level) &&
TraceMeRecorder::CheckFilter(filter_mask))) {
name_.Emplace(std::forward<NameGeneratorT>(name_generator)());
AppendMetadata([&]() {
return TraceMeEncode(
{{"traceme", absl::StrCat(source_location.file_name(), ":",
source_location.line())}});
});
start_time_ = GetCurrentTimeNanos();
}
#endif
Expand Down