Skip to content

Commit 14df45a

Browse files
authored
Enhance NVTX analysis capabilities (triton-inference-server#106)
* Add a Domain to NVTX analysis Signed-off-by: szalpal <[email protected]> * Fix pre-commit Signed-off-by: szalpal <[email protected]> * Add missing option to CMake Signed-off-by: szalpal <[email protected]> * Working around the problem of unlined CUDA Signed-off-by: szalpal <[email protected]> * Missing ; Signed-off-by: szalpal <[email protected]> * More compilation fixes Signed-off-by: szalpal <[email protected]> --------- Signed-off-by: szalpal <[email protected]>
1 parent 3ecedb0 commit 14df45a

File tree

1 file changed

+66
-7
lines changed

1 file changed

+66
-7
lines changed

include/triton/common/nvtx.h

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -31,29 +31,88 @@
3131

3232
namespace triton { namespace common {
3333

34+
namespace detail {
35+
36+
class NvtxTritonDomain {
37+
public:
38+
static nvtxDomainHandle_t& GetDomain()
39+
{
40+
static NvtxTritonDomain inst;
41+
return inst.triton_nvtx_domain_;
42+
}
43+
44+
private:
45+
NvtxTritonDomain() { triton_nvtx_domain_ = nvtxDomainCreateA("Triton"); }
46+
47+
~NvtxTritonDomain() { nvtxDomainDestroy(triton_nvtx_domain_); }
48+
49+
nvtxDomainHandle_t triton_nvtx_domain_;
50+
};
51+
52+
} // namespace detail
53+
3454
// Updates a server stat with duration measured by a C++ scope.
3555
class NvtxRange {
3656
public:
37-
explicit NvtxRange(const char* label) { nvtxRangePushA(label); }
57+
explicit NvtxRange(const char* label, uint32_t rgb = kNvGreen)
58+
{
59+
auto attr = GetAttributes(label, rgb);
60+
nvtxDomainRangePushEx(detail::NvtxTritonDomain::GetDomain(), &attr);
61+
}
62+
63+
explicit NvtxRange(const std::string& label, uint32_t rgb = kNvGreen)
64+
: NvtxRange(label.c_str(), rgb)
65+
{
66+
}
3867

39-
explicit NvtxRange(const std::string& label) : NvtxRange(label.c_str()) {}
68+
~NvtxRange() { nvtxDomainRangePop(detail::NvtxTritonDomain::GetDomain()); }
4069

41-
~NvtxRange() { nvtxRangePop(); }
70+
static constexpr uint32_t kNvGreen = 0x76b900;
71+
static constexpr uint32_t kRed = 0xc1121f;
72+
static constexpr uint32_t kGreen = 0x588157;
73+
static constexpr uint32_t kBlue = 0x023047;
74+
static constexpr uint32_t kYellow = 0xffb703;
75+
static constexpr uint32_t kOrange = 0xfb8500;
76+
77+
private:
78+
nvtxEventAttributes_t GetAttributes(const char* label, uint32_t rgb)
79+
{
80+
nvtxEventAttributes_t attr;
81+
attr.version = NVTX_VERSION;
82+
attr.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE;
83+
attr.colorType = NVTX_COLOR_ARGB;
84+
attr.color = rgb | 0xff000000;
85+
attr.messageType = NVTX_MESSAGE_TYPE_ASCII;
86+
attr.message.ascii = label;
87+
return attr;
88+
}
4289
};
4390

4491
}} // namespace triton::common
4592

4693
#endif // TRITON_ENABLE_NVTX
4794

4895
//
49-
// Macros to access NVTX functionality
96+
// Macros to access NVTX functionality.
97+
// For `NVTX_RANGE` macro please refer to the usage below.
5098
//
5199
#ifdef TRITON_ENABLE_NVTX
52100
#define NVTX_INITIALIZE nvtxInitialize(nullptr)
53-
#define NVTX_RANGE(V, L) triton::common::NvtxRange V(L)
101+
#define NVTX_RANGE1(V, L) triton::common::NvtxRange V(L)
102+
#define NVTX_RANGE2(V, L, RGB) triton::common::NvtxRange V(L, RGB)
54103
#define NVTX_MARKER(L) nvtxMarkA(L)
55104
#else
56105
#define NVTX_INITIALIZE
57-
#define NVTX_RANGE(V, L)
106+
#define NVTX_RANGE1(V, L)
107+
#define NVTX_RANGE2(V, L, RGB)
58108
#define NVTX_MARKER(L)
59109
#endif // TRITON_ENABLE_NVTX
110+
111+
// "Overload" for `NVTX_RANGE` macro.
112+
// Usage:
113+
// NVTX_RANGE(nvtx1, "My message") -> Records NVTX marker with kNvGreen color.
114+
// NVTX_RANGE(nvtx1, "My message", NvtxRange::kRed) -> Records NVTX marker with
115+
// kRed color.
116+
#define GET_NVTX_MACRO(_1, _2, _3, NAME, ...) NAME
117+
#define NVTX_RANGE(...) \
118+
GET_NVTX_MACRO(__VA_ARGS__, NVTX_RANGE2, NVTX_RANGE1)(__VA_ARGS__)

0 commit comments

Comments
 (0)