Skip to content

Commit f213106

Browse files
Merge commit 'fa229d1c4bee16c094be9427334575ec1e79f66c'
2 parents e48642c + fa229d1 commit f213106

File tree

31 files changed

+1132
-129
lines changed

31 files changed

+1132
-129
lines changed

.github/workflows/integration-tests.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ jobs:
245245
lit -v "${LIT_TEST_DIR}"
246246
- name: Run python tests on CUDA
247247
run: |
248-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
249-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
250-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
248+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
249+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
250+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
251251
fi
252252
cd python/test/unit
253253
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -257,7 +257,7 @@ jobs:
257257
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
258258
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
259259
python3 -m pytest -s hopper/test_flashattention.py
260-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
260+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
261261
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
262262
- name: Run interpreter tests
263263
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
@@ -401,9 +401,9 @@ jobs:
401401
lit -v "${LIT_TEST_DIR}"
402402
- name: Run python tests on HIP
403403
run: |
404-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C"
405-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
406-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
404+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
405+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
406+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
407407
fi
408408
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
409409
cd python/test/unit
@@ -412,7 +412,7 @@ jobs:
412412
--ignore=test_debug.py
413413
# TODO: uncomment
414414
# pytest --capture=tee-sys -rfs test_debug.py
415-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
415+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
416416
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
417417
418418
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0

.github/workflows/integration-tests.yml.in

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ jobs:
279279

280280
- name: Run python tests on CUDA
281281
run: |
282-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/_C"
283-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
284-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
282+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/build/$(ls python/build | grep -i lib)/triton/instrumentation"
283+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
284+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
285285
fi
286286
cd python/test/unit
287287
python3 -m pytest -s -n 8 --ignore=hopper/test_flashattention.py --ignore=language/test_line_info.py --ignore=language/test_subprocess.py --ignore=test_debug.py
@@ -291,7 +291,7 @@ jobs:
291291
TRITON_DISABLE_LINE_INFO=0 python3 -m pytest -s language/test_line_info.py
292292
# Run hopper/test_flashattention.py separately to avoid out of gpu memory
293293
python3 -m pytest -s hopper/test_flashattention.py
294-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
294+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
295295
python3 -m pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
296296

297297
- name: Run interpreter tests
@@ -397,9 +397,9 @@ jobs:
397397

398398
- name: Run python tests on HIP
399399
run: |
400-
SHARED_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/_C"
401-
if [ ! -d "${SHARED_LIB_DIR}" ]; then
402-
echo "Coult not find '${SHARED_LIB_DIR}'" ; exit -1
400+
INSTRUMENTATION_LIB_DIR="${GITHUB_WORKSPACE}/python/triton/instrumentation"
401+
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
402+
echo "Coult not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
403403
fi
404404
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
405405
cd python/test/unit
@@ -408,7 +408,7 @@ jobs:
408408
--ignore=test_debug.py
409409
# TODO: uncomment
410410
# pytest --capture=tee-sys -rfs test_debug.py
411-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
411+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
412412
pytest --capture=tee-sys -rfs -vvv instrumentation/test_gpuhello.py
413413

414414
# Run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8888
mlir::registerTritonAMDGPUStreamPipeline();
8989
mlir::registerTritonAMDGPUStreamPipelineV2();
9090
mlir::registerTritonAMDGPUCanonicalizePointers();
91+
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
92+
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
9193

9294
// TODO: register Triton & TritonGPU passes
9395
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,33 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30+
struct BackendCallbacks {
31+
/**
32+
* A backend-specific callback for appending auxiliary data during
33+
* `LocalStoreOp` conversion.
34+
*
35+
* @param[in] op The reference to the re-written `LocalStoreOp`.
36+
* @param[in] count The number of issued LLVM instructions.
37+
* @param[in] type The input type of issued LLVM instructions.
38+
*/
39+
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40+
Type llvmOpType)>
41+
localStoreOpConversion = nullptr;
42+
};
43+
3044
void populateElementwiseOpToLLVMPatterns(
3145
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
3246
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
3347
PatternBenefit benefit);
3448

35-
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36-
const TargetInfoBase &targetInfo,
37-
RewritePatternSet &patterns,
38-
PatternBenefit benefit);
49+
// The given callback is invoked at the end of a successful rewrite. The
50+
// callback receives 1) the current source op, 2) the number of issued LLVM
51+
// instructions and 3) their input types. Each MLIR backend can provide a
52+
// callback and, thus, handle backend-specific behaviors.
53+
void populateMemoryOpToLLVMPattern(
54+
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55+
RewritePatternSet &patterns, PatternBenefit benefit,
56+
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
3957

4058
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
4159
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370-
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371-
Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter,
1373-
const TargetInfoBase &target);
1369+
void storeDistributedToShared(
1370+
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371+
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373+
std::pair<size_t, Type> *const llvmOpCount = nullptr);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
1515
// blocked -> shared.
1616
// Swizzling in shared memory to avoid bank conflict. Normally used for
1717
// A/B operands of dots.
18-
void lowerDistributedToShared(Location loc, Value src, Value dst,
19-
Value adaptorSrc,
20-
const SharedMemoryObject &smemObj,
21-
const LLVMTypeConverter *typeConverter,
22-
ConversionPatternRewriter &rewriter,
23-
const TargetInfoBase &targetInfo) {
18+
void lowerDistributedToShared(
19+
Location loc, Value src, Value dst, Value adaptorSrc,
20+
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21+
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22+
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
2423
auto srcTy = cast<RankedTensorType>(src.getType());
2524
auto dstTy = cast<MemDescType>(dst.getType());
2625
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
@@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
3332
auto dstStrides = smemObj.getStrides();
3433
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
3534
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
36-
loc, rewriter, targetInfo);
35+
loc, rewriter, targetInfo, llvmOpCount);
3736
}
3837

3938
struct LocalAllocOpConversion
@@ -185,12 +184,15 @@ struct LocalStoreOpConversion
185184
public:
186185
using ConvertOpToLLVMPattern<
187186
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
187+
using BackendCallbackType =
188+
decltype(BackendCallbacks::localStoreOpConversion);
188189

189190
LocalStoreOpConversion(const LLVMTypeConverter &converter,
190191
const TargetInfoBase &targetInfo,
192+
BackendCallbackType backendCallback,
191193
PatternBenefit benefit = 1)
192194
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
193-
targetInfo(targetInfo) {}
195+
targetInfo(targetInfo), backendCallback(backendCallback) {}
194196

195197
LogicalResult
196198
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -200,24 +202,36 @@ struct LocalStoreOpConversion
200202
getTypeConverter()->convertType(op.getDst().getType().getElementType());
201203
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
202204
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
205+
206+
std::pair<size_t, Type> llvmOpCount;
203207
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
204208
adaptor.getSrc(), smemObj, getTypeConverter(),
205-
rewriter, targetInfo);
209+
rewriter, targetInfo, &llvmOpCount);
210+
211+
if (backendCallback)
212+
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);
213+
206214
rewriter.eraseOp(op);
207215
return success();
208216
}
209217

210218
private:
211219
const TargetInfoBase &targetInfo;
220+
BackendCallbackType backendCallback;
212221
};
213222

214223
} // namespace
215224

216225
void mlir::triton::populateMemoryOpToLLVMPattern(
217226
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
218-
RewritePatternSet &patterns, PatternBenefit benefit) {
227+
RewritePatternSet &patterns, PatternBenefit benefit,
228+
std::optional<BackendCallbacks> backendCallbacks) {
219229
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
220230
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
221231
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
222-
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
232+
233+
auto backendCall =
234+
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
235+
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
236+
benefit);
223237
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
404404
Type elemLlvmTy, ArrayRef<Value> srcVals,
405405
Value smemBase, ArrayRef<Value> dstStrides,
406406
Location loc, RewriterBase &rewriter,
407-
const TargetInfoBase &target) {
407+
const TargetInfoBase &target,
408+
std::pair<size_t, Type> *const llvmOpCount) {
408409
bool success = emitTransferBetweenRegistersAndShared(
409410
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
410411
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
@@ -418,7 +419,12 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
418419
store(vec, vecAddr)
419420
.setAlignment(vecTy.getNumElements() *
420421
elemLlvmTy.getIntOrFloatBitWidth() / 8);
422+
if (llvmOpCount) {
423+
++(llvmOpCount->first);
424+
llvmOpCount->second = vecTy;
425+
}
421426
});
427+
422428
if (!success)
423429
llvm::report_fatal_error("Failed to emit transfer from register to shared");
424430
}

lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,11 @@ class RewriteTensorPointerPass
370370
}
371371

372372
// update rewritedInfo
373+
auto opResults = op.getResults();
373374
unsigned oldResIdx = 0, newResIdx = 0;
374375
while (oldResIdx < results.size()) {
375376
if (!triton::isTensorPointerType(results[oldResIdx].getType())) {
377+
opResults[oldResIdx].replaceAllUsesWith(newOp.getResult(newResIdx));
376378
oldResIdx++;
377379
newResIdx++;
378380
} else {

python/triton/language/core.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,26 +1909,20 @@ def where(condition, x, y, _builder=None):
19091909
def add(x, y, sanitize_overflow: constexpr = True, _builder=None):
19101910
x = _unwrap_if_constexpr(x)
19111911
y = _unwrap_if_constexpr(y)
1912-
x = semantic.to_tensor(x, _builder)
1913-
y = semantic.to_tensor(y, _builder)
19141912
return semantic.add(x, y, sanitize_overflow, _builder)
19151913

19161914

19171915
@builtin
19181916
def sub(x, y, sanitize_overflow: constexpr = True, _builder=None):
19191917
x = _unwrap_if_constexpr(x)
19201918
y = _unwrap_if_constexpr(y)
1921-
x = semantic.to_tensor(x, _builder)
1922-
y = semantic.to_tensor(y, _builder)
19231919
return semantic.sub(x, y, sanitize_overflow, _builder)
19241920

19251921

19261922
@builtin
19271923
def mul(x, y, sanitize_overflow: constexpr = True, _builder=None):
19281924
x = _unwrap_if_constexpr(x)
19291925
y = _unwrap_if_constexpr(y)
1930-
x = semantic.to_tensor(x, _builder)
1931-
y = semantic.to_tensor(y, _builder)
19321926
return semantic.mul(x, y, sanitize_overflow, _builder)
19331927

19341928

scripts/test-triton.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,12 +298,12 @@ run_instrumentation_tests() {
298298
return
299299
fi
300300

301-
SHARED_LIB_DIR=$(ls -1d $TRITON_PROJ/python/build/*lib*/triton/_C) || err "Could not find $TRITON_PROJ/python/build/*lib*/triton/_C, build Triton first"
301+
INSTRUMENTATION_LIB_DIR=$(ls -1d $TRITON_PROJ/python/build/*lib*/triton/instrumentation) || err "Could not find $TRITON_PROJ/python/build/*lib*/triton/instrumentation, build Triton first"
302302

303303
cd $TRITON_PROJ/python/test/unit
304304

305305
TRITON_TEST_SUITE=instrumentation \
306-
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${SHARED_LIB_DIR}/libGPUHello.so \
306+
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=${INSTRUMENTATION_LIB_DIR}/libGPUInstrumentationTestLib.so \
307307
pytest -vvv --device xpu instrumentation/test_gpuhello.py
308308
}
309309

0 commit comments

Comments
 (0)