Skip to content

Commit fdefe85

Browse files
Merge commit 'd57cbee8633eaa8f691d87503670b00562d21c5d'
2 parents 9d7bc59 + d57cbee commit fdefe85

File tree

61 files changed

+4262
-3625
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+4262
-3625
lines changed

.github/workflows/integration-tests-nvidia.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ jobs:
8888
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
8989
source /venv/bin/activate
9090
fi
91-
make test-unit
91+
make NUM_PROCS=24 test-unit
9292
- name: Run interpreter tests
9393
if: ${{ matrix.runner[0] == 'nvidia-h100' }}
9494
run: make test-interpret

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,20 @@ if(TRITON_BUILD_PYTHON_MODULE)
321321
target_link_libraries(triton PRIVATE z)
322322
endif()
323323
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
324+
325+
if (NOT DEFINED LLVM_SYSPATH)
326+
message(FATAL_ERROR "LLVM_SYSPATH must be set.")
327+
endif()
328+
329+
if (NOT DEFINED TRITON_WHEEL_DIR)
330+
message(FATAL_ERROR "TRITON_WHEEL_DIR must be set.")
331+
endif()
332+
333+
configure_file(
334+
"${LLVM_SYSPATH}/bin/FileCheck"
335+
"${TRITON_WHEEL_DIR}/FileCheck"
336+
COPYONLY)
337+
324338
endif()
325339

326340
if (UNIX AND NOT APPLE)

Makefile

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ BUILD_DIR := $(shell cd python; $(PYTHON) -c 'from build_helpers import get_cmak
77
TRITON_OPT := $(BUILD_DIR)/bin/triton-opt
88
PYTEST := $(PYTHON) -m pytest
99
LLVM_BUILD_PATH ?= ".llvm-project/build"
10+
NUM_PROCS ?= 8
1011

1112
# Incremental builds
1213

@@ -30,25 +31,25 @@ test-cpp:
3031

3132
.PHONY: test-unit
3233
test-unit: all
33-
cd python/test/unit && $(PYTEST) -s -n 8 --ignore=language/test_line_info.py \
34+
cd python/test/unit && $(PYTEST) -s -n $(NUM_PROCS) --ignore=language/test_line_info.py \
3435
--ignore=language/test_subprocess.py --ignore=test_debug.py
35-
$(PYTEST) -s -n 8 python/test/unit/language/test_subprocess.py
36-
$(PYTEST) -s -n 8 python/test/unit/test_debug.py --forked
36+
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/language/test_subprocess.py
37+
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/test_debug.py --forked
3738
$(PYTEST) -s -n 8 python/triton_kernels/tests/
3839
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
3940
# Run attention separately to avoid out of gpu memory
4041
$(PYTEST) -vs python/tutorials/06-fused-attention.py
4142
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
4243
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
43-
$(PYTEST) -s -n 8 python/test/gluon
44+
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon
4445

4546
.PHONY: test-gluon
4647
test-gluon: all
47-
$(PYTEST) -s -n 8 python/test/gluon
48+
$(PYTEST) -s -n $(NUM_PROCS) python/test/gluon
4849

4950
.PHONY: test-regression
5051
test-regression: all
51-
$(PYTEST) -s -n 8 python/test/regression
52+
$(PYTEST) -s -n $(NUM_PROCS) python/test/regression
5253

5354
.PHONY: test-interpret
5455
test-interpret: all

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@ void replaceUsesWithLocalLoad(
263263
OpBuilder &builder, OpResult old,
264264
TypedValue<triton::gpu::MemDescType> alloc,
265265
TypedValue<triton::gpu::AsyncTokenType> token = {});
266+
267+
// Return true if the value comes from a load or a block argument.
268+
// This will skip convert layouts and memdesc views.
269+
// This is a helper useful to know if value is likely to come from shared memory
270+
// after converting loads into async loads.
271+
bool comesFromLoadOrBlockArg(Value v);
272+
266273
} // namespace mlir::triton
267274

268275
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ OpFoldResult MakeRangeOp::fold(FoldAdaptor adaptor) {
368368
LogicalResult MakeRangeOp::verify() {
369369
int64_t start = getStartAttr().getInt();
370370
int64_t end = getEndAttr().getInt();
371-
if (start > end) {
372-
return this->emitOpError() << "start must be less than or equal to end";
371+
if (start >= end) {
372+
return this->emitOpError() << "start must be less than end";
373373
}
374374
auto ty = getType();
375375
if (ty.getShape().size() != 1) {

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -322,29 +322,6 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
322322
if (!(versionMajor >= 1 && versionMajor <= 3))
323323
return failure();
324324

325-
// If both of the operands are not loads, we fallback to MMAv2
326-
// otherwise the reg-smem roundtrip will tank the MMAv3 performance
327-
auto comesFromLoadOrBlockArg = [](Value v) -> bool {
328-
// Peel out the original cvt dot_op<..., #blocked>
329-
// and any other potential cvt/trans ops
330-
while (true) {
331-
if (auto cvtOp = v.getDefiningOp<ConvertLayoutOp>()) {
332-
v = cvtOp.getSrc();
333-
continue;
334-
}
335-
if (auto transOp = v.getDefiningOp<TransOp>()) {
336-
v = transOp.getSrc();
337-
continue;
338-
}
339-
break;
340-
}
341-
// We also accept block arguments as they appear in many MLIR tests
342-
// If this is problematic we can totally drop them
343-
return isa<BlockArgument>(v) ||
344-
(v.getDefiningOp() &&
345-
isa<LoadOp, DescriptorLoadOp>(v.getDefiningOp()));
346-
};
347-
348325
bool aFromLoad = comesFromLoadOrBlockArg(dotOp.getA());
349326
bool bFromLoad = comesFromLoadOrBlockArg(dotOp.getB());
350327
auto origDotOp = dotOp;

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,4 +1554,29 @@ void replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
15541554
alloc.erase();
15551555
}
15561556
}
1557+
1558+
bool comesFromLoadOrBlockArg(Value v) {
1559+
// Peel out the original cvt dot_op<..., #blocked>
1560+
// and any other potential cvt/trans ops
1561+
while (true) {
1562+
Operation *def = v.getDefiningOp();
1563+
if (!def)
1564+
break;
1565+
if (auto cvtOp = dyn_cast<ttg::ConvertLayoutOp>(def)) {
1566+
v = cvtOp.getSrc();
1567+
continue;
1568+
}
1569+
if (def->hasTrait<OpTrait::MemDescViewTrait>()) {
1570+
v = def->getOperand(0);
1571+
continue;
1572+
}
1573+
break;
1574+
}
1575+
// We also accept block arguments as they appear in many MLIR tests
1576+
// If this is problematic we can totally drop them
1577+
return isa<BlockArgument>(v) ||
1578+
(v.getDefiningOp() &&
1579+
isa<LoadOp, DescriptorLoadOp, DescriptorGatherOp>(v.getDefiningOp()));
1580+
}
1581+
15571582
} // namespace mlir::triton

lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,17 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
226226
return std::nullopt;
227227

228228
// Propagate defs of exp.
229-
for (auto expOp : loop.getOps<math::Exp2Op>()) {
230-
auto tensorTy = dyn_cast<RankedTensorType>(expOp.getType());
231-
if (tensorTy && tensorTy.getNumElements() > 256) {
232-
schedule.trySchedule(defaultPartition, expOp);
233-
scheduleDependencies(loop, schedule, defaultPartition, expOp);
229+
for (Operation &op : loop.getOps()) {
230+
if (!isa<math::Exp2Op, ElementwiseInlineAsmOp>(op))
231+
continue;
232+
int elementCount = 0;
233+
for (Type type : op.getResultTypes()) {
234+
if (auto tensorTy = dyn_cast<RankedTensorType>(type))
235+
elementCount += tensorTy.getNumElements();
236+
}
237+
if (elementCount > 256) {
238+
schedule.trySchedule(defaultPartition, &op);
239+
scheduleDependencies(loop, schedule, defaultPartition, &op);
234240
}
235241
}
236242

@@ -242,7 +248,8 @@ static std::optional<WarpSchedule> getInitialSchedule(scf::ForOp loop) {
242248
while (userPartitions.size() < mmas.size()) {
243249
userPartitions.push_back(schedule.addPartition(userPartitions.size()));
244250
}
245-
for (auto [mmaOp, userPartition] : llvm::zip(mmas, userPartitions)) {
251+
for (auto [mmaOp, userPartition] :
252+
llvm::reverse(llvm::zip(mmas, userPartitions))) {
246253
scheduleUsers(loop, schedule, userPartition, mmaOp);
247254
}
248255

lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
6969
isDistributedLayoutTMemCompatible(tcGen5MMAOp, srcType, lhsMemDescType);
7070
Attribute newLayout = srcLayout;
7171
if (!layoutTmemCompatible) {
72-
if (triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) {
72+
if (!comesFromLoadOrBlockArg(src) ||
73+
triton::tools::getBoolEnv("ALLOW_LHS_TMEM_LAYOUT_CONVERSION")) {
7374
newLayout = getLHSTMemLayout(tcGen5MMAOp, srcType);
7475
} else {
7576
return failure();

python/src/gluon_ir.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
using namespace mlir;
1313
namespace py = pybind11;
14+
namespace tt = triton;
1415
namespace ttg = triton::gpu;
1516
namespace ttng = triton::nvidia_gpu;
1617

@@ -298,7 +299,15 @@ void init_gluon_ir(py::module &&m) {
298299
self.create<ttng::AsyncTMAScatterOp>(descPtr, xOffsets, yOffset,
299300
src);
300301
})
301-
302+
.def("create_broadcast",
303+
[](TritonOpBuilder &self, Value &arg, Type retTy) -> Value {
304+
return self.create<tt::BroadcastOp>(retTy, arg);
305+
})
306+
.def(
307+
"create_expand_dims",
308+
[](TritonOpBuilder &self, Value &arg, int axis, Type retTy) -> Value {
309+
return self.create<tt::ExpandDimsOp>(retTy, arg, axis);
310+
})
302311
.def("create_warp_return",
303312
[](GluonOpBuilder &self) -> Operation * {
304313
return self.create<ttg::WarpReturnOp>();

0 commit comments

Comments
 (0)