Skip to content

Commit b385d79

Browse files
Merge commit '1064b598385c49f03fbc73f6839f578146beb4e4'
2 parents b9e4be1 + 1064b59 commit b385d79

File tree

22 files changed

+1214
-567
lines changed

22 files changed

+1214
-567
lines changed

.github/workflows/integration-tests.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,14 @@ jobs:
239239
cd python
240240
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
241241
if [ ! -d "${LIT_TEST_DIR}" ]; then
242-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
242+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
243243
fi
244244
lit -v "${LIT_TEST_DIR}"
245245
- name: Run python tests on CUDA
246246
run: |
247247
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
248248
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
249-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
249+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
250250
fi
251251
cd python/test/unit
252252
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -268,14 +268,16 @@ jobs:
268268
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
269269
runtime/test_autotuner.py::test_kwargs[False]\
270270
../../tutorials/06-fused-attention.py::test_op --device cpu
271+
- name: Run regression tests
272+
run: |
273+
cd python/test/regression
274+
python3 -m pytest -s -n 8 .
271275
- name: Run C++ unittests
272276
run: |
273277
cd python
274278
cd "build/$(ls build | grep -i cmake)"
275279
ctest -j32
276280
- name: Run Proton tests
277-
env:
278-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
279281
run: |
280282
cd third_party/proton
281283
python3 -m pytest -s test
@@ -395,14 +397,14 @@ jobs:
395397
cd python
396398
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
397399
if [ ! -d "${LIT_TEST_DIR}" ]; then
398-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
400+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
399401
fi
400402
lit -v "${LIT_TEST_DIR}"
401403
- name: Run python tests on HIP
402404
run: |
403405
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
404406
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
405-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
407+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
406408
fi
407409
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
408410
cd python/test/unit
@@ -416,10 +418,15 @@ jobs:
416418
417419
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
418420
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
421+
- name: Run regression tests
422+
run: |
423+
# Reenable test_functional_regression.py once it's fixed
424+
cd python/test/regression
425+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
419426
- name: Run Proton tests
420427
run: |
421428
cd third_party/proton
422-
python3 -m pytest test
429+
python3 -m pytest -s test
423430
- name: Run C++ unittests
424431
run: |
425432
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,15 @@ jobs:
272272
cd python
273273
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
274274
if [ ! -d "${LIT_TEST_DIR}" ]; then
275-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
275+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
276276
fi
277277
lit -v "${LIT_TEST_DIR}"
278278

279279
- name: Run python tests on CUDA
280280
run: |
281281
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
282282
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
283-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
283+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
284284
fi
285285
cd python/test/unit
286286
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -304,16 +304,20 @@ jobs:
304304
runtime/test_autotuner.py::test_kwargs[False]\
305305
../../tutorials/06-fused-attention.py::test_op --device cpu
306306

307+
- name: Run regression tests
308+
run: |
309+
cd python/test/regression
310+
python3 -m pytest -s -n 8 .
311+
307312
- &run-cpp-unittests-step
308313
name: Run C++ unittests
309314
run: |
310315
cd python
311316
cd "build/$(ls build | grep -i cmake)"
312317
ctest -j32
313318

314-
- name: Run Proton tests
315-
env:
316-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
319+
- &run-proton-tests-step
320+
name: Run Proton tests
317321
run: |
318322
cd third_party/proton
319323
python3 -m pytest -s test
@@ -398,7 +402,7 @@ jobs:
398402
run: |
399403
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
400404
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
401-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
405+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
402406
fi
403407
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
404408
cd python/test/unit
@@ -413,11 +417,13 @@ jobs:
413417
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
414418
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
415419

416-
- name: Run Proton tests
420+
- name: Run regression tests
417421
run: |
418-
cd third_party/proton
419-
python3 -m pytest test
422+
# Reenable test_functional_regression.py once it's fixed
423+
cd python/test/regression
424+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
420425

426+
- *run-proton-tests-step
421427
- *run-cpp-unittests-step
422428
- *save-build-artifacts-step
423429
- *inspect-cache-directories-step

bin/triton-lsp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,5 @@ int main(int argc, char **argv) {
66
mlir::DialectRegistry registry;
77
registerTritonDialects(registry);
88

9-
mlir::MLIRContext context(registry);
109
return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry));
1110
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
The conference slides are available [here](https://drive.google.com/drive/folders/1osK9hwcX_lC1EjdZGB-v4w5oKx23UnU2?usp=drive_link)
2+
3+
The conference videos are available [here](https://www.youtube.com/playlist?list=PLc_vA1r0qoiTjlrINKUuFrI8Ptoopm8Vz).

include/triton/Dialect/Triton/IR/Types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Type getI32SameShape(Type type);
3434

3535
Type getPointerTypeSameShape(Type type);
3636

37+
Type getPointerTypeToElement(Type type);
38+
3739
} // namespace triton
3840

3941
} // namespace mlir

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116
RankedTensorType dstTy = op.getType();
117117
Attribute srcLayout = srcTy.getEncoding();
118118
Attribute dstLayout = dstTy.getEncoding();
119+
// FIXME [Dot LL]
120+
// Do for all DotOperandEncodingAttr once we have LLs for all of them
121+
auto isAmpereLargeKWidth = [](Attribute layout) {
122+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
124+
return mma.isAmpere() && dot.getKWidth() == 8;
125+
}
126+
}
127+
return false;
128+
};
119129
if (isa<SharedEncodingAttr>(srcLayout) &&
120-
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
121-
dstLayout)) {
130+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131+
dstLayout) ||
132+
isAmpereLargeKWidth(dstLayout))) {
122133
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
123134
rewriter);
124135
}
@@ -170,6 +181,37 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
170181
SmallVector<Value> outVals = loadSharedToDistributed(
171182
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
172183

184+
// FIXME [Dot LL]
185+
// Ampere case
186+
// In this case, we need to pack the outputs into i32
187+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
188+
if (elemLlvmTy.isInteger(8)) {
189+
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190+
return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
191+
or_(shl(zext(i32_ty, a3), i32_val(16)),
192+
shl(zext(i32_ty, a4), i32_val(24))));
193+
};
194+
SmallVector<Value> outVals32(outVals.size() / 4);
195+
for (int i = 0; i < outVals32.size(); ++i) {
196+
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
197+
outVals[4 * i + 2], outVals[4 * i + 3]);
198+
}
199+
outVals = outVals32;
200+
} else {
201+
assert(elemLlvmTy.isBF16() && "Unexpected element type");
202+
auto concat = [&](Value a, Value b) {
203+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
204+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
205+
};
206+
207+
SmallVector<Value> outVals32(outVals.size() / 2);
208+
for (int i = 0; i < outVals32.size(); ++i) {
209+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
210+
}
211+
outVals = outVals32;
212+
}
213+
}
214+
173215
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
174216
rewriter.replaceOp(op, result);
175217

lib/Dialect/Triton/IR/Types.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Dialect/Triton/IR/Types.h"
22

33
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
4+
#include "mlir/IR/TypeUtilities.h"
45
#include "mlir/Support/LLVM.h"
56
#include "triton/Dialect/Triton/IR/Dialect.h"
67
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
@@ -157,6 +158,12 @@ Type getPointerTypeSameShape(Type type) {
157158
}
158159
}
159160

161+
Type getPointerTypeToElement(Type type) {
162+
Type elementType = getElementTypeOrSelf(type);
163+
PointerType ptrType = PointerType::get(elementType, 1);
164+
return ptrType;
165+
}
166+
160167
// upstream Triton only uses address space 1 for Pointer Type
161168
Type getPointerType(Type type, int addressSpace) {
162169
return PointerType::get(type, addressSpace);

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 3 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -163,85 +163,6 @@ void LayoutRematerialization::cleanup() {
163163
op->erase();
164164
}
165165

166-
// Look ahead to at the transitive uses and see if there is a convert to mma
167-
// operations.
168-
bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
169-
SmallVector<Value> queue = {op->getResult(0)};
170-
SetVector<Operation *> forwardSlice;
171-
llvm::SmallDenseSet<Value> seen;
172-
while (!queue.empty()) {
173-
Value currentValue = queue.back();
174-
queue.pop_back();
175-
getForwardSlice(currentValue, &forwardSlice);
176-
for (Operation *op : forwardSlice) {
177-
// HACK: Stop propagation if the ReduceOp is using mma layout but is
178-
// producing tensor smaller than the layout we would like to propagate.
179-
// This is to avoid stepping into the known bug.
180-
if (isa<mlir::triton::ReduceOp>(op)) {
181-
auto tensorType =
182-
dyn_cast<RankedTensorType>(op->getOperand(0).getType());
183-
if (tensorType &&
184-
isa<NvidiaMmaEncodingAttr>(tensorType.getEncoding())) {
185-
auto mmaInstrShape =
186-
cast<NvidiaMmaEncodingAttr>(encoding).getInstrShape();
187-
if (tensorType.getShape()[tensorType.getRank() - 2] <
188-
mmaInstrShape[0] ||
189-
tensorType.getShape()[tensorType.getRank() - 1] <
190-
mmaInstrShape[1]) {
191-
return false;
192-
}
193-
}
194-
}
195-
196-
if (auto convertOp = dyn_cast<ConvertLayoutOp>(op)) {
197-
Attribute dstEncoding = convertOp.getType().getEncoding();
198-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(dstEncoding))
199-
return (mmaLayout.getVersionMajor() > 1) ? true
200-
: mmaLayout == encoding;
201-
if (isa<triton::gpu::AMDMfmaEncodingAttr,
202-
triton::gpu::AMDWmmaEncodingAttr>(dstEncoding))
203-
return true;
204-
if (isa<triton::gpu::DotOperandEncodingAttr>(dstEncoding)) {
205-
if (auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(encoding)) {
206-
return mmaLayout.getVersionMajor() > 1;
207-
} else {
208-
assert((mlir::isa<triton::gpu::AMDMfmaEncodingAttr,
209-
triton::gpu::AMDWmmaEncodingAttr>(encoding)));
210-
return true;
211-
}
212-
}
213-
}
214-
bool isMMAV3 =
215-
isa<NvidiaMmaEncodingAttr>(encoding) &&
216-
cast<NvidiaMmaEncodingAttr>(encoding).getVersionMajor() == 3;
217-
if (isMMAV3 && (isa<LocalAllocOp>(op) || isa<LocalStoreOp>(op)))
218-
return true;
219-
auto yield = dyn_cast<scf::YieldOp>(op);
220-
if (!yield)
221-
continue;
222-
if (auto ifOp = dyn_cast<scf::IfOp>(yield->getParentOp())) {
223-
for (OpOperand &operand : yield->getOpOperands()) {
224-
Operation *def = operand.get().getDefiningOp();
225-
if (def &&
226-
(forwardSlice.count(def) || operand.get() == currentValue) &&
227-
(seen.insert(operand.get()).second == true))
228-
queue.push_back(ifOp.getResult(operand.getOperandNumber()));
229-
}
230-
}
231-
auto forOp = dyn_cast<scf::ForOp>(yield.getOperation()->getParentOp());
232-
if (!forOp)
233-
continue;
234-
for (OpOperand &operand : yield->getOpOperands()) {
235-
Operation *def = operand.get().getDefiningOp();
236-
if (def && (forwardSlice.count(def) || operand.get() == currentValue) &&
237-
(seen.insert(operand.get()).second == true))
238-
queue.push_back(forOp.getRegionIterArg(operand.getOperandNumber()));
239-
}
240-
}
241-
}
242-
return false;
243-
}
244-
245166
// Return true if the op is an op with a layout we don't want to change. We will
246167
// propagate the layout starting from anchor ops.
247168
bool isLayoutAnchor(Operation *op) {
@@ -262,18 +183,8 @@ bool isLayoutAnchor(Operation *op) {
262183
}
263184

264185
void LayoutPropagation::initAnchorLayout() {
265-
auto maybeAddAnchor = [&](Value v) {
186+
auto addAnchor = [&](Value v) {
266187
if (auto tensorType = dyn_cast<RankedTensorType>(v.getType())) {
267-
// Workaround, don't popagate MMA layout unless there is a convert
268-
// back to mma further down to avoid generating reduction with MMA
269-
// layout that may have lower performance.
270-
// This can be improved with more aggressive backward propagation.
271-
if (isa<MmaEncodingTrait>(tensorType.getEncoding()) &&
272-
v.getDefiningOp() &&
273-
!hasConvertToMMATransisitiveUse(v.getDefiningOp(),
274-
tensorType.getEncoding())) {
275-
return;
276-
}
277188
layouts.insert({v, LayoutInfo(tensorType.getEncoding())});
278189
}
279190
};
@@ -282,13 +193,13 @@ void LayoutPropagation::initAnchorLayout() {
282193
// you can pass a tensor with an encoding as an arg, instead of explicitly
283194
// calling tt.load.
284195
for (auto arg : funcOp.getArguments()) {
285-
maybeAddAnchor(arg);
196+
addAnchor(arg);
286197
}
287198

288199
funcOp.walk([&](Operation *op) {
289200
if (isLayoutAnchor(op)) {
290201
for (auto result : op->getResults()) {
291-
maybeAddAnchor(result);
202+
addAnchor(result);
292203
}
293204
}
294205
});

python/test/regression/conftest.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
1-
# content of conftest.py
2-
1+
import os
32
import pytest
3+
import tempfile
44

55

66
def pytest_addoption(parser):
7-
parser.addoption("--device", action="store", default='cuda')
7+
parser.addoption("--device", action="store", default="cuda")
88

99

1010
@pytest.fixture
1111
def device(request):
1212
return request.config.getoption("--device")
13+
14+
15+
@pytest.fixture
16+
def fresh_triton_cache():
17+
with tempfile.TemporaryDirectory() as tmpdir:
18+
try:
19+
os.environ["TRITON_CACHE_DIR"] = tmpdir
20+
yield tmpdir
21+
finally:
22+
os.environ.pop("TRITON_CACHE_DIR", None)

0 commit comments

Comments
 (0)