Skip to content

Commit f223b65

Browse files
Merge OpenAI Triton commit 343bd8e (#4519)
This PR change the Triton base from 418c127 to 343bd8e (Jun 12). Pass rate: 97.11%
2 parents 3bb0b5f + 8eb3861 commit f223b65

File tree

63 files changed

+1115
-747
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1115
-747
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ dev-install-llvm:
107107

108108
.PHONY: golden-samples
109109
golden-samples: triton-opt
110-
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \
110+
$(TRITON_OPT) test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-pipeline -canonicalize | \
111111
$(PYTHON) utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in --source_delim_regex="\bmodule" \
112112
-o test/TritonGPU/samples/simulated-grouped-gemm.mlir
113113
$(TRITON_OPT) test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-assign-latencies -tritongpu-schedule-loops -tritongpu-pipeline -canonicalize | \

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ void registerTestAxisInfoPass();
4848

4949
void registerTestAliasPass();
5050
void registerTestAlignmentPass();
51+
void registerAMDTestAlignmentPass();
5152
void registerTestAllocationPass();
5253
void registerTestLivenessPass();
5354
void registerTestMembarPass();
@@ -65,6 +66,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6566
mlir::test::intel::registerTestAxisInfoPass();
6667
mlir::test::registerTestAliasPass();
6768
mlir::test::registerTestAlignmentPass();
69+
mlir::test::registerAMDTestAlignmentPass();
6870
mlir::test::registerTestAllocationPass();
6971
mlir::test::registerTestLivenessPass();
7072
mlir::test::registerTestMembarPass();

include/triton/Analysis/AxisInfo.h

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,49 @@ class AxisInfo {
204204
std::optional<int64_t> constantValue;
205205
};
206206

207+
class AxisInfoVisitor {
208+
public:
209+
AxisInfoVisitor() = default;
210+
virtual ~AxisInfoVisitor() = default;
211+
212+
bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape, int dim) {
213+
return info.getContiguity(dim) == shape[dim];
214+
}
215+
216+
bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape, int dim) {
217+
return info.getConstancy(dim) == shape[dim];
218+
}
219+
220+
virtual AxisInfo
221+
getAxisInfo(Operation *op,
222+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
223+
224+
virtual bool match(Operation *op) = 0;
225+
};
226+
227+
class AxisInfoVisitorList {
228+
public:
229+
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
230+
void append() {
231+
(visitors.emplace_back(std::make_unique<Ts>()), ...);
232+
}
233+
234+
AxisInfo apply(Operation *op,
235+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
236+
for (auto &visitor : visitors)
237+
if (visitor->match(op))
238+
return visitor->getAxisInfo(op, operands);
239+
return AxisInfo();
240+
}
241+
242+
private:
243+
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
244+
};
245+
246+
namespace axisinfo {
247+
using CallbackType = std::function<void(AxisInfoVisitorList &)>;
248+
} // namespace axisinfo
249+
207250
// Module level axis info analysis based on the call graph, assuming that we do
208251
// not have recursive functions.
209252
//
@@ -214,7 +257,8 @@ class AxisInfo {
214257
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
215258
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
216259
public:
217-
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
260+
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp,
261+
axisinfo::CallbackType callback = nullptr)
218262
: CallGraph<AxisInfoMapT>(moduleOp) {
219263
SmallVector<FunctionOpInterface> funcs;
220264
for (auto root : getRoots()) {
@@ -230,7 +274,7 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
230274
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
231275
SymbolTableCollection symbolTable;
232276
for (auto funcOp : llvm::reverse(sortedFuncs)) {
233-
initialize(funcOp);
277+
initialize(funcOp, callback);
234278
funcOp.walk([&](CallOpInterface callOp) {
235279
auto callee = dyn_cast<FunctionOpInterface>(
236280
callOp.resolveCallableInTable(&symbolTable));
@@ -272,10 +316,10 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
272316
unsigned getMaskAlignment(Value mask);
273317

274318
private:
275-
void initialize(FunctionOpInterface funcOp);
319+
void initialize(FunctionOpInterface funcOp,
320+
axisinfo::CallbackType callback = nullptr);
276321
void update(CallOpInterface callOp, FunctionOpInterface funcOp);
277322
};
278-
279323
} // namespace mlir::triton
280324

281325
#endif

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,6 @@ size_t linearize(ArrayRef<unsigned> multiDim, ArrayRef<unsigned> shape,
423423
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
424424
StringRef content);
425425

426-
inline bool isKernel(FunctionOpInterface funcOp) {
427-
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
428-
}
429-
430426
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp);
431427

432428
Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
@@ -547,6 +543,13 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
547543
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
548544
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
549545

546+
[[nodiscard]] bool emitTransferBetweenRegistersAndShared(
547+
LinearLayout &regLayout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
548+
std::optional<int32_t> maxVecElems, const SharedMemoryObject &smemObj,
549+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
550+
Value laneId, Value warpId,
551+
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
552+
550553
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
551554
Type elemLlvmTy,
552555
const SharedMemoryObject &smemObj,

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ Value getLastInductionValue(OpBuilder &b, scf::ForOp loop);
182182

183183
MakeTensorPtrOp getMakeTensorPtrOp(Value v);
184184

185+
bool isHostSideDescriptor(Value v);
186+
187+
bool isKernel(FunctionOpInterface funcOp);
188+
185189
} // namespace triton
186190
} // namespace mlir
187191

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
298298
}];
299299

300300
let hasVerifier = 1;
301+
let hasFolder = 1;
301302
}
302303

303304
def TTG_LocalLoadOp : TTG_Op<"local_load"> {

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ static const char *kWarpSpecializeAttrName = "tt.warp_specialize";
1919
static const char *kLoopStageAttrName = "loop.stage";
2020
static const char *kLoopClusterAttrName = "loop.cluster";
2121
static const char *kScheduledMaxStageAttrName = "tt.scheduled_max_stage";
22-
22+
class CoarseSchedule;
23+
class ModuleAxisInfoAnalysis;
2324
//===----------------------------------------------------------------------===//
2425
// Hoisting Utilities
2526
//===----------------------------------------------------------------------===//
@@ -86,6 +87,9 @@ std::pair<Operation *, int64_t> getDefiningOpAndDistance(scf::ForOp forOp,
8687
int getCopyVecBytes(RankedTensorType registerTy,
8788
gpu::SharedEncodingTrait sharedEnc);
8889

90+
bool canBeConvertedToAsyncLoad(
91+
triton::LoadOp loadOp, triton::ModuleAxisInfoAnalysis &axisInfoAnalysis);
92+
8993
// Serialize the latencies of the operations in the loops into the latency
9094
// attribute.
9195
void serializeLatencies(ModuleOp module, DenseMap<Operation *, int> &opLatency);
@@ -138,6 +142,12 @@ createSingleBufferView(OpBuilder &builder, Value alloc, Value idx);
138142
TypedValue<triton::gpu::MemDescType>
139143
createSingleBufferView(OpBuilder &builder, Value alloc, int idx);
140144

145+
Value createIncrementModulo(OpBuilder &builder, Location loc, Value counter,
146+
Value modulus, Value zero, Value one,
147+
Value *outWrapCond = nullptr);
148+
149+
scf::ForOp lowerTMADescriptors(scf::ForOp forOp, CoarseSchedule &schedule);
150+
141151
} // namespace triton
142152
} // namespace mlir
143153

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
33

44
#include "mlir/Dialect/SCF/IR/SCF.h"
5+
#include "mlir/IR/ImplicitLocOpBuilder.h"
56
#include "mlir/Support/LLVM.h"
67
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
78
#include "llvm/ADT/ArrayRef.h"
@@ -164,6 +165,32 @@ class CoarseSchedule {
164165
// the same stage and ordering cluster as the anchor op.
165166
void scheduleDependencies(scf::ForOp forOp, CoarseSchedule &schedule);
166167

168+
class OpBuilderForStage : public mlir::ImplicitLocOpBuilder,
169+
public OpBuilder::Listener {
170+
public:
171+
explicit OpBuilderForStage(Location loc, Operation *op,
172+
CoarseSchedule &schedule)
173+
: ImplicitLocOpBuilder(loc, op, this), schedule(schedule) {
174+
if (auto it = schedule.find(op); it != schedule.end())
175+
std::tie(stage, cluster) = it->second;
176+
}
177+
178+
void setStageCluster(std::pair<int, CoarseSchedule::Cluster> stageCluster) {
179+
stage = stageCluster.first;
180+
cluster = stageCluster.second;
181+
}
182+
183+
void notifyOperationInserted(Operation *op, InsertPoint previous) {
184+
if (stage && cluster)
185+
schedule.insert(op, *stage, *cluster);
186+
}
187+
188+
private:
189+
std::optional<int> stage;
190+
std::optional<CoarseSchedule::Cluster> cluster;
191+
CoarseSchedule &schedule;
192+
};
193+
167194
} // namespace triton
168195
} // namespace mlir
169196
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

lib/Analysis/AxisInfo.cpp

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1+
#include "triton/Analysis/AxisInfo.h"
12
#include "mlir/Analysis/DataFlowFramework.h"
23
#include "mlir/Dialect/UB/IR/UBOps.h"
3-
#include "llvm/Support/Debug.h"
4-
#include "llvm/Support/raw_ostream.h"
5-
6-
#include "triton/Analysis/AxisInfo.h"
74
#include "triton/Dialect/Triton/IR/Dialect.h"
85
#include "triton/Dialect/Triton/IR/Utility.h"
96
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
7+
#include "llvm/Support/Debug.h"
8+
#include "llvm/Support/raw_ostream.h"
109

1110
#define DEBUG_TYPE "axis-info"
1211
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -52,28 +51,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5251
return lhs * rhs;
5352
}
5453

55-
class AxisInfoVisitor {
56-
public:
57-
AxisInfoVisitor() = default;
58-
virtual ~AxisInfoVisitor() = default;
59-
60-
static bool isContiguousDim(const AxisInfo &info, ArrayRef<int64_t> shape,
61-
int dim) {
62-
return info.getContiguity(dim) == shape[dim];
63-
}
64-
65-
static bool isConstantDim(const AxisInfo &info, ArrayRef<int64_t> shape,
66-
int dim) {
67-
return info.getConstancy(dim) == shape[dim];
68-
}
69-
70-
virtual AxisInfo
71-
getAxisInfo(Operation *op,
72-
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) = 0;
73-
74-
virtual bool match(Operation *op) = 0;
75-
};
76-
7754
// Base class for all operations
7855
template <typename OpTy> class AxisInfoVisitorImpl : public AxisInfoVisitor {
7956
public:
@@ -147,25 +124,6 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
147124
}
148125
};
149126

150-
class AxisInfoVisitorList {
151-
public:
152-
template <typename... Ts, typename = std::enable_if_t<sizeof...(Ts) != 0>>
153-
void append() {
154-
(visitors.emplace_back(std::make_unique<Ts>()), ...);
155-
}
156-
157-
AxisInfo apply(Operation *op,
158-
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) {
159-
for (auto &visitor : visitors)
160-
if (visitor->match(op))
161-
return visitor->getAxisInfo(op, operands);
162-
return AxisInfo();
163-
}
164-
165-
private:
166-
std::vector<std::unique_ptr<AxisInfoVisitor>> visitors;
167-
};
168-
169127
class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
170128
dataflow::Lattice<AxisInfo>> {
171129
private:
@@ -193,7 +151,8 @@ class AxisInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis<
193151
}
194152

195153
public:
196-
AxisInfoAnalysis(DataFlowSolver &solver);
154+
AxisInfoAnalysis(DataFlowSolver &solver,
155+
axisinfo::CallbackType callback = nullptr);
197156
using dataflow::SparseForwardDataFlowAnalysis<
198157
dataflow::Lattice<AxisInfo>>::getLatticeElement;
199158
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
@@ -1031,7 +990,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
1031990
// AxisInfoAnalysis
1032991
//===----------------------------------------------------------------------===//
1033992

1034-
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
993+
AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
994+
axisinfo::CallbackType callback)
1035995
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(
1036996
solver) {
1037997
// UnrealizedConversionCast:
@@ -1070,6 +1030,9 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10701030
MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10711031
MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10721032
visitors.append<LoadOpAxisInfoVisitor>();
1033+
1034+
if (callback)
1035+
callback(visitors);
10731036
}
10741037

10751038
LogicalResult AxisInfoAnalysis::visitOperation(
@@ -1339,9 +1302,10 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
13391302
return alignment;
13401303
}
13411304

1342-
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
1305+
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp,
1306+
axisinfo::CallbackType callback) {
13431307
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
1344-
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
1308+
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>(callback);
13451309
// Walk pre-order so analysis results can be propagated into nested isolated
13461310
// regions.
13471311
WalkResult result =

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
6666
// 1. Modify the function type to add the new arguments.
6767
auto funcTy = funcOp.getFunctionType();
6868
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
69-
bool isKernel = LLVM::isKernel(funcOp);
69+
bool isKernel = triton::isKernel(funcOp);
7070
if (isKernel) {
7171
for (auto i : llvm::seq(amendedInputTy.size())) {
7272
if (isa<TensorDescType>(amendedInputTy[i])) {
@@ -111,7 +111,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
111111
// Map the MLIR attribute `tt.nv_tma_desc` to the appropriate LLVM and NVVM
112112
// attributes.
113113
static void handleByvalTmaDescArgs(LLVM::LLVMFuncOp &llvmFuncOp) {
114-
const bool isKernel = LLVM::isKernel(llvmFuncOp);
114+
const bool isKernel = triton::isKernel(llvmFuncOp);
115115
for (unsigned i = 0; i < llvmFuncOp.getNumArguments(); ++i) {
116116
const auto attrs = llvmFuncOp.getArgAttrDict(i);
117117
if (!attrs) {
@@ -161,7 +161,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
161161

162162
auto ctx = funcOp->getContext();
163163

164-
if (LLVM::isKernel(funcOp)) {
164+
if (triton::isKernel(funcOp)) {
165165
// Set an attribute to indicate this function is a kernel entry.
166166
newFuncOp->setAttr(NVVM::NVVMDialect::getKernelFuncAttrName(),
167167
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));

0 commit comments

Comments
 (0)