Skip to content

Commit 7d23ec4

Browse files
Fix hardcoding of shared address space in target independent code (#4833)
Currently, we hardcode the address space of shared memory in target independent code. This wouldn't work for a backend that uses a different address space for shared memory, so this PR does two things: 1. Removes unnecessary casts (that go from a shared memory ptr to a shared memory ptr) and instances of hardcoding that can be avoided. It does these cleanups only to target independent code - if needed, I can do it to the backends as well 2. Introduces getSharedAddressSpace to TargetInfo and uses it for the rest The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because it is a non-functional change and any existing tests that exercise shared memory should exercise the changes in this test. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 6c3e953 commit 7d23ec4

26 files changed

+125
-92
lines changed

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
8686

8787
void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
8888
RewritePatternSet &patterns,
89+
const TargetInfoBase &targetInfo,
8990
PatternBenefit benefit);
9091

9192
void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
@@ -95,6 +96,7 @@ void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
9596

9697
void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,
9798
RewritePatternSet &patterns, int numWarps,
99+
const TargetInfoBase &targetInfo,
98100
PatternBenefit benefit);
99101

100102
void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class TargetInfoBase {
8080
StringRef message, StringRef file, StringRef func,
8181
int line) const = 0;
8282

83+
virtual int getSharedAddressSpace() const = 0;
84+
8385
virtual ~TargetInfoBase() {}
8486
};
8587
} // namespace mlir::triton

include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
55
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
66
#include "triton/Conversion/MLIRTypes.h"
7+
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
78
#include "triton/Dialect/TritonGPU/IR/Types.h"
89

910
using namespace mlir;
@@ -14,12 +15,14 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
1415
using TypeConverter::convertType;
1516

1617
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
18+
const TargetInfoBase &targetInfo,
1719
const DataLayoutAnalysis *analysis = nullptr);
1820

1921
Type getElementTypeForStruct(TensorOrMemDesc type);
2022
Type convertTritonPointerType(triton::PointerType type);
21-
Type convertTritonTensorType(RankedTensorType type);
22-
Type convertMemDescType(MemDescType type);
23+
Type convertTritonTensorType(RankedTensorType type,
24+
const TargetInfoBase &targetInfo);
25+
Type convertMemDescType(MemDescType type, const TargetInfoBase &targetInfo);
2326
Type convertAsyncToken(triton::gpu::AsyncTokenType type);
2427
};
2528

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,8 +372,9 @@ inline Value getStackPointer(RewriterBase &rewriter,
372372
}
373373

374374
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
375-
Operation *op) {
376-
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
375+
const TargetInfoBase &target, Operation *op) {
376+
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
377+
target.getSharedAddressSpace());
377378
FunctionOpInterface func =
378379
op->template getParentOfType<FunctionOpInterface>();
379380
assert(op->hasAttr("allocation.offset"));
@@ -1222,7 +1223,7 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs(
12221223
// then (x + y) XOR z = 0byyyyxxxx XOR 0b00000zzzz = (x XOR z) + y
12231224
// This means that we can use some immediate offsets for shared memory
12241225
// operations.
1225-
auto dstPtrTy = ptr_ty(rewriter.getContext(), 3);
1226+
auto dstPtrTy = smemObj.base.getType();
12261227
auto dstOffset = dot(rewriter, loc, offsetVals, smemObj.strides);
12271228
Value dstPtrBase = gep(dstPtrTy, resElemTy, smemObj.base, dstOffset);
12281229

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
5151
// CallOpInterfaceLowering is adapted from
5252
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485
5353
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
54-
CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit)
55-
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit) {}
54+
CallOpConversion(LLVMTypeConverter &converter,
55+
const TargetInfoBase &targetInfo, PatternBenefit benefit)
56+
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
57+
targetInfo(targetInfo) {}
5658

5759
LogicalResult
5860
matchAndRewrite(triton::CallOp callOp,
@@ -85,8 +87,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
8587
promotedOperands.push_back(base);
8688
return promotedOperands;
8789
}
88-
promotedOperands.push_back(
89-
LLVM::getSharedMemoryBase(callOp->getLoc(), rewriter, callOp));
90+
promotedOperands.push_back(LLVM::getSharedMemoryBase(
91+
callOp->getLoc(), rewriter, targetInfo, callOp));
9092
return promotedOperands;
9193
}
9294

@@ -129,13 +131,14 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
129131
}
130132
return results;
131133
}
134+
const TargetInfoBase &targetInfo;
132135
};
133136

134137
} // namespace
135138

136139
void mlir::triton::populateControlFlowOpToLLVMPattern(
137140
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
138-
PatternBenefit benefit) {
141+
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
139142
patterns.add<ReturnOpConversion>(typeConverter, benefit);
140-
patterns.add<CallOpConversion>(typeConverter, benefit);
143+
patterns.add<CallOpConversion>(typeConverter, targetInfo, benefit);
141144
}

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct ConvertLayoutOpConversion
4646
Attribute srcLayout = srcTy.getEncoding();
4747
Attribute dstLayout = dstTy.getEncoding();
4848
if (isSupported(srcLayout, dstLayout)) {
49-
return lowerDistributedToDistributed(op, adaptor, rewriter);
49+
return lowerDistributedToDistributed(op, adaptor, rewriter, targetInfo);
5050
}
5151
return failure();
5252
}
@@ -115,10 +115,9 @@ struct ConvertLayoutOpConversion
115115
shapePerCTA);
116116
Value offset = linearize(rewriter, loc, multiDimOffsetWrapped,
117117
paddedRepShape, outOrd);
118-
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
118+
auto elemPtrTy = smemBase.getType();
119119
Value ptr = gep(elemPtrTy, llvmElemTy, smemBase, offset);
120120
auto vecTy = vec_ty(llvmElemTy, vec);
121-
ptr = bitcast(ptr, ptr_ty(rewriter.getContext(), 3));
122121
if (stNotRd) {
123122
Value valVec = undef(vecTy);
124123
for (unsigned v = 0; v < vec; ++v) {
@@ -150,7 +149,8 @@ struct ConvertLayoutOpConversion
150149
// Data padding in shared memory to avoid bank conflict.
151150
LogicalResult
152151
lowerDistributedToDistributed(ConvertLayoutOp op, OpAdaptor adaptor,
153-
ConversionPatternRewriter &rewriter) const {
152+
ConversionPatternRewriter &rewriter,
153+
const TargetInfoBase &targetInfo) const {
154154
auto loc = op.getLoc();
155155
auto typeConverter = getTypeConverter();
156156
RankedTensorType srcTy = op.getSrc().getType();
@@ -168,9 +168,7 @@ struct ConvertLayoutOpConversion
168168
}
169169

170170
Value smemBase =
171-
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
172-
auto elemPtrTy = ptr_ty(rewriter.getContext(), 3);
173-
smemBase = bitcast(smemBase, elemPtrTy);
171+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
174172
auto shape = dstTy.getShape();
175173
unsigned rank = dstTy.getRank();
176174
SmallVector<unsigned> numReplicates(rank);
@@ -447,8 +445,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
447445
MLIRContext *ctx = op.getContext();
448446
auto loc = op.getLoc();
449447

450-
auto sharedPtrTy = ptr_ty(ctx, /*addressSpace=*/3);
451-
452448
StringAttr kRegister = str_attr("register");
453449
StringAttr kLane = str_attr("lane");
454450
StringAttr kWarp = str_attr("warp");
@@ -508,7 +504,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
508504
collectRegsForIter(ctx, shmemLoadLayout);
509505

510506
Value smemBase =
511-
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
507+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
508+
auto sharedPtrTy = smemBase.getType();
512509
Type elemTy = inVals[0].getType();
513510
auto outSize = shmemLoadLayout.getInDimSize(kRegister);
514511
auto iterations = sharedLayout.getInDimSize(kIteration);

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ Value loadAFMA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
131131
}
132132
auto elemTy = typeConverter->convertType(aTensorTy.getElementType());
133133

134-
Type ptrTy = ptr_ty(rewriter.getContext(), 3);
134+
Type ptrTy = aSmem.base.getType();
135135
SmallVector<Value> aPtrs(aNumPtr);
136136
for (int i = 0; i < aNumPtr; ++i)
137137
aPtrs[i] = gep(ptrTy, elemTy, aSmem.base, aOff[i]);
@@ -197,7 +197,7 @@ Value loadBFMA(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
197197
}
198198
auto elemTy = typeConverter->convertType(bTensorTy.getElementType());
199199

200-
Type ptrTy = ptr_ty(rewriter.getContext(), 3);
200+
Type ptrTy = bSmem.base.getType();
201201
SmallVector<Value> bPtrs(bNumPtr);
202202
for (int i = 0; i < bNumPtr; ++i)
203203
bPtrs[i] = gep(ptrTy, elemTy, bSmem.base, bOff[i]);

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ using namespace mlir::triton;
1818
/// information.
1919
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
2020
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
21-
PatternBenefit benefit)
22-
: ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps) {}
21+
const TargetInfoBase &targetInfo, PatternBenefit benefit)
22+
: ConvertOpToLLVMPattern(converter, benefit), numWarps(numWarps),
23+
targetInfo(targetInfo) {}
2324

2425
/// Only retain those attributes that are not constructed by
2526
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
@@ -38,12 +39,14 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
3839
}
3940

4041
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
41-
ConversionPatternRewriter &rewriter) const {
42+
ConversionPatternRewriter &rewriter,
43+
const TargetInfoBase &targetInfo) const {
4244
// Push back a variable that indicates the current stack pointer of shared
4345
// memory to the function arguments.
4446
auto loc = funcOp.getLoc();
4547
auto ctx = funcOp->getContext();
46-
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
48+
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
49+
targetInfo.getSharedAddressSpace());
4750
// 1. Modify the function type to add the new argument.
4851
auto funcTy = funcOp.getFunctionType();
4952
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
@@ -109,7 +112,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
109112
// Prevent LLVM's inliner to inline this function
110113
auto amendedFuncOp = funcOp;
111114
if (!LLVM::isKernel(funcOp))
112-
amendedFuncOp = amendFuncOp(funcOp, rewriter);
115+
amendedFuncOp = amendFuncOp(funcOp, rewriter, targetInfo);
113116

114117
FailureOr<LLVM::LLVMFuncOp> maybeNewFuncOp =
115118
mlir::convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter,
@@ -150,12 +153,13 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
150153

151154
private:
152155
int numWarps{0};
156+
const TargetInfoBase &targetInfo;
153157
};
154158

155159
} // namespace
156160

157161
void mlir::triton::populateFuncOpConversionPattern(
158162
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, int numWarps,
159-
PatternBenefit benefit) {
160-
patterns.add<FuncOpConversion>(typeConverter, numWarps, benefit);
163+
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
164+
patterns.add<FuncOpConversion>(typeConverter, numWarps, targetInfo, benefit);
161165
}

lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ struct HistogramOpConversion
182182
// generate the right layout. Currently the warp level histogram generates
183183
// data in the default blocked layout.
184184
Value baseSharedMemPtr =
185-
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
185+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
186186
auto dstType = op.getType();
187187
Attribute dstEncoding = dstType.getEncoding();
188188
auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding,

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct LocalAllocOpConversion
5151
return failure();
5252
Location loc = op->getLoc();
5353
Value smemBase =
54-
LLVM::getSharedMemoryBase(loc, rewriter, op.getOperation());
54+
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
5555
auto resultTy = cast<MemDescType>(op.getType());
5656
auto typeConverter = getTypeConverter();
5757
auto sharedLayout =

0 commit comments

Comments
 (0)