Skip to content

Commit 167eea4

Browse files
authored
Optimize range gather for small hidden size and minor fix others (#202)
* implemented range_gather_per_range kernel to optimize cases with small hidden size, with ut updated * fixed range_gather_per_range_kernel bug for stride!=1 * fixed test_range_op bug for missing assert * fixed zero ROWS_PER_BLOCK; raised the meta args from int32 to int64 for range ops * added print_rank arg to assert_close with docstring updated; added rank info for test ffa test_case_string * moved random seed set to init_pg * updated the erro thresholds for test_dist_attn * updated the base image version from 25.10.2 to 25.10.3 * updated the workflow to skip testing installation on cuda12 when magi_attention/csrc has no change * minor fixed workflow
1 parent 4ea4003 commit 167eea4

File tree

14 files changed

+284
-130
lines changed

14 files changed

+284
-130
lines changed

.github/workflows/build_test.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,18 @@ jobs:
5050
filters: |
5151
MagiAttention:
5252
- 'magi_attention/**'
53+
MagiAttentionCsrc:
54+
- 'magi_attention/csrc/**'
5355
- name: print filter results
5456
run: |
5557
echo "is MagiAttention modified: ${{ steps.filter.outputs.MagiAttention }}"
58+
echo "is MagiAttention csrc modified: ${{ steps.filter.outputs.MagiAttentionCsrc }}"
5659
5760
install_MagiAttention_ngc2505_cuda12:
5861
needs: [detect_changes]
5962
if: |
6063
always() &&
61-
needs.detect_changes.outputs.MagiAttention == 'true'
64+
needs.detect_changes.outputs.MagiAttentionCsrc == 'true'
6265
runs-on: [self-hosted]
6366
container:
6467
image: registry.cn-sh-01.sensecore.cn/sandai-ccr/magi-base:25.05.4
@@ -101,11 +104,11 @@ jobs:
101104
always() &&
102105
(
103106
needs.detect_changes.outputs.MagiAttention == 'true' &&
104-
needs.install_MagiAttention_ngc2505_cuda12.result == 'success'
107+
(needs.detect_changes.outputs.MagiAttentionCsrc != 'true' || needs.install_MagiAttention_ngc2505_cuda12.result == 'success')
105108
)
106109
runs-on: [self-hosted]
107110
container:
108-
image: registry.cn-sh-01.sensecore.cn/sandai-ccr/magi-base:25.10.2
111+
image: registry.cn-sh-01.sensecore.cn/sandai-ccr/magi-base:25.10.3
109112
options: --gpus all --ipc host
110113
credentials:
111114
username: ${{ secrets.DOCKER_USER_NAME }}

magi_attention/comm/primitive/grpcoll/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _calc_range_gather_kwargs_from_ranges_with_rank(
102102
total_size = sum(range_sizes)
103103

104104
# calculate row_map from row idx to range idx
105-
range_sizes = torch.tensor([0] + range_sizes, dtype=torch.int32, device=device)
105+
range_sizes = torch.tensor([0] + range_sizes, dtype=torch.int64, device=device)
106106
row_map = torch.repeat_interleave(
107107
torch.arange(0, len(ranges), device=device),
108108
range_sizes[1:],
@@ -141,7 +141,7 @@ def _calc_unperm_range_gather_kwargs_from_split_size_list(
141141
range_sizes = [end - start for start, end in ranges]
142142
range_sizes = torch.tensor(
143143
[0] + range_sizes,
144-
dtype=torch.int32,
144+
dtype=torch.int64,
145145
device=device,
146146
)
147147

@@ -186,7 +186,7 @@ def _calc_range_reduce_kwargs_from_ranges(
186186
total_size += reduce_end - reduce_start
187187

188188
range_reduce_kwargs: dict[str, Any] = {"deterministic": deterministic}
189-
input_ranges = torch.tensor(input_ranges, dtype=torch.int32, device=device)
189+
input_ranges = torch.tensor(input_ranges, dtype=torch.int64, device=device)
190190
range_reduce_kwargs["input_ranges"] = input_ranges
191191

192192
if deterministic:
@@ -214,7 +214,7 @@ def _calc_range_reduce_kwargs_from_ranges(
214214
range_reduce_kwargs["out2inp_range_map"] = out2inp_range_map
215215
range_reduce_kwargs["unique_ordered_out_ranges"] = unique_ordered_out_ranges
216216
else:
217-
range_sizes = torch.tensor([0] + range_sizes, dtype=torch.int32, device=device)
217+
range_sizes = torch.tensor([0] + range_sizes, dtype=torch.int64, device=device)
218218
cu_range_sizes = torch.cumsum(range_sizes, dim=0)
219219
row_map = torch.repeat_interleave(
220220
torch.arange(0, input_ranges.shape[0], device=device),
@@ -227,7 +227,7 @@ def _calc_range_reduce_kwargs_from_ranges(
227227
range_reduce_kwargs["total_size"] = total_size
228228
range_reduce_kwargs["row_map"] = row_map
229229

230-
output_ranges = torch.tensor(output_ranges, dtype=torch.int32, device=device)
230+
output_ranges = torch.tensor(output_ranges, dtype=torch.int64, device=device)
231231
range_reduce_kwargs["output_ranges"] = output_ranges
232232

233233
return range_reduce_kwargs

magi_attention/common/range_op/_range_gather.py

Lines changed: 134 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Literal, TypeAlias
16+
1517
import torch
1618
import triton
1719
import triton.language as tl
@@ -23,8 +25,61 @@
2325
__all__ = ["range_gather"]
2426

2527

28+
RangeGatherKernelBackend: TypeAlias = Literal["per_row", "per_range"]
29+
30+
31+
@triton.jit
32+
def range_gather_per_range_kernel(
33+
input_ptr,
34+
output_ptr,
35+
ranges_ptr,
36+
cu_range_sizes_ptr,
37+
input_stride,
38+
output_stride,
39+
N_PER_ROW: tl.constexpr,
40+
ROWS_PER_BLOCK: tl.constexpr,
41+
UNROLL_FACTOR: tl.constexpr = 4,
42+
):
43+
range_idx = tl.program_id(0)
44+
cu_range_size = tl.load(cu_range_sizes_ptr + range_idx)
45+
range_start = tl.load(ranges_ptr + range_idx * 2)
46+
range_end = tl.load(ranges_ptr + range_idx * 2 + 1)
47+
range_size = range_end - range_start
48+
49+
num_row_blocks = (range_size + ROWS_PER_BLOCK - 1) // ROWS_PER_BLOCK
50+
row_offs = tl.arange(0, ROWS_PER_BLOCK)[:, None]
51+
col_offs = tl.arange(0, N_PER_ROW)[None, :]
52+
input_offs = (row_offs * input_stride) + col_offs
53+
output_offs = (row_offs * output_stride) + col_offs
54+
col_mask = (col_offs < input_stride) & (col_offs < output_stride)
55+
56+
inp_idx = range_start * input_stride
57+
out_idx = cu_range_size * output_stride
58+
curr_inp_ptr = input_ptr + inp_idx
59+
curr_out_ptr = output_ptr + out_idx
60+
61+
for row_block_idx in tl.range(num_row_blocks, loop_unroll_factor=UNROLL_FACTOR):
62+
row_start = row_block_idx * ROWS_PER_BLOCK
63+
inp_ptr_this_block = curr_inp_ptr + row_start * input_stride
64+
out_ptr_this_block = curr_out_ptr + row_start * output_stride
65+
66+
row_mask = row_offs + row_start < range_size
67+
mask = row_mask & col_mask
68+
69+
inp = tl.load(
70+
inp_ptr_this_block + input_offs,
71+
mask=mask,
72+
)
73+
tl.store(
74+
out_ptr_this_block + output_offs,
75+
inp,
76+
mask=mask,
77+
cache_modifier=".cs", # cache streaming, since accessed once
78+
)
79+
80+
2681
@triton.jit
27-
def range_gather_kernel(
82+
def range_gather_per_row_kernel(
2883
input_ptr,
2984
output_ptr,
3085
ranges_ptr,
@@ -110,14 +165,28 @@ def range_gather(
110165
# sanity check
111166
assert cu_range_sizes.size(0) == ranges.size(0) + 1
112167

113-
# Calculate row_map if not provided
114-
row_map = kwargs.pop("row_map", None)
115-
if row_map is None:
116-
row_map = _calc_ranges_row_map(ranges, total_size)
117-
else:
118-
row_map = row_map.contiguous()
119-
# sanity check
120-
assert row_map.size(0) == total_size
168+
# Determine which kernel to use
169+
kernel_backend: RangeGatherKernelBackend | None = kwargs.pop("kernel_backend", None)
170+
if kernel_backend is None: # auto dispatch
171+
# heuristic: default use per-row kernel when hidden size per row is non-trivially small
172+
# TODO: refine the heuristic for better performance
173+
hidden_size_per_row = (
174+
input.numel() // input.shape[0] if input.shape[0] > 0 else 0
175+
)
176+
if hidden_size_per_row >= 128:
177+
kernel_backend = "per_row"
178+
else:
179+
kernel_backend = "per_range"
180+
181+
# Calculate row_map if not provided but required
182+
if kernel_backend == "per_row":
183+
row_map = kwargs.pop("row_map", None)
184+
if row_map is None:
185+
row_map = _calc_ranges_row_map(ranges, total_size)
186+
else:
187+
row_map = row_map.contiguous()
188+
# sanity check
189+
assert row_map.size(0) == total_size
121190

122191
# --- pre-process input/output --- #
123192

@@ -145,30 +214,62 @@ def range_gather(
145214
input_stride = input.stride(0)
146215
output_stride = output.stride(0)
147216

148-
# --- calculate grid size --- #
149-
150-
M = total_size
151-
N = input.numel() // input.shape[0]
152-
153-
ELEM_PER_BLOCK = 2048 // input.element_size()
154-
N_BLOCK = triton.cdiv(N, ELEM_PER_BLOCK)
155-
156-
grid = (M, N_BLOCK)
157-
158-
# --- launch kernel --- #
159-
160-
range_gather_kernel[grid](
161-
input,
162-
output,
163-
ranges,
164-
cu_range_sizes,
165-
row_map,
166-
input_stride,
167-
output_stride,
168-
N,
169-
N_BLOCK,
170-
ELEM_PER_BLOCK,
171-
)
217+
match kernel_backend:
218+
case "per_row":
219+
# --- calculate grid size --- #
220+
221+
M = total_size
222+
N = input.numel() // input.shape[0]
223+
224+
ELEM_PER_BLOCK = 2048 // input.element_size() # heuristic
225+
N_BLOCK = triton.cdiv(N, ELEM_PER_BLOCK)
226+
227+
grid = (M, N_BLOCK)
228+
229+
# --- launch kernel --- #
230+
231+
range_gather_per_row_kernel[grid](
232+
input,
233+
output,
234+
ranges,
235+
cu_range_sizes,
236+
row_map,
237+
input_stride,
238+
output_stride,
239+
N,
240+
N_BLOCK,
241+
ELEM_PER_BLOCK,
242+
num_warps=4, # block_size=128
243+
)
244+
case "per_range":
245+
# --- calculate grid size --- #
246+
247+
M = ranges.shape[0]
248+
grid = (M,) # type: ignore[assignment]
249+
250+
N_PER_ROW = triton.next_power_of_2(
251+
max(input_stride, output_stride)
252+
) # heuristic
253+
avg_range_size = (total_size + M - 1) // M
254+
ROWS_PER_BLOCK = max(
255+
1, min(triton.next_power_of_2(avg_range_size // 2), 4096)
256+
) # heuristic
257+
258+
# --- launch kernel --- #
259+
260+
range_gather_per_range_kernel[grid](
261+
input,
262+
output,
263+
ranges,
264+
cu_range_sizes,
265+
input_stride,
266+
output_stride,
267+
N_PER_ROW,
268+
ROWS_PER_BLOCK,
269+
num_warps=8, # block_size=256
270+
)
271+
case _:
272+
raise ValueError(f"Unsupported kernel_backend: {kernel_backend}")
172273

173274
# --- post-process output --- #
174275

magi_attention/common/range_op/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _calc_cu_range_sizes(
3030
total_size += end - start
3131
cu_range_sizes.append(total_size)
3232

33-
cu_range_sizes = torch.tensor(cu_range_sizes, dtype=torch.int32, device=device)
33+
cu_range_sizes = torch.tensor(cu_range_sizes, dtype=torch.int64, device=device)
3434

3535
return cu_range_sizes, total_size
3636

@@ -40,7 +40,7 @@ def _calc_ranges_row_map(
4040
total_size: int,
4141
) -> torch.Tensor:
4242
if ranges.shape[0] == 0:
43-
return torch.empty(0, dtype=torch.int32, device=ranges.device)
43+
return torch.empty(0, dtype=torch.int64, device=ranges.device)
4444

4545
row_map = torch.arange(0, ranges.shape[0], device=ranges.device)
4646
range_sizes = ranges[:, 1] - ranges[:, 0]
@@ -82,10 +82,10 @@ def _calc_out2inp_range_map(
8282
out2inp_range_map.append(inp_range_list)
8383

8484
out2inp_range_map = torch.tensor(
85-
out2inp_range_map, dtype=torch.int32, device=device
85+
out2inp_range_map, dtype=torch.int64, device=device
8686
)
8787
unique_ordered_out_ranges = torch.tensor(
88-
unique_ordered_out_ranges, dtype=torch.int32, device=device
88+
unique_ordered_out_ranges, dtype=torch.int64, device=device
8989
)
9090

9191
return out2inp_range_map, unique_ordered_out_ranges, max_inp_indices_size

magi_attention/testing/dist_common.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def init_pg(self) -> None:
7979
]:
8080
raise RuntimeError(f"Backend {self.backend} not supported!")
8181

82+
# Initialize the process group
8283
dist.init_process_group(
8384
backend=self.backend,
8485
world_size=self.world_size,
@@ -87,10 +88,13 @@ def init_pg(self) -> None:
8788
timeout=datetime.timedelta(minutes=30),
8889
)
8990

90-
# set device for nccl pg for collectives
91+
# Set the device for this process
9192
if "nccl" in self.backend:
9293
torch.cuda.set_device(self.rank)
9394

95+
# Set random seed with rank offset
96+
self._set_random_seed()
97+
9498
def destroy_pg(self) -> None:
9599
# Wait for all ranks to reach here before starting shutdown.
96100
# FIXME dist.barrier deadlocks with multiple threads and NCCL: https://github.com/pytorch/pytorch/issues/95895
@@ -112,7 +116,6 @@ def setUp(self) -> None:
112116
TIMEOUT_OVERRIDE.update({self.id().split(".")[-1]: timeout})
113117

114118
self._spawn_processes()
115-
self._set_random_seed()
116119

117120

118121
TestFunc = Callable[..., Any]

magi_attention/testing/precision.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616

1717
import torch
18+
import torch.distributed as dist
1819
from packaging import version
1920

2021
from magi_attention.functional.utils import safe_subtract
@@ -95,17 +96,35 @@ def assert_close(
9596
rtol: float = 1e-5,
9697
mismatch_threshold: float = 0,
9798
test_case: str = "",
99+
print_rank: int = 0,
98100
) -> None:
101+
"""Assert that two tensors are close within given tolerances,
102+
with a mismatch threshold to allow some degree of mismatch.
103+
104+
Args:
105+
a (torch.Tensor): tensor a.
106+
b (torch.Tensor): tensor b.
107+
atol (float, optional): absolute tolerance. Defaults to ``1e-5``.
108+
rtol (float, optional): relative tolerance. Defaults to ``1e-5``.
109+
mismatch_threshold (float, optional): allowed mismatch threshold. Defaults to ``0``.
110+
test_case (str, optional): test case description. Defaults to "".
111+
print_rank (int, optional): rank to print from. Defaults to ``0``.
112+
And set to ``-1`` to print from all ranks.
113+
"""
99114
assert (
100115
0 <= mismatch_threshold <= 1
101116
), f"{mismatch_threshold=} must be between 0 and 1"
117+
118+
if dist.is_initialized():
119+
rank = dist.get_rank()
120+
is_this_print_rank = print_rank == -1 or rank == print_rank
121+
else:
122+
is_this_print_rank = True
123+
102124
try:
103125
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
104126
no_mismatch_info = f"[{test_case}]: has no mismatch"
105-
if torch.distributed.is_initialized():
106-
if torch.distributed.get_rank() == 0:
107-
print(no_mismatch_info)
108-
else:
127+
if is_this_print_rank:
109128
print(no_mismatch_info)
110129
except AssertionError as e:
111130
error_msg = str(e)
@@ -119,10 +138,7 @@ def assert_close(
119138
)
120139

121140
if mismatch_ratio <= mismatch_threshold:
122-
if torch.distributed.is_initialized():
123-
if torch.distributed.get_rank() == 0:
124-
print(mismatch_info)
125-
else:
141+
if is_this_print_rank:
126142
print(mismatch_info)
127143
return
128144
else:

0 commit comments

Comments
 (0)