Skip to content

Commit e8efee9

Browse files
Merge OpenAI Triton commit dcad270 (#5598)
This PR changes the Triton base from e4e68d1 to dcad270 (Nov 23). Pass rate: 95.42%->95.86%
2 parents 2d5c967 + 83eb05c commit e8efee9

File tree

26 files changed

+1890
-1969
lines changed

26 files changed

+1890
-1969
lines changed

lib/Plugins/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,16 @@ foreach( plugin ${TRITON_PLUGIN_PASSES} )
3535
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
3636
)
3737

38-
set_target_properties(${plugin} PROPERTIES
38+
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
39+
# build. It is empty if building directly from the root
40+
# CMakeLists.txt file. Therefore if not building from Python just
41+
# use the default CMake shared lib path otherwise this causes a hard
42+
# build error
43+
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
44+
set_target_properties(${plugin} PROPERTIES
3945
LIBRARY_OUTPUT_DIRECTORY
4046
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../plugins")
47+
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
4148

4249
target_compile_options(${plugin} PRIVATE -fvisibility=hidden)
4350
endforeach()

python/test/unit/language/test_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6762,4 +6762,4 @@ def kernel(X, Y, Z, RANK: tl.constexpr, TRANS_A: tl.constexpr, TRANS_B: tl.const
67626762

67636763
d = a.to(torch.float32) @ b.to(torch.float32)
67646764

6765-
assert torch.equal(c, d)
6765+
assert torch.allclose(c, d, rtol=1e-3, atol=1e-2)

python/triton/tools/ragged_tma.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
1212
of potentially unequal size.
1313
1414
The load_ragged and store_ragged device functions can be used to read
15-
and write from subarrays T[batch_offset : batch_offset + batch_size]
15+
and write from subarrays T[slice_off : slice_off + slice_size]
1616
with hardware bounds-checking preventing any sort of leakage outside
1717
the subarray.
1818
"""
@@ -46,22 +46,22 @@ def create_ragged_descriptor(T, block_shape, ragged_dim=0):
4646

4747

4848
@triton.jit
49-
def to_ragged_indices(batch_offset, batch_size, row):
49+
def to_ragged_indices(slice_off, slice_size, row):
5050
"""
5151
Helper function for load_ragged and store_ragged.
5252
"""
5353

5454
billion = 0x40000000 # == 2**30
55-
x = billion - batch_size + row
56-
y = batch_offset + batch_size
55+
x = billion - slice_size + row
56+
y = slice_off + slice_size
5757

5858
return billion, y, x
5959

6060

6161
@triton.jit
62-
def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr = 0):
62+
def load_ragged(TMA, slice_off, slice_size, coords, ragged_dim: tl.constexpr = 0):
6363
"""
64-
Read from a subarray T[batch_offset : batch_offset + batch_size] with
64+
Read from a subarray T[slice_off : slice_off + slice_size] with
6565
hardware bounds-checking, where reading outside the subarray gives zeros.
6666
6767
Coords should be an appropriately-sized list of integers, just like in
@@ -70,39 +70,39 @@ def load_ragged(TMA, batch_offset, batch_size, coords, ragged_dim: tl.constexpr
7070

7171
tl.static_assert(len(TMA.shape) == len(coords) + 2, "TMA must be a read-write ragged descriptor")
7272

73-
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
73+
c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
7474
data = TMA.load([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:])
7575
data = tl.reshape(data, data.shape[2:])
7676
return data
7777

7878

7979
@triton.jit
80-
def store_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
80+
def store_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0):
8181
"""
82-
Write to a subarray T[batch_offset : batch_offset + batch_size] with
82+
Write to a subarray T[slice_off : slice_off + slice_size] with
8383
hardware bounds-checking, where writes outside the subarray are masked
8484
correctly.
8585
8686
Coords should be an appropriately-sized list of integers, just like in
8787
TMA.store().
8888
"""
8989

90-
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
90+
c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
9191
data = tl.reshape(data, [1, 1] + data.shape)
9292
TMA.store([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)
9393

9494

9595
@triton.jit
96-
def atomic_add_ragged(TMA, batch_offset, batch_size, coords, data, ragged_dim: tl.constexpr = 0):
96+
def atomic_add_ragged(TMA, slice_off, slice_size, coords, data, ragged_dim: tl.constexpr = 0):
9797
"""
98-
Atomic add into a subarray T[batch_offset : batch_offset + batch_size] with
98+
Atomic add into a subarray T[slice_off : slice_off + slice_size] with
9999
hardware bounds-checking, where adds outside the subarray are masked
100100
correctly.
101101
102102
Coords should be an appropriately-sized list of integers, just like in
103103
TMA.atomic_add().
104104
"""
105105

106-
c0, c1, c2 = to_ragged_indices(batch_offset, batch_size, coords[ragged_dim])
106+
c0, c1, c2 = to_ragged_indices(slice_off, slice_size, coords[ragged_dim])
107107
data = tl.reshape(data, [1, 1] + data.shape)
108108
TMA.atomic_add([c0, c1] + coords[:ragged_dim] + [c2] + coords[ragged_dim + 1:], data)

python/triton_kernels/bench/bench_mlp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton_kernels
88
import triton_kernels.roofline as roofline
99
import triton_kernels.swiglu
10-
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
10+
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1111
from triton_kernels.target_info import get_cdna_version
1212
import distributed as triton_dist
1313
from triton_kernels.tensor_details import layout
@@ -71,7 +71,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
7171

7272
input_x = torch.randn((batch // DP, dim1), device=dev)
7373
expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev))
74-
triton_dist.initialize_matmul_ogs(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype)
74+
triton_dist.initialize_matmul(batch, dim1, dim2, n_expts_act, n_expts_tot, input_x.dtype)
7575

7676
# run layer
7777
fpath = Path(tempfile.mktemp())
@@ -80,17 +80,16 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d
8080
xg = input_x.to(wg.dtype if n_expts_tot > 1 else input_x.dtype)
8181
for i in range(100):
8282
if n_expts_tot > 1: # sparse
83-
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
83+
logits = matmul(xg, wg, bg, precision_config=pcg)
8484
x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP,
8585
TP=TP, expt_assignment=expt_assignment,
8686
mode="ep_sharding")
8787
else: # dense
8888
x = triton_dist.all_gather(input_x, dim=0)
8989
rdata, gather_indx, scatter_indx, metadata = None, None, None, None
9090
if x.nelement() > 0:
91-
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
92-
x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx,
93-
precision_config=pc2)
91+
x = matmul(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
92+
x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, precision_config=pc2)
9493
x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment)
9594
proton.finalize()
9695
return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*")

python/triton_kernels/bench/distributed.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
import triton_kernels
1111
import triton_kernels.swiglu
1212
from triton_kernels.reduce import reduce
13-
from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx
1413
from triton_kernels.topk import topk
15-
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
14+
from triton_kernels.matmul import matmul, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1615
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq
1716
from triton_kernels.tensor_details import layout
18-
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata
17+
from triton_kernels.tensor import RaggedTensorMetadata, make_ragged_tensor_metadata, remap_ragged_tensor_metadata
1918
from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment, symm_mem_pool
2019

2120
from bench_utils import quantize_weight
@@ -40,7 +39,7 @@ def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> O
4039
return make_expt_assignment(EP, n_expts_tot, expt_dict, device)
4140

4241

43-
def initialize_matmul_ogs(
42+
def initialize_matmul(
4443
batch: int,
4544
dim1: int,
4645
dim2: int,
@@ -52,7 +51,7 @@ def initialize_matmul_ogs(
5251
return
5352
world_size = dist.get_world_size()
5453
device = torch.cuda.current_device()
55-
symm_mem_pool.initialize_matmul_ogs(
54+
symm_mem_pool.initialize_matmul(
5655
n_tokens_global=batch,
5756
d_input=dim1,
5857
d_model=dim2,
@@ -146,8 +145,7 @@ def routing(
146145
TP: int = 1,
147146
expt_assignment: Optional[ExptAssignment] = None,
148147
mode: Optional[str] = None,
149-
) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]:
150-
n_expts_tot = logits.shape[-1]
148+
) -> Tuple[torch.Tensor, RaggedTensorMetadata, torch.Tensor, torch.Tensor, Optional[ReduceScatterMetadata]]:
151149
if _is_distributed_launch() and mode:
152150
if mode == "ep_sharding":
153151
if not expt_assignment:
@@ -170,29 +168,24 @@ def routing(
170168
logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0])
171169
x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx)
172170
logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map)
173-
gate_scal = logits_global.vals.flatten()[combine_indx]
174-
rdata = RoutingData(gate_scal, expt_sizes, n_expts_tot // EP, n_expts_act, logits_local_metadata)
175171
reduce_scatter_metadata = ReduceScatterMetadata(
176172
mode=mode,
177173
active_indx=active_indx,
178174
dispatch_indx=dispatch_indx,
179175
combine_indx=combine_indx,
180176
)
181-
return x, rdata, None, None, reduce_scatter_metadata
177+
return x, logits_local_metadata, None, None, reduce_scatter_metadata
182178
else:
183179
raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.")
184180
else:
185181
# If mode is not specified or we have a single process, we do single-GPU routing.
186182
logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first)
187183
dispatch_indx = logits.mask_metadata.row_sorted_indx
188184
combine_indx = logits.mask_metadata.col_sorted_indx
189-
ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0])
190-
gate_scal = logits.vals.flatten()[combine_indx]
191-
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act,
192-
ragged_batch_metadata)
193-
gather_indx = GatherIndx(combine_indx, dispatch_indx)
194-
scatter_indx = ScatterIndx(dispatch_indx, combine_indx)
195-
return x, routing_data, gather_indx, scatter_indx, None
185+
ragged_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0])
186+
gather_indx = combine_indx // n_expts_act
187+
scatter_indx = combine_indx
188+
return x, ragged_metadata, gather_indx, scatter_indx, None
196189

197190

198191
def gather_ep(rank, world_size, param, TP, EP):
@@ -276,14 +269,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
276269
w1_full = w2_full = w1_flex_full = w2_flex_full = w1_scale_full = w2_scale_full = None
277270

278271
# precision configs
279-
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), weight_scale=wg_scale)
272+
pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=wg_flex), b_mx_scale=wg_scale)
280273
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit"), reduction_n=2),
281274
(1.0, 1.0))
282-
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), weight_scale=w1_scale)
283-
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), weight_scale=w2_scale)
275+
pc1 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex), b_mx_scale=w1_scale)
276+
pc2 = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex), b_mx_scale=w2_scale)
284277
if rank == 0:
285-
pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), weight_scale=w1_scale_full)
286-
pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), weight_scale=w2_scale_full)
278+
pc1_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex_full), b_mx_scale=w1_scale_full)
279+
pc2_full = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex_full), b_mx_scale=w2_scale_full)
287280
else:
288281
pc1_full = pc2_full = None
289282

@@ -296,7 +289,7 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
296289
xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype])
297290
x0 = all_gather(xd, dim=0)
298291
expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev))
299-
symm_mem_pool.initialize_matmul_ogs(
292+
symm_mem_pool.initialize_matmul(
300293
n_tokens_global=batch,
301294
d_input=dim1,
302295
d_model=dim2,
@@ -312,25 +305,25 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac
312305
def single(x):
313306
xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype)
314307
if n_expts_tot > 1:
315-
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
308+
logits = matmul(xg, wg, bg, precision_config=pcg)
316309
x, rdata, gi, si, _ = routing(x, logits, n_expts_act)
317310
else:
318311
rdata = gi = si = None
319-
x = matmul_ogs(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act)
320-
return matmul_ogs(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full)
312+
x = matmul(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act)
313+
return matmul(x, w2_full, b2_full, rdata, scatter_indx=si, precision_config=pc2_full)
321314

322315
# distributed pass
323316
def distributed(x):
324317
xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype)
325318
if n_expts_tot > 1: # sparse
326-
logits = matmul_ogs(xg, wg, bg, precision_config=pcg)
319+
logits = matmul(xg, wg, bg, precision_config=pcg)
327320
x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment,
328321
mode="ep_sharding")
329322
else: # dense
330323
x = all_gather(x, dim=0)
331324
rdata = gi = si = metadata = None
332-
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act)
333-
x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2)
325+
x = matmul(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act)
326+
x = matmul(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2)
334327
x = reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment)
335328
# gather the result from all GPUs, just for verification
336329
return all_gather(x, dim=0)

python/triton_kernels/tests/test_distributed.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from triton_kernels.distributed import convert_dp_to_ep, convert_ep_to_dp, make_expt_dict_uniform, make_expt_dict_random, make_expt_assignment, symm_mem_pool
1010
from triton_kernels.reduce import reduce
1111
from triton_kernels.topk import topk
12-
from triton_kernels.matmul_ogs import matmul_ogs, RoutingData, GatherIndx, ScatterIndx
12+
from triton_kernels.matmul import matmul
1313
from triton_kernels.target_info import is_hip
1414
from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata
1515
import pytest
@@ -122,17 +122,18 @@ def routing(logits, n_expts_act, all_gather=False, y_indx=None):
122122
dispatch_indx = sparse_logits.mask_metadata.row_sorted_indx
123123
combine_indx = sparse_logits.mask_metadata.col_sorted_indx
124124
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
125-
gate_scal = sparse_logits.vals.flatten()[combine_indx]
126-
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, logits.shape[-1], n_expts_act,
127-
ragged_batch_metadata)
128-
gather_idx = GatherIndx(combine_indx, dispatch_indx)
129-
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
130-
return routing_data, gather_idx, scatter_idx, sparse_logits.indx
125+
gather_idx = torch.div(combine_indx, n_expts_act, rounding_mode="trunc")
126+
scatter_idx = combine_indx
127+
return ragged_batch_metadata, gather_idx, scatter_idx, sparse_logits.indx
131128

132129

133130
def mixture_of_expt_nosharded(x_global, l_global, w_global, b_global, n_expts_act, y_indx=None):
134131
rdata, combine_indx, dispatch_indx, _ = routing(l_global, n_expts_act, y_indx=y_indx)
135-
y_global = matmul_ogs(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx)
132+
y_global = matmul(x_global, w_global, b_global, rdata, gather_indx=combine_indx, scatter_indx=dispatch_indx)
133+
y_mask = (dispatch_indx != -1).view(y_global.shape[-2] // n_expts_act, n_expts_act, 1)
134+
y_global = y_global.view(y_global.shape[-2] // n_expts_act, n_expts_act, -1)
135+
y_mask = y_mask.expand_as(y_global)
136+
y_global, _ = reduce(y_global, dim=1, mask=y_mask)
136137
return y_global
137138

138139

@@ -153,9 +154,7 @@ def mixture_of_expt_epsharded(x_dp_local, l_dp_local, w_ep_local, b_ep_local, ex
153154
y_ep_local = convert_dp_to_ep(x_dp_local, expt_assignment, active_indx, dispatch_indx)
154155
y_ep_local_metadata = remap_ragged_tensor_metadata(x_global_metadata, expt_map)
155156
# matrix multiply
156-
# TODO: clean-up API. `RoutingData` should not exist; we should be passing `y_ep_local_metadata`.
157-
rdata_ep_local = RoutingData(None, expt_sizes, w_ep_local.shape[0], n_expts_act, y_ep_local_metadata)
158-
y_ep_local = matmul_ogs(y_ep_local, w_ep_local, b_ep_local, rdata_ep_local)
157+
y_ep_local = matmul(y_ep_local, w_ep_local, b_ep_local, a_ragged_metadata=y_ep_local_metadata)
159158
# convert x from expert-sorted, ep-local to token-sorted, dp-local
160159
y_dp_local = convert_ep_to_dp(y_ep_local, expt_assignment, active_indx, combine_indx)
161160
# weighted average of the output token from experts
@@ -208,7 +207,7 @@ def run_mixture():
208207
y_indx=y_indx_global,
209208
)
210209

211-
symm_mem_pool.initialize_matmul_ogs(
210+
symm_mem_pool.initialize_matmul(
212211
n_tokens_global=n_tokens_global,
213212
d_input=d_model,
214213
d_model=d_model,

0 commit comments

Comments
 (0)