Skip to content

Commit 63ac77b

Browse files
Merge OpenAI Triton commit d25fc5f (#4420)
This PR change the Triton base from 5cf16d7 to d25fc5f (May 28). Pass rate: 97.23%
2 parents 9054446 + 0633a33 commit 63ac77b

File tree

55 files changed

+1576
-783
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+1576
-783
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
300300
add_compile_definitions(TRITON_BACKENDS_TUPLE=${TRITON_BACKENDS_TUPLE})
301301
add_library(triton SHARED ${PYTHON_SRC_PATH}/main.cc
302302
${PYTHON_SRC_PATH}/ir.cc
303+
${PYTHON_SRC_PATH}/gluon_ir.cc
303304
${PYTHON_SRC_PATH}/passes.cc
304305
${PYTHON_SRC_PATH}/interpreter.cc
305306
${PYTHON_SRC_PATH}/llvm.cc)

Makefile

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ test-lit:
2828
test-cpp:
2929
ninja -C $(BUILD_DIR) check-triton-unit-tests
3030

31-
.PHONY: test-python
31+
.PHONY: test-unit
3232
test-unit: all
3333
cd python/test/unit && $(PYTEST) -s -n 8 --ignore=cuda/test_flashattention.py \
3434
--ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -40,6 +40,11 @@ test-unit: all
4040
$(PYTEST) -vs python/tutorials/06-fused-attention.py
4141
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4242
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
43+
$(PYTEST) -s -n 8 python/test/gluon
44+
45+
.PHONY: test-gluon
46+
test-gluon: all
47+
$(PYTEST) -s -n 8 python/test/gluon
4348

4449
.PHONY: test-regression
4550
test-regression: all

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,6 @@ SetVector<Value> getNestedOperands(Operation *op);
247247
// Erase the given loop carried values from the loop, where `loop` is replaced
248248
// with a new loop.
249249
void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices);
250-
251-
// Get a boolean if the Value is an arith::ConstantOp
252-
std::optional<bool> getBoolFromConstant(Value cst);
253250
} // namespace mlir
254251

255252
namespace mlir::triton {

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <algorithm>
55
#include <assert.h>
66
#include <cstdlib>
7+
#include <mutex>
78
#include <set>
89
#include <sstream>
910
#include <string>
@@ -75,7 +76,10 @@ inline void assertIsRecognized(const std::string &env) {
7576
assert((is_invalidating || is_neutral) && errmsg.c_str());
7677
}
7778

79+
static std::mutex getenv_mutex;
80+
7881
inline std::string getStrEnv(const std::string &env) {
82+
std::lock_guard<std::mutex> lock(getenv_mutex);
7983
assertIsRecognized(env);
8084
const char *cstr = std::getenv(env.c_str());
8185
if (!cstr)
@@ -86,6 +90,7 @@ inline std::string getStrEnv(const std::string &env) {
8690

8791
// return value of a cache-invalidating boolean environment variable
8892
inline bool getBoolEnv(const std::string &env) {
93+
std::lock_guard<std::mutex> lock(getenv_mutex);
8994
assertIsRecognized(env);
9095
const char *s = std::getenv(env.c_str());
9196
std::string str(s ? s : "");

lib/Dialect/TritonGPU/Transforms/HoistTMEMAlloc.cpp

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include "mlir/Dialect/Arith/IR/Arith.h"
2-
#include "mlir/Dialect/SCF/IR/SCF.h"
31
#include "mlir/IR/Dominance.h"
42
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
53
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -94,53 +92,6 @@ class RemoveUnusedTMEMLoad : public OpRewritePattern<TMEMTokenLoadOp> {
9492
}
9593
};
9694

97-
class RemoveUnusedTMEMStore : public OpRewritePattern<TMEMTokenStoreOp> {
98-
public:
99-
using OpRewritePattern::OpRewritePattern;
100-
101-
LogicalResult matchAndRewrite(TMEMTokenStoreOp store,
102-
PatternRewriter &rewriter) const override {
103-
auto pred = getBoolFromConstant(store.getPred());
104-
if (!pred || pred.value() == false)
105-
return failure(); // we've already processed this
106-
auto tok = store.getToken();
107-
if (!tok.hasOneUse())
108-
return failure();
109-
auto loop = dyn_cast<scf::ForOp>(*tok.getUsers().begin());
110-
if (!loop)
111-
return failure();
112-
auto loopTok = loop.getBody()->getArgument(
113-
tok.getUses().begin()->getOperandNumber() - 2);
114-
if (!loopTok.hasOneUse())
115-
return failure();
116-
auto mma =
117-
dyn_cast<nvidia_gpu::MMAv5OpInterface>(*loopTok.getUsers().begin());
118-
if (!mma)
119-
return failure();
120-
auto useD = dyn_cast<BlockArgument>(mma.useAccumulator());
121-
if (!useD)
122-
return failure();
123-
auto parent = useD.getParentBlock()->getParentOp();
124-
if (parent != loop)
125-
return failure();
126-
auto loopInit = loop.getInitArgs()[useD.getArgNumber() - 1];
127-
auto val = getBoolFromConstant(loopInit);
128-
if (!val)
129-
return failure();
130-
if (val.value() == true)
131-
return failure();
132-
auto loc = store.getLoc();
133-
rewriter.setInsertionPoint(store);
134-
Value diff = rewriter.create<arith::SubIOp>(loc, loop.getUpperBound(),
135-
loop.getLowerBound());
136-
Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, diff.getType());
137-
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
138-
diff, zero);
139-
store.getPredMutable().assign(cond);
140-
return success();
141-
}
142-
};
143-
14495
// Load-store forwarding pattern.
14596
class CombineTMEMLoadAndStore : public OpRewritePattern<TMEMTokenStoreOp> {
14697
public:
@@ -460,8 +411,7 @@ struct HoistTMEMAlloc
460411
mlir::RewritePatternSet patterns(&getContext());
461412
patterns.add<RotateTMEMStoreInLoop, RotateTMEMLoadInLoop,
462413
CombineTMEMLoadAndStore, CombineTMEMStoreAndSelect,
463-
SinkTMEMLoad, RemoveUnusedTMEMLoad, RemoveUnusedTMEMStore>(
464-
&getContext());
414+
SinkTMEMLoad, RemoveUnusedTMEMLoad>(&getContext());
465415
scf::ForOp::getCanonicalizationPatterns(patterns, &getContext());
466416
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
467417
llvm_unreachable("Failed to hoist tmem_store");

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,18 @@ findZeroInitOp(Value accUse, scf::ForOp forOp, bool &loopArgIsZero) {
171171
return std::nullopt;
172172
}
173173

174+
std::optional<bool> getBoolFromConstant(Value cst) {
175+
auto constantOp = cst.getDefiningOp<arith::ConstantOp>();
176+
if (!constantOp) {
177+
return std::nullopt;
178+
}
179+
assert(constantOp.getValue());
180+
if (auto boolAttr = dyn_cast<BoolAttr>(constantOp.getValue())) {
181+
return boolAttr.getValue();
182+
}
183+
return std::nullopt;
184+
}
185+
174186
} // namespace
175187

176188
class OptimizeAccumulatorInitPass

lib/Dialect/TritonGPU/Transforms/Pipeliner/ScheduleLoops.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,32 @@ CoarseSchedule getInitialSchedule(scf::ForOp forOp,
167167
CoarseSchedule schedule;
168168
if (forOp->hasAttr(kWarpSpecializeAttrName) &&
169169
succeeded(schedule.deSerialize(forOp))) {
170+
// The loop was partitioned from a warp-specialized loop, meaning it can
171+
// have a partial view of the original loop stages. Re-schedule the loop
172+
// root at the stages of the latency ops to prune unnecessary stages.
173+
auto isLatencyOp = [&](Operation &op) {
174+
return opLatency.count(&op) ||
175+
isa<LocalStoreOp, LocalLoadOp, ttng::TMEMLoadOp, ttng::TMEMStoreOp,
176+
AsyncCopyGlobalToLocalOp, ttng::AsyncTMACopyGlobalToLocalOp,
177+
ttng::AsyncTMAGatherOp, ttng::MMAv5OpInterface,
178+
ttng::WaitBarrierOp, ttng::ArriveBarrierOp>(op);
179+
};
180+
181+
// If there are no latency ops or all latency ops are in the same stage, we
182+
// don't need to pipeline the loop. Return a new schedule with everything
183+
// assigned to the same stage.
184+
DenseSet<int> latencyStages;
185+
auto ops = forOp.getBody()->without_terminator();
186+
for (Operation &op : llvm::make_filter_range(ops, isLatencyOp))
187+
latencyStages.insert(schedule[&op].first);
188+
if (latencyStages.size() <= 1) {
189+
CoarseSchedule normalized(/*numStages=*/1);
190+
auto cluster = normalized.clusters.newAtFront();
191+
for (Operation &op : ops)
192+
normalized.insert(&op, 0, cluster);
193+
return normalized;
194+
}
195+
170196
schedule.shrinkToFit();
171197
return schedule;
172198
}

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,18 +1400,6 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) {
14001400
loop = newLoop;
14011401
}
14021402

1403-
std::optional<bool> getBoolFromConstant(Value cst) {
1404-
auto constantOp = cst.getDefiningOp<arith::ConstantOp>();
1405-
if (!constantOp) {
1406-
return std::nullopt;
1407-
}
1408-
assert(constantOp.getValue());
1409-
if (auto boolAttr = dyn_cast<BoolAttr>(constantOp.getValue())) {
1410-
return boolAttr.getValue();
1411-
}
1412-
return std::nullopt;
1413-
}
1414-
14151403
} // namespace mlir
14161404

14171405
namespace mlir::triton {

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionLoops.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
192192
// explicit captures are the leaves of the subgraph.
193193
SetVector<Operation *> opsToClone;
194194
SmallVector<Value> explicitCaptures;
195+
SmallVector<IRMapping> mappings(wsOp.getPartitionNumWarps().size());
196+
SmallVector<OpBuilder> builders;
197+
for (Region *region : wsOp.getPartitionRegions())
198+
builders.push_back(OpBuilder::atBlockBegin(&region->front()));
195199
for (unsigned i = 0; i < captures.size(); ++i) {
196200
Value capture = captures[i];
197201

@@ -215,10 +219,11 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
215219
tensorTy.getShape(), tensorTy.getElementType(), sharedEnc,
216220
SharedMemorySpaceAttr::get(tensorTy.getContext()));
217221
auto alloc = b.create<LocalAllocOp>(memdescTy, capture);
218-
for (Region *region : wsOp.getPartitionRegions()) {
219-
b.setInsertionPointToStart(&region->front());
220-
Value value = b.create<LocalLoadOp>(tensorTy, alloc);
222+
for (auto [i, region] : llvm::enumerate(wsOp.getPartitionRegions())) {
223+
Value value =
224+
builders[i].create<LocalLoadOp>(capture.getLoc(), tensorTy, alloc);
221225
replaceAllUsesInRegionWith(capture, value, *region);
226+
mappings[i].map(capture, value);
222227
}
223228
capture = alloc;
224229
}
@@ -228,9 +233,9 @@ LogicalResult triton::gpu::partitionLoop(scf::ForOp loop) {
228233

229234
// Clone the ops into each region in topological order.
230235
opsToClone = topologicalSort(opsToClone);
231-
for (Region *region : wsOp.getPartitionRegions()) {
232-
b.setInsertionPointToStart(&region->front());
233-
IRMapping mapping;
236+
for (auto [i, region] : llvm::enumerate(wsOp.getPartitionRegions())) {
237+
OpBuilder &b = builders[i];
238+
IRMapping &mapping = mappings[i];
234239
for (Operation *op : opsToClone) {
235240
Value copy = b.clone(*op, mapping)->getResult(0);
236241
mapping.map(op->getResult(0), copy);

python/src/gluon_ir.cc

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include "ir.h"
2+
#include "pybind11/pybind11.h"
3+
#include <pybind11/stl.h>
4+
5+
#include "mlir/IR/BuiltinTypes.h"
6+
#include "mlir/IR/Types.h"
7+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
8+
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
9+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
10+
11+
using namespace mlir;
12+
namespace py = pybind11;
13+
namespace ttg = triton::gpu;
14+
15+
struct GluonOpBuilder : public TritonOpBuilder {};
16+
17+
void init_gluon_ir(py::module &&m) {
18+
py::class_<GluonOpBuilder, TritonOpBuilder>(
19+
m, "GluonOpBuilder", py::module_local(), py::dynamic_attr())
20+
.def(py::init<MLIRContext *>())
21+
.def("get_distributed_ty",
22+
[](GluonOpBuilder &self, Type &elementType,
23+
std::vector<int64_t> &shape, Attribute layout) -> Type {
24+
return RankedTensorType::get(shape, elementType, layout);
25+
})
26+
.def("get_shared_mem_desc_ty",
27+
[](GluonOpBuilder &self, Type &elementType,
28+
std::vector<int64_t> &shape, Attribute layout,
29+
std::vector<int64_t> &allocShape) -> Type {
30+
auto ctx = self.getContext();
31+
return ttg::MemDescType::get(shape, elementType, layout,
32+
ttg::SharedMemorySpaceAttr::get(ctx),
33+
/*mutableMemory=*/true,
34+
/*allocShape=*/allocShape);
35+
})
36+
.def("get_blocked_layout",
37+
[](GluonOpBuilder &self, std::vector<unsigned> &sizePerThread,
38+
std::vector<unsigned> &threadsPerWarp,
39+
std::vector<unsigned> &warpsPerCta, std::vector<unsigned> &order,
40+
std::vector<unsigned> &ctasPerCga,
41+
std::vector<unsigned> &ctaSplitNum,
42+
std::vector<unsigned> &ctaOrder) -> Attribute {
43+
auto ctx = self.getContext();
44+
auto ctaLayout = ttg::CTALayoutAttr::get(ctx, ctasPerCga,
45+
ctaSplitNum, ctaOrder);
46+
return ttg::BlockedEncodingAttr::get(ctx, sizePerThread,
47+
threadsPerWarp, warpsPerCta,
48+
order, ctaLayout);
49+
})
50+
.def("get_slice_layout",
51+
[](GluonOpBuilder &self, unsigned dim,
52+
Attribute parent) -> Attribute {
53+
auto ctx = self.getContext();
54+
auto dist = cast<ttg::DistributedEncodingTrait>(parent);
55+
return ttg::SliceEncodingAttr::get(ctx, dim, dist);
56+
})
57+
.def("get_nvmma_shared_layout",
58+
[](GluonOpBuilder &self, unsigned swizzleByteWidth,
59+
unsigned elementBitwidth, bool transposed, bool fp4Padded,
60+
std::vector<unsigned> &ctasPerCga,
61+
std::vector<unsigned> &ctaSplitNum,
62+
std::vector<unsigned> &ctaOrder) -> Attribute {
63+
auto ctx = self.getContext();
64+
auto ctaLayout = ttg::CTALayoutAttr::get(ctx, ctasPerCga,
65+
ctaSplitNum, ctaOrder);
66+
return ttg::NVMMASharedEncodingAttr::get(
67+
ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded,
68+
ctaLayout);
69+
})
70+
.def("create_convert_layout",
71+
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
72+
return self.create<ttg::ConvertLayoutOp>(resultTy, value);
73+
})
74+
.def("create_local_alloc",
75+
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
76+
return self.create<ttg::LocalAllocOp>(resultTy, value);
77+
})
78+
.def("create_local_store",
79+
[](GluonOpBuilder &self, Value memDesc, Value value) {
80+
self.create<ttg::LocalStoreOp>(value, memDesc);
81+
})
82+
.def("create_local_load",
83+
[](GluonOpBuilder &self, Type resultTy, Value memDesc) -> Value {
84+
return self.create<ttg::LocalLoadOp>(resultTy, memDesc);
85+
});
86+
}

0 commit comments

Comments
 (0)