Skip to content

Commit 586824a

Browse files
Merge commit '67af519ec69331a3d4e2fc2cd9d45e0165a849a1'
2 parents ba0c584 + 67af519 commit 586824a

File tree

10 files changed

+144
-62
lines changed

10 files changed

+144
-62
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ PYTHON ?= python
66
BUILD_DIR := $(shell cd python; $(PYTHON) -c 'from build_helpers import get_cmake_dir; print(get_cmake_dir())')
77
TRITON_OPT := $(BUILD_DIR)/bin/triton-opt
88
PYTEST := $(PYTHON) -m pytest
9-
LLVM_BUILD_PATH ?= $(realpath .llvm-project/build)
9+
LLVM_BUILD_PATH ?= "$(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))/.llvm-project/build"
1010
NUM_PROCS ?= 8
1111

1212
# Incremental builds

lib/Conversion/TritonInstrumentToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ add_triton_library(TritonInstrumentToLLVM
88
TritonGPUIR
99
TritonInstrumentIR
1010
TritonNvidiaGPUIR
11+
NVGPUIR
1112
)

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/Partition.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h"
33
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
44
#include "llvm/ADT/SCCIterator.h"
5+
#include "llvm/IR/Use.h"
56

67
using namespace mlir;
78
using namespace triton;
@@ -119,16 +120,14 @@ bool WarpSchedule::trySchedule(Partition *partition, Operation *op) {
119120

120121
FailureOr<WarpSchedule> WarpSchedule::deserialize(scf::ForOp loop) {
121122
auto stages = loop->getAttrOfType<ArrayAttr>(kPartitionStagesAttrName);
122-
if (!stages) {
123-
return mlir::emitWarning(loop.getLoc(), "missing '")
124-
<< kPartitionStagesAttrName << "' attribute";
125-
}
123+
if (!stages)
124+
return failure();
126125

127126
WarpSchedule result;
128127
for (auto [idx, attr] : llvm::enumerate(stages)) {
129128
auto stage = dyn_cast<IntegerAttr>(attr);
130129
if (!stage || stage.getInt() < 0) {
131-
return mlir::emitWarning(loop.getLoc(), "partition stages attribute '")
130+
return mlir::emitError(loop.getLoc(), "partition stages attribute '")
132131
<< kPartitionStagesAttrName << "' has invalid element " << attr;
133132
}
134133

@@ -140,10 +139,8 @@ FailureOr<WarpSchedule> WarpSchedule::deserialize(scf::ForOp loop) {
140139
Partition *partition = result.getRootPartition();
141140
if (auto attr = op.getAttrOfType<IntegerAttr>(kPartitionAttrName)) {
142141
int64_t idx = attr.getInt();
143-
if (idx < 0 || idx >= result.partitions.size()) {
144-
return mlir::emitWarning(op.getLoc(), "invalid partition index ")
145-
<< idx;
146-
}
142+
if (idx < 0 || idx >= result.partitions.size())
143+
return mlir::emitError(op.getLoc(), "invalid partition index ") << idx;
147144
partition = result.partitions[idx].get();
148145
}
149146
result.insert(partition, &op);

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ static void scheduleUsers(scf::ForOp loop, WarpSchedule &schedule,
149149
// first-order partition assignment to the operations in the scheme and its
150150
// users and/or dependencies. This sets up the initial partitioning of the ops.
151151
static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
152-
WarpSchedule schedule;
152+
// Check for an existing schedule.
153+
if (FailureOr<WarpSchedule> scheduleOr = WarpSchedule::deserialize(loop);
154+
succeeded(scheduleOr))
155+
return {std::move(*scheduleOr)};
153156

154157
// Start by creating the default partition, a partition for for all loads, and
155158
// a partition for all MMAs.
159+
WarpSchedule schedule;
156160
Partition *defaultPartition = schedule.addPartition(0);
157161
Partition *mmaPartition = schedule.addPartition(1);
158162
Partition *loadPartition = schedule.addPartition(0);
@@ -479,6 +483,39 @@ void propagatePartitions(scf::ForOp loop, WarpSchedule &schedule) {
479483
}
480484
}
481485

486+
// Rematerialize chains of broadcasts where the user is in a different partition
487+
// than the broadcast to reduce the amount of data that needs to be transferred.
488+
void rematerializeBroadcasts(WarpSchedule &schedule, OpOperand *use) {
489+
static_assert(
490+
std::is_base_of_v<OpTrait::OneResult<BroadcastOp>, BroadcastOp> &&
491+
std::is_base_of_v<OpTrait::OneResult<ExpandDimsOp>, ExpandDimsOp>);
492+
493+
Operation *defOp = use->get().getDefiningOp();
494+
while (isa_and_nonnull<BroadcastOp, ExpandDimsOp>(defOp)) {
495+
Operation *clone = OpBuilder(defOp).clone(*defOp);
496+
Partition *userPartition = schedule.getPartition(use->getOwner());
497+
assert(userPartition && "user not scheduled");
498+
schedule.insert(userPartition, clone);
499+
use->set(clone->getResult(0));
500+
501+
defOp = clone->getOperand(0).getDefiningOp();
502+
use = &clone->getOpOperand(0);
503+
}
504+
}
505+
506+
void optimizeSchedule(scf::ForOp loop, WarpSchedule &schedule) {
507+
for (Partition &partition : schedule.getPartitions()) {
508+
SmallVector<OpOperand *> uses;
509+
schedule.iterateOutputs(loop, &partition,
510+
[&](Operation *defOp, OpOperand &use) {
511+
if (!isa<scf::YieldOp>(use.getOwner()))
512+
uses.push_back(&use);
513+
});
514+
for (OpOperand *use : uses)
515+
rematerializeBroadcasts(schedule, use);
516+
}
517+
}
518+
482519
//===----------------------------------------------------------------------===//
483520
// Pass Definition
484521
//===----------------------------------------------------------------------===//
@@ -507,6 +544,7 @@ void PartitionScheduling::runOnOperation() {
507544
for (scf::ForOp loop : loops) {
508545
if (std::optional<WarpSchedule> schedule = getInitialSchedule(loop)) {
509546
propagatePartitions(loop, *schedule);
547+
optimizeSchedule(loop, *schedule);
510548
schedule->serialize(loop);
511549
}
512550
}

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/RewritePartitionDependencies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct UseInfo {
4848
int UseInfo::getMaxUseDistance(const Partition &partition) {
4949
int maxDistance = 0;
5050
for (auto [usePartition, distance] : llvm::make_first_range(consumers)) {
51-
int dist = 2 + distance;
51+
int dist = 1 + distance;
5252
maxDistance = std::max(maxDistance, dist);
5353
}
5454
return maxDistance;

python/test/unit/language/test_core.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7276,11 +7276,13 @@ def mul_add(data):
72767276
# -----------------------
72777277

72787278

7279-
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90"])
7279+
@pytest.mark.parametrize("arch", ["sm70", "sm80", "sm90", "gfx942", "gfx950", "gfx1200"])
72807280
@pytest.mark.parametrize("env_var_override", [False, True])
72817281
def test_override_arch(arch, env_var_override, device):
7282-
if not is_cuda():
7283-
pytest.xfail('arch only for CUDA')
7282+
if arch.startswith("sm") and not is_cuda():
7283+
pytest.xfail(f"{arch} arch only for CUDA")
7284+
elif arch.startswith("gfx") and not is_hip():
7285+
pytest.xfail(f"{arch} arch only for HIP")
72847286

72857287
@triton.jit
72867288
def simple(data, out):
@@ -7291,15 +7293,31 @@ def simple(data, out):
72917293
data = torch.randn((128, ), device=device, dtype=torch.float32)
72927294
out = torch.empty_like(data)
72937295

7294-
if env_var_override:
7295-
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
7296-
h = simple[(1, )](data, out)
7297-
os.environ.pop("TRITON_OVERRIDE_ARCH")
7298-
else:
7299-
h = simple[(1, )](data, out, arch=arch)
7300-
torch.testing.assert_close(data * 1.5 + 1.0, out)
7301-
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
7302-
assert ttgir_cc.group(1) == arch[2:]
7296+
if is_cuda():
7297+
if env_var_override:
7298+
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
7299+
h = simple[(1, )](data, out)
7300+
os.environ.pop("TRITON_OVERRIDE_ARCH")
7301+
else:
7302+
h = simple[(1, )](data, out, arch=arch)
7303+
torch.testing.assert_close(data * 1.5 + 1.0, out)
7304+
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
7305+
assert ttgir_cc.group(1) == arch[2:]
7306+
elif is_hip():
7307+
# For HIP, the generated kernel is a binary containing the final ISA. So we cannot run
7308+
# them like CUDA side if the chip doesn't match. Here we just check generated ISA.
7309+
if env_var_override:
7310+
os.environ["TRITON_OVERRIDE_ARCH"] = str(arch)
7311+
h = simple.warmup(data, out, grid=(1, ))
7312+
os.environ.pop("TRITON_OVERRIDE_ARCH")
7313+
else:
7314+
h = simple.warmup(data, out, arch=arch, grid=(1, ))
7315+
ttgir_gfx = re.search(r'hip:(\w+)', h.asm["ttgir"])
7316+
ttgir_warp = re.search(r'"ttg.threads-per-warp" = (\d+)', h.asm["ttgir"])
7317+
amdgcn_gfx = re.search(r'.amdgcn_target "amdgcn-amd-amdhsa--(\w+)"', h.asm["amdgcn"])
7318+
assert ttgir_gfx.group(1) == arch
7319+
assert int(ttgir_warp.group(1)) == (32 if arch == "gfx1200" else 64)
7320+
assert amdgcn_gfx.group(1) == arch
73037321

73047322

73057323
# -----------------------

python/triton_kernels/tests/test_specialize.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,36 @@ def cache_hook(*args, **kwargs):
5353
fn_name = kwargs["fn"].name
5454
module_name = kwargs["fn"].module
5555

56-
triton.knobs.runtime.jit_cache_hook = cache_hook
57-
o = torch.empty((1, ), dtype=torch.float32, device=device)
58-
k = specialized_kernel[(1, )](o, )
59-
hash = k.hash
60-
assert o.item() == 1.0
61-
assert module_name == "tests.test_specialize"
62-
assert fn_name == "cacheable_kernel"
63-
64-
compile_count = 0
65-
66-
def count_hook(*args, **kwargs):
67-
nonlocal compile_count
68-
compile_count += 1
69-
70-
triton.knobs.runtime.jit_cache_hook = count_hook
71-
# clear the cache
72-
specialized_kernel.device_caches.clear()
73-
74-
# retrieve the kernel from name and preload it.
75-
fn = retrieve_fn(module_name, fn_name)
76-
assert fn == specialized_kernel
77-
preload = fn.preload(specialization_data)
78-
assert compile_count == 1
79-
assert preload.hash == hash
80-
81-
# verify that we hit the cache.
82-
compile_count = 0
83-
specialized_kernel[(1, )](o, )
84-
assert compile_count == 0
56+
prev_hook = triton.knobs.runtime.jit_cache_hook
57+
try:
58+
triton.knobs.runtime.jit_cache_hook = cache_hook
59+
o = torch.empty((1, ), dtype=torch.float32, device=device)
60+
k = specialized_kernel[(1, )](o, )
61+
hash = k.hash
62+
assert o.item() == 1.0
63+
assert module_name == "tests.test_specialize"
64+
assert fn_name == "cacheable_kernel"
65+
66+
compile_count = 0
67+
68+
def count_hook(*args, **kwargs):
69+
nonlocal compile_count
70+
compile_count += 1
71+
72+
triton.knobs.runtime.jit_cache_hook = count_hook
73+
# clear the cache
74+
specialized_kernel.device_caches.clear()
75+
76+
# retrieve the kernel from name and preload it.
77+
fn = retrieve_fn(module_name, fn_name)
78+
assert fn == specialized_kernel
79+
preload = fn.preload(specialization_data)
80+
assert compile_count == 1
81+
assert preload.hash == hash
82+
83+
# verify that we hit the cache.
84+
compile_count = 0
85+
specialized_kernel[(1, )](o, )
86+
assert compile_count == 0
87+
finally:
88+
triton.knobs.runtime.jit_cache_hook = prev_hook

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis
5252
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
5353

5454
BLOCK_OUT_DIM = 128
55-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
55+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
5656
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
5757
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
5858

@@ -93,7 +93,7 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dty
9393
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
9494
reshaped_scale = scale.view(-1, scale.shape[-1])
9595
BLOCK_OUT_DIM = 128
96-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
96+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
9797
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
9898
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
9999
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,

test/TritonGPU/partition-scheduling.mlir

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ tt.func public @attention_forward(
2828
%zero = arith.constant dense<0.0> : tensor<256x64xf32, #blocked>
2929
%one = arith.constant dense<1.0> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
3030

31-
%QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
3231

3332
%loop_outs:4 = scf.for %i = %c0_i32 to %n_tiles step %c64_i32 iter_args(
3433
%l_i = %one,
@@ -46,6 +45,7 @@ tt.func public @attention_forward(
4645
%K_shared = ttg.local_alloc %K : (tensor<64x64xf16, #load_blocked>) -> !ttg.memdesc<64x64xf16, #shared, #smem>
4746

4847
%K_trans = ttg.memdesc_trans %K_shared {order = array<i32: 1, 0>} : !ttg.memdesc<64x64xf16, #shared, #smem> -> !ttg.memdesc<64x64xf16, #shared_T, #smem>
48+
%QK_tmem, %QK_tok = ttng.tmem_alloc : () -> (!ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>, !ttg.async.token)
4949
%QK_mma_tok = ttng.tc_gen5_mma %Q_shared, %K_trans, %QK_tmem[%QK_tok], %false, %true : !ttg.memdesc<256x64xf16, #shared, #smem>, !ttg.memdesc<64x64xf16, #shared_T, #smem>, !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable>
5050

5151
%QK, %QK_load_tok = ttng.tmem_load %QK_tmem[%QK_mma_tok] : !ttg.memdesc<256x64xf32, #tmem_acc, #ttng.tensor_memory, mutable> -> tensor<256x64xf32, #blocked>
@@ -138,4 +138,28 @@ tt.func public @mma_operand_view(
138138
tt.return
139139
}
140140

141+
// CHECK-LABEL: @optimize_broadcast
142+
tt.func @optimize_broadcast(%arg0: i32) {
143+
%c0_i32 = arith.constant 0 : i32
144+
%c1_i32 = arith.constant 1 : i32
145+
// CHECK: scf.for
146+
scf.for %i = %c0_i32 to %arg0 step %c1_i32 : i32 {
147+
// CHECK: [[X:%.*]] = "producer"{{.*}}partition = 0
148+
%x = "producer"() {ttg.partition = 0 : i32} : () -> tensor<128xf32>
149+
150+
// CHECK-DAG: [[X0_P0:%.*]] = tt.expand_dims [[X]] {{.*}}partition = 0
151+
// CHECK-DAG: [[X0_P1:%.*]] = tt.expand_dims [[X]] {{.*}}partition = 1
152+
%x0 = tt.expand_dims %x {axis = 0 : i32} : tensor<128xf32> -> tensor<1x128xf32>
153+
// CHECK-DAG: [[X1_P0:%.*]] = tt.broadcast [[X0_P0]] {{.*}}partition = 0
154+
// CHECK-DAG: [[X1_P1:%.*]] = tt.broadcast [[X0_P1]] {{.*}}partition = 1
155+
%x1 = tt.broadcast %x0 : tensor<1x128xf32> -> tensor<128x128xf32>
156+
157+
// CHECK: "use"([[X1_P0]]) {{.*}}partition = 0
158+
"use"(%x1) {ttg.partition = 0 : i32} : (tensor<128x128xf32>) -> ()
159+
// CHECK: "use"([[X1_P1]]) {{.*}}partition = 1
160+
"use"(%x1) {ttg.partition = 1 : i32} : (tensor<128x128xf32>) -> ()
161+
} {tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32]}
162+
tt.return
163+
}
164+
141165
}

test/TritonGPU/rewrite-partition-dependencies.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ module attributes {"ttg.num-warps" = 4 : i32} {
1010
// CHECK-LABEL: @two_consumers
1111
tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) {
1212
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
13-
// CHECK-NEXT: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
13+
// CHECK-NEXT: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
1414
// CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[ABUF]]
1515
scf.for %i = %lb to %ub step %step iter_args() -> () : i32 {
1616
%0 = "op_a"() {ttg.partition = 0} : () -> !ty
@@ -40,7 +40,7 @@ tt.func @two_consumers(%lb: i32, %ub: i32, %step: i32) {
4040
// CHECK-LABEL: @distance_one
4141
tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) {
4242
// CHECK: [[C0:%.*]] = arith.constant 0 : i32
43-
// CHECK: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
43+
// CHECK: [[ABUF:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
4444
// CHECK-NEXT: [[AREF:%.*]] = nvws.aref.create [[ABUF]]
4545
%cst = arith.constant dense<0> : !ty
4646
// CHECK: scf.for [[IV:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[K:%.*]] = {{.*}})
@@ -63,9 +63,9 @@ tt.func @distance_one(%lb: i32, %ub: i32, %step: i32) {
6363
}
6464

6565
tt.func @complex_case(%lb: i32, %ub: i32, %step: i32) {
66-
// CHECK: [[ABUF1:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
66+
// CHECK: [[ABUF1:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
6767
// CHECK-NEXT: [[AREF1:%.*]] = nvws.aref.create [[ABUF1]]
68-
// CHECK-NEXT: [[ABUF2:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<2x1xi32, {{.*}}>
68+
// CHECK-NEXT: [[ABUF2:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x1xi32, {{.*}}>
6969
// CHECK-NEXT: [[AREF2:%.*]] = nvws.aref.create [[ABUF2]]
7070
%cst = arith.constant dense<0> : !ty
7171
// CHECK: scf.for [[IV:%.*]] = [[LB:%.*]] to [[UB:%.*]] step [[STEP:%.*]] iter_args([[K:%.*]] = {{.*}}, [[L:%.*]] = {{.*}})
@@ -337,7 +337,7 @@ tt.func @no_def_op(%lb: i32, %ub: i32, %step: i32) {
337337
module attributes {"ttg.num-warps" = 4 : i32} {
338338

339339
tt.func @invalid_attribute(%lb: i32, %ub: i32, %step: i32) {
340-
// expected-warning @below {{partition stages attribute 'ttg.partition.stages' has invalid element "a"}}
340+
// expected-error @below {{partition stages attribute 'ttg.partition.stages' has invalid element "a"}}
341341
scf.for %i = %lb to %ub step %step : i32 {
342342
scf.yield
343343
} {ttg.partition.stages = ["a"]}
@@ -359,7 +359,7 @@ module attributes {"ttg.num-warps" = 4 : i32} {
359359

360360
tt.func @invalid_attribute(%lb: i32, %ub: i32, %step: i32) {
361361
scf.for %k = %lb to %ub step %step : i32 {
362-
// expected-warning @below {{invalid partition index -1}}
362+
// expected-error @below {{invalid partition index -1}}
363363
"op"() {ttg.partition = -1} : () -> ()
364364
scf.yield
365365
} {ttg.partition.stages = [2, 2]}

0 commit comments

Comments
 (0)