Skip to content

Commit 2db25c5

Browse files
Merge commit 'abd3bb0679dc12bad4e5e2d560211f8034b2f00b'
2 parents 7a156e7 + abd3bb0 commit 2db25c5

File tree

24 files changed

+462
-350
lines changed

24 files changed

+462
-350
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ llvm-project-*/
1010
dist/
1111
triton*.egg-info/
1212
*.whl
13+
python/triton_kernels/triton*.egg-info/
1314

1415
python/triton/_C/*.pyd
1516
python/triton/_C/*.so

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6161
mlir::registerAllPasses();
6262
mlir::triton::registerTritonPasses();
6363
mlir::triton::gpu::registerTritonGPUPasses();
64-
mlir::registerTritonNvidiaGPUPasses();
64+
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
6565
mlir::test::intel::registerTestAxisInfoPass();
6666
mlir::test::registerTestAliasPass();
6767
mlir::test::registerTestAlignmentPass();

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,38 +38,19 @@ struct ClusterInfo {
3838
int clusterDimZ;
3939
};
4040

41-
} // namespace nvidia_gpu
42-
} // namespace triton
43-
} // namespace mlir
44-
45-
namespace mlir {
46-
4741
std::unique_ptr<Pass> createTritonNvidiaGPUPlanCTAPass(
4842
mlir::triton::nvidia_gpu::ClusterInfo *clusterInfo = nullptr);
4943

50-
std::unique_ptr<Pass>
51-
createTritonNvidiaGPUFenceInsertionPass(int computeCapability = 90);
52-
53-
std::unique_ptr<Pass> createTritonNvidiaGPUTMALoweringPass();
54-
55-
std::unique_ptr<Pass> createTensorMemoryAllocationPass();
56-
57-
std::unique_ptr<Pass> createTritonNvidiaGPUMMALoweringPass();
58-
59-
std::unique_ptr<Pass> createTritonNvidiaGPUPromoteLHSToTMemPass();
60-
61-
std::unique_ptr<Pass> createTritonNvidiaGPURemoveTMEMTokensPass();
62-
63-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeDescriptorEncodingPass();
64-
65-
std::unique_ptr<Pass> createTritonNvidiaGPUOptimizeTMemLayoutsPass();
66-
67-
std::unique_ptr<Pass> createTritonNvidiaGPUInterleaveTMemPass();
44+
#define GEN_PASS_DECL
45+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
6846

6947
/// Generate the code for registering passes.
7048
#define GEN_PASS_REGISTRATION
7149
#define GEN_PASS_DECL_TRITONNVIDIAGPULEGALIZETMALAYOUTS
7250
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
7351

52+
} // namespace nvidia_gpu
53+
} // namespace triton
7454
} // namespace mlir
55+
7556
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_PASSES_H_

include/triton/Dialect/TritonNvidiaGPU/Transforms/Passes.td

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def TritonGPUPlanCTAPass : Pass<"triton-nvidia-gpu-plan-cta", "mlir::ModuleOp">
3232
and StoreLikeOps operations.
3333
}];
3434

35-
let constructor = "mlir::createTritonNvidiaGPUPlanCTAPass()";
35+
let constructor = "mlir::triton::nvidia_gpu::createTritonNvidiaGPUPlanCTAPass()";
3636

3737
let dependentDialects = [
3838
"mlir::triton::gpu::TritonGPUDialect",
@@ -48,8 +48,6 @@ def TritonGPUFenceInsertion : Pass<"triton-nvidia-gpu-fence-insertion", "mlir::M
4848
properly ordered across generic and async operations.
4949
}];
5050

51-
let constructor = "mlir::createTritonNvidiaGPUFenceInsertionPass()";
52-
5351
let dependentDialects = [
5452
"mlir::triton::gpu::TritonGPUDialect",
5553
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
@@ -69,22 +67,18 @@ def TritonNvidiaGPUTMALoweringPass : Pass<"triton-nvidia-tma-lowering", "mlir::M
6967
Lower Triton experimental descriptor load to TMA load/store operations in TritonNvidiaGPUDialect.
7068
}];
7169

72-
let constructor = "mlir::createTritonNvidiaGPUTMALoweringPass()";
73-
7470
let dependentDialects = [
7571
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
7672
];
7773
}
7874

79-
def TritionTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
75+
def TritonTensorMemoryAllocationPass : Pass<"triton-tensor-memory-allocation", "mlir::ModuleOp"> {
8076
let summary = "Assign tensor memory allocation";
8177

8278
let description = [{
8379
Decide on tensor memory allocation and assign attributes to each allocation.
8480
}];
8581

86-
let constructor = "mlir::createTensorMemoryAllocationPass()";
87-
8882
let dependentDialects = [
8983
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
9084
];
@@ -97,8 +91,6 @@ def TritonNvidiaGPUMMALoweringPass : Pass<"triton-nvidia-mma-lowering", "mlir::M
9791
Lower MMA ops to prepare for conversion to LLVM.
9892
}];
9993

100-
let constructor = "mlir::createTritonNvidiaGPUMMALoweringPass()";
101-
10294
let dependentDialects = [
10395
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"
10496
];
@@ -111,8 +103,6 @@ def TritonNvidiaGPUPromoteLHSToTMemPass : Pass<"tritongpu-promote-lhs-to-tmem",
111103
Promote LHS operand of MMAv5 op to Tensor Memory.
112104
}];
113105

114-
let constructor = "mlir::createTritonNvidiaGPUPromoteLHSToTMemPass()";
115-
116106
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
117107
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
118108
"mlir::triton::TritonDialect"];

lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
101101
return true;
102102
return false;
103103
});
104+
addDynamicallyLegalOp<triton::FuncOp>([](triton::FuncOp funcOp) -> bool {
105+
for (auto arg : funcOp.getArguments()) {
106+
if (auto tensor = dyn_cast<RankedTensorType>(arg.getType())) {
107+
if (!tensor.getEncoding())
108+
return false;
109+
}
110+
}
111+
return true;
112+
});
104113
}
105114

106115
bool TritonGPUConversionTarget::isDynamicallyLegal(

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,14 +481,17 @@ class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
481481
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
482482
ConversionPatternRewriter &rewriter) const override {
483483
auto converter = getTypeConverter();
484+
TypeConverter::SignatureConversion result(op.getNumArguments());
484485
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
485486
op, op.getName(), op.getFunctionType());
486487
addNamedAttrs(newOp, adaptor.getAttributes());
487488
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
488489
newOp.getBody().end());
489-
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
490-
return failure();
491-
490+
// Convert just the entry block. The remaining unstructured control flow is
491+
// converted by br patterns.
492+
if (!newOp.getBody().empty())
493+
rewriter.applySignatureConversion(&newOp.getBody().front(), result,
494+
converter);
492495
return success();
493496
}
494497
};

lib/Dialect/Triton/Transforms/LoopAwareCSE.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct LoopCSEDriver {
4343
bool areEqualInLoop(Value a, Value b);
4444

4545
scf::ForOp loop;
46-
ValueEquivalence equalValues;
46+
SmallVector<std::pair<int, int>> argStack;
4747
};
4848
} // namespace
4949

@@ -52,14 +52,15 @@ bool LoopCSEDriver::areIterArgsEqual(int i, int j) {
5252
return true;
5353
if (loop.getInitArgs()[i] != loop.getInitArgs()[j])
5454
return false;
55+
if (llvm::is_contained(argStack, std::make_pair(i, j)))
56+
return true;
5557
BlockArgument aArg = loop.getRegionIterArg(i);
5658
BlockArgument bArg = loop.getRegionIterArg(j);
5759
// First, assume the arguments are equal. This is how recursion is broken.
58-
equalValues.setKnownEquivalence(aArg, bArg, true);
60+
argStack.push_back({i, j});
5961
bool result =
6062
areEqualInLoop(loop.getYieldedValues()[i], loop.getYieldedValues()[j]);
61-
// Now update the equivalence based on the actual result.
62-
equalValues.setKnownEquivalence(aArg, bArg, result);
63+
argStack.pop_back();
6364
return result;
6465
}
6566

@@ -83,14 +84,10 @@ bool LoopCSEDriver::areEqualInLoop(Value a, Value b) {
8384
if (a == loop.getInductionVar() || b == loop.getInductionVar())
8485
return false;
8586

86-
if (std::optional<bool> eq = equalValues.getKnownEquivalence(a, b))
87-
return *eq;
88-
8987
if (auto aArg = dyn_cast<BlockArgument>(a)) {
9088
auto bArg = cast<BlockArgument>(b);
9189
bool result =
9290
areIterArgsEqual(aArg.getArgNumber() - 1, bArg.getArgNumber() - 1);
93-
equalValues.setKnownEquivalence(a, b, result);
9491
return result;
9592
}
9693

@@ -107,9 +104,7 @@ bool LoopCSEDriver::areEqualInLoop(Value a, Value b) {
107104
bool result = OperationEquivalence::isEquivalentTo(
108105
aDef, bDef,
109106
[&](Value a, Value b) { return success(areEqualInLoop(a, b)); },
110-
[&](Value a, Value b) { equalValues.setKnownEquivalence(a, b, true); },
111-
OperationEquivalence::IgnoreLocations);
112-
equalValues.setKnownEquivalence(a, b, result);
107+
/*markEquivalent=*/nullptr, OperationEquivalence::IgnoreLocations);
113108
return result;
114109
}
115110

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,21 @@
1313
//
1414
//===----------------------------------------------------------------------===//
1515

16-
using namespace mlir;
17-
namespace tt = ::mlir::triton;
18-
namespace ttg = ::mlir::triton::gpu;
19-
namespace ttng = ::mlir::triton::nvidia_gpu;
16+
namespace ttg = mlir::triton::gpu;
2017

21-
#define GEN_PASS_CLASSES
22-
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
18+
namespace mlir {
19+
namespace triton {
20+
namespace nvidia_gpu {
2321

24-
namespace {
22+
#define GEN_PASS_DEF_TRITONGPUFENCEINSERTION
23+
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
2524

2625
struct FenceInsertionPass
27-
: public TritonGPUFenceInsertionBase<FenceInsertionPass> {
26+
: public impl::TritonGPUFenceInsertionBase<FenceInsertionPass> {
2827

2928
public:
30-
FenceInsertionPass() = default;
31-
FenceInsertionPass(int computeCapability) {
32-
this->computeCapability = computeCapability;
33-
}
29+
using impl::TritonGPUFenceInsertionBase<
30+
FenceInsertionPass>::TritonGPUFenceInsertionBase;
3431
// TODO: support more general patterns to insert fences. eg. any op(generic)
3532
// to shared in use-def chain which refers by async proxy. We have generic(
3633
// convertlayout with sts/stmatix) + fence + async(wgmma) up to now
@@ -39,7 +36,7 @@ struct FenceInsertionPass
3936
if (computeCapability < 90)
4037
return;
4138
ModuleOp mod = getOperation();
42-
mod.walk([&](tt::DotOpInterface dotOp) {
39+
mod.walk([&](DotOpInterface dotOp) {
4340
Value a = dotOp.getA();
4441
Value b = dotOp.getB();
4542
bool aDependsOnShared = dependOnCopyRegToShared(a);
@@ -48,8 +45,8 @@ struct FenceInsertionPass
4845
return WalkResult::advance();
4946

5047
OpBuilder builder(dotOp);
51-
auto fence = builder.create<ttng::FenceAsyncSharedOp>(dotOp.getLoc(),
52-
/*bCluster=*/false);
48+
auto fence = builder.create<FenceAsyncSharedOp>(dotOp.getLoc(),
49+
/*bCluster=*/false);
5350
// If there is all the dependencies are outside of the loop try to hoist
5451
// the fence.
5552
while (auto loopOp = fence->getParentOfType<LoopLikeOpInterface>()) {
@@ -63,8 +60,8 @@ struct FenceInsertionPass
6360
}
6461

6562
// If the previous op is already a fence, this one isn't needed.
66-
if (auto lastFence = dyn_cast_or_null<ttng::FenceAsyncSharedOp>(
67-
fence->getPrevNode())) {
63+
if (auto lastFence =
64+
dyn_cast_or_null<FenceAsyncSharedOp>(fence->getPrevNode())) {
6865
if (lastFence.getBCluster() == fence.getBCluster())
6966
fence.erase();
7067
}
@@ -129,9 +126,7 @@ struct FenceInsertionPass
129126
return true;
130127
}
131128
};
132-
} // namespace
133129

134-
std::unique_ptr<Pass>
135-
mlir::createTritonNvidiaGPUFenceInsertionPass(int computeCapability) {
136-
return std::make_unique<FenceInsertionPass>(computeCapability);
137-
}
130+
} // namespace nvidia_gpu
131+
} // namespace triton
132+
} // namespace mlir

lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
55
#include "llvm/ADT/AddressRanges.h"
66

7-
namespace {
8-
9-
using namespace mlir;
7+
namespace ttg = mlir::triton::gpu;
108

11-
namespace ttng = triton::nvidia_gpu;
12-
namespace ttg = triton::gpu;
13-
namespace tt = triton;
9+
namespace mlir {
10+
namespace triton {
11+
namespace nvidia_gpu {
1412

15-
#define GEN_PASS_CLASSES
13+
#define GEN_PASS_DEF_TRITONNVIDIAGPUINTERLEAVETMEMPASS
1614
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
1715

16+
namespace {
17+
1818
// If we don't know the effects of the op, we add all possible effects.
1919
void addAllValuelessEffects(
2020
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
@@ -77,7 +77,7 @@ std::pair<Value, AccessRange> findBufferAccess(Value a) {
7777

7878
Operation *defOp = a.getDefiningOp();
7979
// Accessing the alloc accesses the whole buffer.
80-
if (auto alloc = dyn_cast<ttng::TMEMAllocOp>(defOp)) {
80+
if (auto alloc = dyn_cast<TMEMAllocOp>(defOp)) {
8181
AccessRange access;
8282
for (uint64_t dim : alloc.getType().getShape())
8383
access.ranges.push_back({{0, dim}});
@@ -128,7 +128,7 @@ std::pair<Value, AccessRange> findBufferAccess(Value a) {
128128
}
129129

130130
// Subslice is a subview only on the N dimension.
131-
if (auto subslice = dyn_cast<ttng::TMEMSubSliceOp>(defOp)) {
131+
if (auto subslice = dyn_cast<TMEMSubSliceOp>(defOp)) {
132132
auto [alloc, parentAccess] = findBufferAccess(subslice.getSrc());
133133
if (!alloc)
134134
return {};
@@ -186,7 +186,7 @@ bool sinkOps(Value buffer, ArrayRef<Operation *> useChain) {
186186
}
187187
// Don't sink past barrier signals, since they may guard the liverange
188188
// of the buffer.
189-
if (isa<ttng::ArriveBarrierOp>(next))
189+
if (isa<ArriveBarrierOp>(next))
190190
break;
191191
if (!isMemoryEffectFree(next)) {
192192
SmallVector<MemoryEffects::EffectInstance> effects;
@@ -199,7 +199,7 @@ bool sinkOps(Value buffer, ArrayRef<Operation *> useChain) {
199199
dep = true;
200200
break;
201201
}
202-
if (isa<ttng::TensorMemory>(effect.getResource()) &&
202+
if (isa<TensorMemory>(effect.getResource()) &&
203203
(!effect.getValue() || tmemMayAlias(effect.getValue(), buffer))) {
204204
dep = true;
205205
break;
@@ -229,20 +229,22 @@ bool trySinkOp(Operation *op, Value buffer) {
229229
return sinkOps(buffer, useChain);
230230
}
231231

232+
} // anonymous namespace
233+
232234
struct TritonNvidiaGPUInterleaveTMemPass
233-
: public TritonNvidiaGPUInterleaveTMemPassBase<
235+
: public impl::TritonNvidiaGPUInterleaveTMemPassBase<
234236
TritonNvidiaGPUInterleaveTMemPass> {
235-
using TritonNvidiaGPUInterleaveTMemPassBase::
236-
TritonNvidiaGPUInterleaveTMemPassBase;
237+
using impl::TritonNvidiaGPUInterleaveTMemPassBase<
238+
TritonNvidiaGPUInterleaveTMemPass>::TritonNvidiaGPUInterleaveTMemPassBase;
237239

238240
void runOnOperation() override {
239241
MLIRContext *context = &getContext();
240242
ModuleOp m = getOperation();
241243
SmallVector<std::pair<Operation *, Value>> opsToSink;
242244
m.walk([&](Operation *op) {
243-
if (auto load = dyn_cast<ttng::TMEMLoadOp>(op))
245+
if (auto load = dyn_cast<TMEMLoadOp>(op))
244246
opsToSink.emplace_back(load, load.getSrc());
245-
else if (auto alloc = dyn_cast<ttng::TMEMAllocOp>(op))
247+
else if (auto alloc = dyn_cast<TMEMAllocOp>(op))
246248
opsToSink.emplace_back(alloc, alloc.getResult());
247249
});
248250
for (auto [op, buffer] : opsToSink) {
@@ -253,8 +255,6 @@ struct TritonNvidiaGPUInterleaveTMemPass
253255
}
254256
};
255257

256-
} // namespace
257-
258-
std::unique_ptr<Pass> mlir::createTritonNvidiaGPUInterleaveTMemPass() {
259-
return std::make_unique<TritonNvidiaGPUInterleaveTMemPass>();
260-
}
258+
} // namespace nvidia_gpu
259+
} // namespace triton
260+
} // namespace mlir

0 commit comments

Comments
 (0)