Skip to content

Commit 238f1d2

Browse files
authored
feat: SM level profiler (#1305)
<!-- .github/pull_request_template.md --> ## 📌 Description Simply add smid into the profiler tag. Now it looks like this <img width="1284" height="670" alt="image" src="https://github.com/user-attachments/assets/dc754dc3-1e6b-4792-80c0-c0ec65fb5a64" /> Also fixed profiler install in pyproject.toml cc @happierpig @yzh119 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 60cf6e4 commit 238f1d2

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

flashinfer/profiler/__init__.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,49 +32,58 @@ class EventType(Enum):
3232

3333

3434
def decode_tag(tag, num_blocks, num_groups):
35-
block_group_tag = tag >> 12
35+
"""
36+
Decode a profiler tag into (block_idx, group_idx, event_idx, event_type, sm_id).
37+
Tag layout:
38+
bits 0-1: event_type
39+
bits 2-11: event_idx
40+
bits 12-23: block_group_idx
41+
bits 24-31: sm_id
42+
"""
43+
sm_id = (tag >> 24) & 0xFF
44+
block_group_idx = (tag >> 12) & 0xFFF
3645
event_idx = (tag >> 2) & 0x3FF
3746
event_type = tag & 0x3
38-
return (
39-
block_group_tag // num_groups,
40-
block_group_tag % num_groups,
41-
event_idx,
42-
event_type,
43-
)
47+
block_idx = block_group_idx // num_groups
48+
group_idx = block_group_idx % num_groups
49+
return block_idx, group_idx, event_idx, event_type, sm_id
4450

4551

4652
def export_to_perfetto_trace(
4753
profiler_buffer: torch.Tensor,
4854
event_names: List[str],
4955
file_name: str,
5056
) -> None:
51-
57+
assert profiler_buffer.dtype == torch.uint64
5258
profiler_buffer_host = profiler_buffer.cpu()
5359
num_blocks, num_groups = profiler_buffer_host[:1].view(dtype=torch.int32)
5460
num_blocks = int(num_blocks)
5561
num_groups = int(num_groups)
5662

5763
tgen = TraceGenerator(file_name)
5864

65+
pid_map = {}
5966
tid_map = {}
6067
track_map = {}
61-
for block_idx in range(num_blocks):
62-
pid = tgen.create_group(f"block_{block_idx}")
63-
for group_idx in range(num_groups):
64-
tid = pid.create_group(f"group_{group_idx}")
65-
tid_map[(block_idx, group_idx)] = tid
6668

6769
for i in range(1, len(profiler_buffer_host)):
6870
if profiler_buffer_host[i] == 0:
6971
continue
7072
tag, timestamp = profiler_buffer_host[i : i + 1].view(dtype=torch.uint32)
7173
tag = int(tag)
7274
timestamp = int(timestamp)
73-
block_idx, group_idx, event_idx, event_type = decode_tag(
75+
block_idx, group_idx, event_idx, event_type, sm_id = decode_tag(
7476
tag, num_blocks, num_groups
7577
)
76-
event = event_names[event_idx]
78+
79+
# create trackers
80+
if block_idx not in pid_map:
81+
pid_map[block_idx] = tgen.create_group(f"sm_{sm_id}_block_{block_idx}")
82+
pid = pid_map[block_idx]
83+
if (block_idx, group_idx) not in tid_map:
84+
tid_map[(block_idx, group_idx)] = pid.create_group(f"group_{group_idx}")
7785
tid = tid_map[(block_idx, group_idx)]
86+
event = event_names[event_idx]
7887

7988
if (block_idx, group_idx, event_idx) in track_map:
8089
track = track_map[(block_idx, group_idx, event_idx)]

include/flashinfer/profiler.cuh

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,21 @@ constexpr uint32_t BEGIN_END_MASK = 0x3;
3636

3737
constexpr uint32_t EVENT_IDX_SHIFT = 2;
3838
constexpr uint32_t BLOCK_GROUP_IDX_SHIFT = 12;
39+
constexpr uint32_t SM_ID_SHIFT = 24;
40+
// Tag layout:
41+
// bits 0-1: event_type (start, end, instant)
42+
// bits 2-11: event_idx (translates to event_names in python profiler)
43+
// bits 12-23: block_id (12 bits)
44+
// bits 24-31: sm_id (8 bits)
3945

4046
constexpr uint32_t EVENT_BEGIN = 0x0;
4147
constexpr uint32_t EVENT_END = 0x1;
4248
constexpr uint32_t EVENT_INSTANT = 0x2;
4349

44-
__device__ __forceinline__ uint32_t encode_tag(uint32_t block_group_idx, uint32_t event_idx,
45-
uint32_t event_type) {
46-
return (block_group_idx << BLOCK_GROUP_IDX_SHIFT) | (event_idx << EVENT_IDX_SHIFT) | event_type;
50+
__device__ __forceinline__ uint32_t encode_tag(uint32_t sm_id, uint32_t block_id,
51+
uint32_t event_idx, uint32_t event_type) {
52+
return (sm_id << SM_ID_SHIFT) | (block_id << BLOCK_GROUP_IDX_SHIFT) |
53+
(event_idx << EVENT_IDX_SHIFT) | event_type;
4754
}
4855

4956
__device__ __forceinline__ uint32_t get_timestamp() {
@@ -79,17 +86,19 @@ struct ProfilerEntry {
7986
#define PROFILER_FUNC_PARAMS , at::Tensor profiler_buffer
8087
#define PROFILER_PARAMS_DECL uint64_t* profiler_buffer;
8188

82-
#define PROFILER_INIT(params, smem_storage, closure, group_idx, num_groups, \
83-
write_thread_predicate) \
84-
if (get_block_idx() == 0 && get_thread_idx() == 0) { \
85-
closure.entry.nblocks = get_num_blocks(); \
86-
closure.entry.ngroups = num_groups; \
87-
params.profiler_buffer[0] = closure.entry.raw; \
88-
} \
89-
closure.profiler_write_ptr = \
90-
params.profiler_buffer + 1 + get_block_idx() * num_groups + group_idx; \
91-
closure.profiler_write_stride = get_num_blocks() * num_groups; \
92-
closure.profiler_entry_tag_base = encode_tag(get_block_idx() * num_groups + group_idx, 0, 0); \
89+
#define PROFILER_INIT(params, smem_storage, closure, group_idx, num_groups, \
90+
write_thread_predicate) \
91+
uint32_t _sm_idx; \
92+
asm volatile("mov.u32 %0, %smid;" : "=r"(_sm_idx)); \
93+
if (get_block_idx() == 0 && get_thread_idx() == 0) { \
94+
closure.entry.nblocks = get_num_blocks(); \
95+
closure.entry.ngroups = num_groups; \
96+
params.profiler_buffer[0] = closure.entry.raw; \
97+
} \
98+
closure.profiler_write_ptr = \
99+
params.profiler_buffer + 1 + get_block_idx() * num_groups + group_idx; \
100+
closure.profiler_write_stride = get_num_blocks() * num_groups; \
101+
closure.profiler_entry_tag_base = encode_tag(_sm_idx, get_block_idx(), 0, 0); \
93102
closure.profiler_write_thread_predicate = write_thread_predicate;
94103

95104
#define PROFILER_EVENT_START(closure, event) \

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ packages = [
4949
"flashinfer.jit",
5050
"flashinfer.jit.attention",
5151
"flashinfer.triton",
52+
"flashinfer.profiler",
5253
"flashinfer.triton.kernels",
5354
"flashinfer.comm",
5455
"flashinfer.cudnn",

0 commit comments

Comments
 (0)