Skip to content

Commit df042a2

Browse files
authored
Merge branch 'main' into lesh/conda-oct
2 parents 8df3398 + 529ca78 commit df042a2

File tree

41 files changed

+2345
-678
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2345
-678
lines changed

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
487873f7cafeb0fd390eaefe40496b804bceabbd
1+
0efa590d435d2b4aefcbad9014dd5fa75dcf8405

.github/workflows/auto-update-translator-cid.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ jobs:
8686
- name: Search the latest valid Translator cid
8787
if: ${{ env.TARGET_PRID == null }}
8888
run: |
89-
env
9089
./scripts/check-update-translator-cid.sh $CID_LATEST $CID_CURRENT
9190
if git status --porcelain ./lib/Target/SPIRV/spirv-llvm-translator.conf | grep '^ M'; then
9291
echo "MODIFIED=true" >> $GITHUB_ENV

.github/workflows/integration-tests.yml

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,14 @@ jobs:
239239
cd python
240240
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
241241
if [ ! -d "${LIT_TEST_DIR}" ]; then
242-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
242+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
243243
fi
244244
lit -v "${LIT_TEST_DIR}"
245245
- name: Run python tests on CUDA
246246
run: |
247247
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
248248
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
249-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
249+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
250250
fi
251251
cd python/test/unit
252252
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -268,14 +268,16 @@ jobs:
268268
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
269269
runtime/test_autotuner.py::test_kwargs[False]\
270270
../../tutorials/06-fused-attention.py::test_op --device cpu
271+
- name: Run regression tests
272+
run: |
273+
cd python/test/regression
274+
python3 -m pytest -s -n 8 .
271275
- name: Run C++ unittests
272276
run: |
273277
cd python
274278
cd "build/$(ls build | grep -i cmake)"
275279
ctest -j32
276280
- name: Run Proton tests
277-
env:
278-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
279281
run: |
280282
cd third_party/proton
281283
python3 -m pytest -s test
@@ -395,14 +397,14 @@ jobs:
395397
cd python
396398
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
397399
if [ ! -d "${LIT_TEST_DIR}" ]; then
398-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
400+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
399401
fi
400402
lit -v "${LIT_TEST_DIR}"
401403
- name: Run python tests on HIP
402404
run: |
403405
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
404406
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
405-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
407+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
406408
fi
407409
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
408410
cd python/test/unit
@@ -416,10 +418,15 @@ jobs:
416418
417419
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
418420
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
421+
- name: Run regression tests
422+
run: |
423+
# Reenable test_functional_regression.py once it's fixed
424+
cd python/test/regression
425+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
419426
- name: Run Proton tests
420427
run: |
421428
cd third_party/proton
422-
python3 -m pytest test
429+
python3 -m pytest -s test
423430
- name: Run C++ unittests
424431
run: |
425432
cd python

.github/workflows/integration-tests.yml.in

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,15 @@ jobs:
272272
cd python
273273
LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test"
274274
if [ ! -d "${LIT_TEST_DIR}" ]; then
275-
echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1
275+
echo "Could not find '${LIT_TEST_DIR}'" ; exit -1
276276
fi
277277
lit -v "${LIT_TEST_DIR}"
278278

279279
- name: Run python tests on CUDA
280280
run: |
281281
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
282282
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
283-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
283+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
284284
fi
285285
cd python/test/unit
286286
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -304,16 +304,20 @@ jobs:
304304
runtime/test_autotuner.py::test_kwargs[False]\
305305
../../tutorials/06-fused-attention.py::test_op --device cpu
306306

307+
- name: Run regression tests
308+
run: |
309+
cd python/test/regression
310+
python3 -m pytest -s -n 8 .
311+
307312
- &run-cpp-unittests-step
308313
name: Run C++ unittests
309314
run: |
310315
cd python
311316
cd "build/$(ls build | grep -i cmake)"
312317
ctest -j32
313318

314-
- name: Run Proton tests
315-
env:
316-
LD_LIBRARY_PATH: "/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
319+
- &run-proton-tests-step
320+
name: Run Proton tests
317321
run: |
318322
cd third_party/proton
319323
python3 -m pytest -s test
@@ -398,7 +402,7 @@ jobs:
398402
run: |
399403
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
400404
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
401-
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
405+
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
402406
fi
403407
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
404408
cd python/test/unit
@@ -413,11 +417,13 @@ jobs:
413417
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0
414418
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s -n 8 language/test_line_info.py
415419

416-
- name: Run Proton tests
420+
- name: Run regression tests
417421
run: |
418-
cd third_party/proton
419-
python3 -m pytest test
422+
# Reenable test_functional_regression.py once it's fixed
423+
cd python/test/regression
424+
python3 -m pytest -s -n 8 ./test_cast_matmul.py
420425

426+
- *run-proton-tests-step
421427
- *run-cpp-unittests-step
422428
- *save-build-artifacts-step
423429
- *inspect-cache-directories-step

benchmarks/triton_kernels_benchmark/benchmark_driver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def serialize_kernel_metadata(arg, args_dict):
405405
args_dict["shared_memory"] = arg.shared
406406
args_dict["kernel_name"] = arg.name
407407
args_dict["spv_name"] = f"{arg.name}.spv"
408+
args_dict["build_flags"] = arg.build_flags
408409

409410

410411
def serialize_args(args, constants, signature):

bin/triton-lsp.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,5 @@ int main(int argc, char **argv) {
66
mlir::DialectRegistry registry;
77
registerTritonDialects(registry);
88

9-
mlir::MLIRContext context(registry);
109
return mlir::failed(mlir::MlirLspServerMain(argc, argv, registry));
1110
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
The conference slides are available [here](https://drive.google.com/drive/folders/1osK9hwcX_lC1EjdZGB-v4w5oKx23UnU2?usp=drive_link)
2+
3+
The conference videos are available [here](https://www.youtube.com/playlist?list=PLc_vA1r0qoiTjlrINKUuFrI8Ptoopm8Vz).

include/triton/Dialect/Triton/IR/Types.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Type getI32SameShape(Type type);
3434

3535
Type getPointerTypeSameShape(Type type);
3636

37+
Type getPointerTypeToElement(Type type);
38+
3739
} // namespace triton
3840

3941
} // namespace mlir

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116
RankedTensorType dstTy = op.getType();
117117
Attribute srcLayout = srcTy.getEncoding();
118118
Attribute dstLayout = dstTy.getEncoding();
119+
// FIXME [Dot LL]
120+
// Do for all DotOperandEncodingAttr once we have LLs for all of them
121+
auto isAmpereLargeKWidth = [](Attribute layout) {
122+
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123+
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
124+
return mma.isAmpere() && dot.getKWidth() == 8;
125+
}
126+
}
127+
return false;
128+
};
119129
if (isa<SharedEncodingAttr>(srcLayout) &&
120-
isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
121-
dstLayout)) {
130+
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131+
dstLayout) ||
132+
isAmpereLargeKWidth(dstLayout))) {
122133
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
123134
rewriter);
124135
}
@@ -170,6 +181,37 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
170181
SmallVector<Value> outVals = loadSharedToDistributed(
171182
dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
172183

184+
// FIXME [Dot LL]
185+
// Ampere case
186+
// In this case, we need to pack the outputs into i32
187+
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
188+
if (elemLlvmTy.isInteger(8)) {
189+
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190+
return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
191+
or_(shl(zext(i32_ty, a3), i32_val(16)),
192+
shl(zext(i32_ty, a4), i32_val(24))));
193+
};
194+
SmallVector<Value> outVals32(outVals.size() / 4);
195+
for (int i = 0; i < outVals32.size(); ++i) {
196+
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
197+
outVals[4 * i + 2], outVals[4 * i + 3]);
198+
}
199+
outVals = outVals32;
200+
} else {
201+
assert(elemLlvmTy.isBF16() && "Unexpected element type");
202+
auto concat = [&](Value a, Value b) {
203+
return or_(zext(i32_ty, bitcast(a, i16_ty)),
204+
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
205+
};
206+
207+
SmallVector<Value> outVals32(outVals.size() / 2);
208+
for (int i = 0; i < outVals32.size(); ++i) {
209+
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
210+
}
211+
outVals = outVals32;
212+
}
213+
}
214+
173215
Value result = packLLElements(loc, typeConverter, outVals, rewriter, dstTy);
174216
rewriter.replaceOp(op, result);
175217

lib/Dialect/Triton/IR/Types.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "triton/Dialect/Triton/IR/Types.h"
22

33
#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc`
4+
#include "mlir/IR/TypeUtilities.h"
45
#include "mlir/Support/LLVM.h"
56
#include "triton/Dialect/Triton/IR/Dialect.h"
67
#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc`
@@ -157,6 +158,12 @@ Type getPointerTypeSameShape(Type type) {
157158
}
158159
}
159160

161+
Type getPointerTypeToElement(Type type) {
162+
Type elementType = getElementTypeOrSelf(type);
163+
PointerType ptrType = PointerType::get(elementType, 1);
164+
return ptrType;
165+
}
166+
160167
// upstream Triton only uses address space 1 for Pointer Type
161168
Type getPointerType(Type type, int addressSpace) {
162169
return PointerType::get(type, addressSpace);

0 commit comments

Comments
 (0)