-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathgemm_all_scatter.py
More file actions
148 lines (125 loc) · 5.03 KB
/
gemm_all_scatter.py
File metadata and controls
148 lines (125 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
import triton
import triton.language as tl
from iris.device_utils import read_realtime
import iris
@triton.jit()
def persistent_gemm_all_scatter(
A,
B,
C,
c_global,
bias_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_cm_global,
stride_cn_global,
stride_bias,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
heap_bases: tl.tensor,
cur_rank: tl.constexpr,
world_size: tl.constexpr,
COLLECT_TIMESTAMPS: tl.constexpr = False,
mm_begin_timestamp_ptr: tl.tensor = None,
mm_end_timestamp_ptr: tl.tensor = None,
):
pid = tl.program_id(0)
if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
for tile_id in range(pid, total_tiles, NUM_SMS):
if COLLECT_TIMESTAMPS:
timestamp = read_realtime()
tl.atomic_min(mm_begin_timestamp_ptr + tile_id, timestamp)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rk = tl.arange(0, BLOCK_SIZE_K)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
loop_k = tl.cdiv(K, BLOCK_SIZE_K)
if not EVEN_K:
loop_k -= 1
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, loop_k):
a = tl.load(tl.multiple_of(A_BASE, (1, 16)))
b = tl.load(tl.multiple_of(B_BASE, (16, 1)))
acc += tl.dot(a, b)
A_BASE += BLOCK_SIZE_K * stride_ak
B_BASE += BLOCK_SIZE_K * stride_bk
if not EVEN_K:
k = loop_k
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
A_BASE = tl.multiple_of(A_BASE, (1, 16))
B_BASE = tl.multiple_of(B_BASE, (16, 1))
a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0)
b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0)
acc += tl.dot(a, b)
# Accumulator registers with C results
c = acc.to(C.type.element_ty)
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# Add compiler hints
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
# Define the C-mask (BLOCK_SIZE_M, 1) x (1, BLOCK_SIZE_N)
sub_mask = (rm[:, None] < M) & (rn[None, :] < N)
# Calculate the "global" offset of C based on the rank.
# Note how the N-dimension is being multiplied by current rank.
# This is because each rank is computing a portion of the N-dimension
# locally and then scattering it to all other ranks to complete
# the global N-dimension.
global_offset = rm[:, None] * stride_cm_global + (rn[None, :] + cur_rank * N) * stride_cn_global
# Timestamp for GEMM before store
if COLLECT_TIMESTAMPS:
timestamp = read_realtime()
tl.atomic_max(mm_end_timestamp_ptr + tile_id, timestamp)
# Store data to the global result using puts
for remote_rank in range(world_size):
if remote_rank == cur_rank:
# For the current rank, we can use store
tl.store(c_global + global_offset, c, mask=sub_mask)
else:
iris.store(
c_global + global_offset,
c,
cur_rank,
remote_rank,
heap_bases,
mask=sub_mask,
)