Skip to content

Commit 8eb3861

Browse files
Merge commit '16961b79bdac1b774b42d44e52fd55a266ec2866'
2 parents d545f3d + 16961b7 commit 8eb3861

File tree

20 files changed

+355
-190
lines changed

20 files changed

+355
-190
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
543543
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
544544
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
545545

546+
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
547+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
548+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
549+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
550+
Value laneId, Value warpId,
551+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
552+
546553
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
547554
Type elemLlvmTy,
548555
const SharedMemoryObject &smemObj,

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
2222
class CoarseSchedule;
23-
23+
class ModuleAxisInfoAnalysis;
2424
//===----------------------------------------------------------------------===//
2525
// Hoisting Utilities
2626
//===----------------------------------------------------------------------===//
@@ -87,6 +87,9 @@ std::pair<Operation *, int64_t> getDefiningOpAndDistance(scf::ForOp forOp,
8787
int getCopyVecBytes(RankedTensorType registerTy,
8888
gpu::SharedEncodingTrait sharedEnc);
8989

90+
bool canBeConvertedToAsyncLoad(
91+
triton::LoadOp loadOp, triton::ModuleAxisInfoAnalysis &axisInfoAnalysis);
92+
9093
// Serialize the latencies of the operations in the loops into the latency
9194
// attribute.
9295
void serializeLatencies(ModuleOp module, DenseMap<Operation *, int> &opLatency);

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ bool emitTransferBetweenRegistersAndShared(
417417
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
418418
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
419419
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
420+
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
421+
return emitTransferBetweenRegistersAndShared(
422+
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
423+
target, laneId, warpId, perVectorCallback);
424+
}
425+
426+
bool emitTransferBetweenRegistersAndShared(
427+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
428+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
429+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
430+
Value laneId, Value warpId,
431+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
420432
MLIRContext *ctx = rewriter.getContext();
421433
auto b = TritonLLVMOpBuilder(loc, rewriter);
422434

@@ -458,7 +470,6 @@ bool emitTransferBetweenRegistersAndShared(
458470
maxVecElems.value_or(std::numeric_limits<int>::max()));
459471

460472
auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1;
461-
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
462473
Value blockId =
463474
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0);
464475

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -106,34 +106,10 @@ class AssignLoadLatencies {
106106
return !incompatible;
107107
}
108108

109-
bool isSmallLoad(tt::LoadOp loadOp,
110-
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
111-
assert(!isLoadFromTensorPtr(loadOp) &&
112-
"Block ptr should have been lowered before this pass.");
113-
auto ptr = loadOp.getPtr();
114-
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
115-
if (auto mask = loadOp.getMask())
116-
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
117-
118-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
119-
if (!tensorTy)
120-
return true;
121-
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
122-
unsigned width = vec * ty.getIntOrFloatBitWidth();
123-
124-
// We do not pipeline all loads for the following reasons:
125-
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16.
126-
// 2. It's likely that pipling small loads won't offer much performance
127-
// improvement and may even hurt performance by increasing register
128-
// pressure.
129-
LDBG("Load " << *loadOp << " has width " << width);
130-
return width < 32;
131-
}
132-
133109
bool isPipeliningBeneficial(Operation *op, Operation *finalUser,
134110
tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
135111
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
136-
if (isSmallLoad(loadOp, axisInfoAnalysis)) {
112+
if (!canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) {
137113
LDBG("Load " << *loadOp << " is too small for pipelining");
138114
return false;
139115
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "mlir/Dialect/UB/IR/UBOps.h"
22
#include "mlir/IR/Dominance.h"
33
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4+
#include "triton/Analysis/AxisInfo.h"
45
#include "triton/Analysis/Utility.h"
56
#include "triton/Dialect/Triton/IR/Dialect.h"
67
#include "triton/Dialect/Triton/IR/Types.h"
@@ -441,7 +442,8 @@ bool loadRequiresAdditionalBuffer(Operation *loadOp) {
441442
return false;
442443
}
443444

444-
scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
445+
scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule,
446+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
445447
llvm::MapVector<Operation *, AsyncLoad> asyncLoads;
446448
llvm::MapVector<int, LoadGroupInfo> loadGroups;
447449
// Only visit the top level ops, we do not support pipelining conditional
@@ -457,9 +459,13 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule) {
457459
SharedEncodingTrait sharedEncoding = getSharedEncoding(&op);
458460
// Do not create async loads for small loads (cp.async requires at least 4
459461
// bytes)
462+
bool canUseAsyncCp =
463+
isa<tt::LoadOp>(op) &&
464+
canBeConvertedToAsyncLoad(cast<tt::LoadOp>(op), axisInfoAnalysis);
460465
int copyVecBytes = getCopyVecBytes(
461466
cast<RankedTensorType>(op.getResultTypes()[0]), sharedEncoding);
462-
if (copyVecBytes >= 4 || isTMALoad(&op)) {
467+
canUseAsyncCp &= copyVecBytes >= 4;
468+
if (canUseAsyncCp || isTMALoad(&op)) {
463469
if (loadRequiresAdditionalBuffer(&op)) {
464470
// Allocate additional buffer required by the wgmma pipelining.
465471
stageDiff += 1;
@@ -1008,26 +1014,28 @@ scf::ForOp lowerMMAs(scf::ForOp forOp, CoarseSchedule &schedule) {
10081014
// LOWER LOOP
10091015
/////////////////////////////
10101016

1011-
void lowerLoop(scf::ForOp forOp) {
1017+
void lowerLoop(scf::ForOp forOp,
1018+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
10121019
CoarseSchedule schedule;
10131020
if (failed(schedule.deSerialize(forOp))) {
10141021
return;
10151022
}
10161023
scf::ForOp newForOp = lowerMMAs(forOp, schedule);
1017-
newForOp = lowerLoads(newForOp, schedule);
1024+
newForOp = lowerLoads(newForOp, schedule, axisInfoAnalysis);
10181025
newForOp = lowerTMADescriptors(newForOp, schedule);
10191026
schedule.serialize(newForOp);
10201027
}
10211028

10221029
} // namespace
10231030

10241031
void lowerLoops(ModuleOp moduleOp) {
1032+
triton::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
10251033
SmallVector<scf::ForOp> loops;
10261034
moduleOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
10271035
if (loops.empty())
10281036
return;
10291037
for (auto forOp : loops) {
1030-
lowerLoop(forOp);
1038+
lowerLoop(forOp, axisInfoAnalysis);
10311039
}
10321040
}
10331041

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/IR/TypeUtilities.h"
77
#include "mlir/Interfaces/SideEffectInterfaces.h"
88
#include "mlir/Support/LLVM.h"
9+
#include "triton/Analysis/AxisInfo.h"
910
#include "triton/Dialect/Triton/IR/Utility.h"
1011
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1112
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -14,8 +15,13 @@
1415
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1516
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
1617
#include "llvm/Support/Casting.h"
18+
#include "llvm/Support/Debug.h"
1719
#include <queue>
1820

21+
#define DEBUG_TYPE "triton-loop-pipeline"
22+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
23+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
24+
1925
using namespace mlir;
2026
namespace tt = mlir::triton;
2127
namespace ttg = mlir::triton::gpu;
@@ -313,6 +319,30 @@ int mlir::triton::getCopyVecBytes(RankedTensorType registerTy,
313319
return vecElems * registerTy.getElementTypeBitWidth() / 8;
314320
}
315321

322+
bool mlir::triton::canBeConvertedToAsyncLoad(
323+
tt::LoadOp loadOp, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis) {
324+
assert(!isLoadFromTensorPtr(loadOp) &&
325+
"Block ptr should have been lowered before this pass.");
326+
auto ptr = loadOp.getPtr();
327+
unsigned vec = axisInfoAnalysis.getContiguity(ptr);
328+
if (auto mask = loadOp.getMask())
329+
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
330+
331+
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
332+
if (!tensorTy)
333+
return false;
334+
auto ty = cast<tt::PointerType>(tensorTy.getElementType()).getPointeeType();
335+
unsigned width = vec * ty.getIntOrFloatBitWidth();
336+
337+
// We do not pipeline all loads for the following reasons:
338+
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8, or 16.
339+
// 2. It's likely that pipling small loads won't offer much performance
340+
// improvement and may even hurt performance by increasing register
341+
// pressure.
342+
LDBG("Load " << *loadOp << " has width " << width);
343+
return width >= 32;
344+
}
345+
316346
void mlir::triton::serializeLatencies(ModuleOp module,
317347
DenseMap<Operation *, int> &opLatency) {
318348
auto helper = TritonDialect::getLoaded(module)->getLatencyAttrHelper();

python/src/llvm.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ void init_triton_llvm(py::module &&m) {
316316
CGSCCAnalysisManager cgam;
317317
ModuleAnalysisManager mam;
318318

319+
if (arch.empty()) {
320+
llvm::TargetLibraryInfoImpl TLII;
321+
TLII.disableAllFunctions();
322+
fam.registerPass([TLII = std::move(TLII)] {
323+
return llvm::TargetLibraryAnalysis(TLII);
324+
});
325+
}
326+
319327
PassInstrumentationCallbacks *instrCbPtr = nullptr;
320328
PassInstrumentationCallbacks passInstrCb;
321329
StandardInstrumentations standardInstr(mod->getContext(),

python/test/gluon/test_frontend.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton._filecheck import filecheck_test, run_parser
1414
import triton.language as tl
1515
from triton._internal_testing import is_cuda
16-
from triton.compiler.errors import CompilationError
16+
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
1717

1818
TARGET_PAT = re.compile('ttg.target = "[^"]*"')
1919

@@ -604,10 +604,10 @@ def kernel():
604604
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
605605
tt.func public @kernel() attributes {noinline = false} {
606606
%0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable>
607-
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
607+
tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> ()
608608
tt.return
609609
}
610-
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_constexpr[1]_constexpr[0]____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(constexpr_1_ ,constexpr_0_), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
610+
tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0____SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1 ,0), ctas_per_cga=None, cta_split_num=None, cta_order=None)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} {
611611
tt.return
612612
}
613613
}
@@ -855,7 +855,7 @@ def test_tensor_permute():
855855
def test_split_join():
856856
# CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
857857
# CHECK: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
858-
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
858+
layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], [1], [1], [0])
859859
a = ttgl.full([128], 1, ttgl.int32, layout)
860860
b = ttgl.full([128], 2, ttgl.int32, layout)
861861
# CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]>
@@ -883,6 +883,16 @@ def test_tensor_reshape():
883883
ttgl.static_assert(v.type.layout == expect_layout)
884884

885885

886+
@gluon.jit
887+
def static_assert_kernel():
888+
ttgl.static_assert(False)
889+
890+
891+
def test_static_assert():
892+
with pytest.raises(CompileTimeAssertionFailure):
893+
run_parser(static_assert_kernel)
894+
895+
886896
@filecheck_test
887897
@gluon.jit
888898
def test_zeros():

python/triton/compiler/code_generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ def _apply_binary_method(self, method_name, lhs, rhs):
637637
if _is_triton_tensor(rhs):
638638
reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name)
639639
return getattr(rhs, reverse_method_name)(lhs, _semantic=self.semantic)
640+
if not isinstance(lhs, (constexpr, language.tuple)) and isinstance(rhs, constexpr):
641+
lhs = constexpr(lhs)
640642
return getattr(lhs, method_name)(rhs)
641643

642644
def visit_BinOp(self, node):
@@ -1457,9 +1459,12 @@ def ret(self, node: ast.Call):
14571459

14581460
return ret
14591461

1462+
from ..experimental.gluon import language as ttgl
14601463
statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = {
14611464
language.core.static_assert: execute_static_assert,
14621465
language.core.static_print: static_executor(print),
1466+
ttgl.static_assert: execute_static_assert,
1467+
ttgl.static_print: static_executor(print),
14631468
int: static_executor(int),
14641469
len: static_executor(len),
14651470
}

python/triton/compiler/compiler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1111
from ..runtime.driver import driver
1212
from ..tools.disasm import get_sass, get_spvdis
13-
# TODO: this shouldn't be here
14-
from .code_generator import ast_to_ttir
1513
from pathlib import Path
1614
import re
1715
import functools
@@ -81,6 +79,7 @@ def hash(self):
8179
return hashlib.sha256(key.encode("utf-8")).hexdigest()
8280

8381
def make_ir(self, options, codegen_fns, module_map, context):
82+
from .code_generator import ast_to_ttir
8483
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
8584
module_map=module_map)
8685

0 commit comments

Comments
 (0)