Skip to content

Commit c238af8

Browse files
authored
[AMD] Enable masked load and pointer canonicalization pass (triton-lang#4638)
This PR is doing two things: - We are using the new `llvm.masked{load/store}` intrinsics. This means that the backend will take responsibility to lower the stores/loads. - We are enabling the canonicalization pointer pass on the Triton IR. I extensively run testing and corrected a couple of minor issues still present in the implementation. The reason why I am enabling both at the same time is because I saw a minor regression with `llvm.masked{load,store}` which seems to go away when using the pointer canonicalization. Also, this combination seems to reduce the numbers of vgprs used (at least for GEMM kernels).
1 parent 368c864 commit c238af8

File tree

7 files changed

+121
-41
lines changed

7 files changed

+121
-41
lines changed

python/test/unit/language/test_line_info.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,10 @@ def test_line_info(func: str):
172172
assert (check_file_lines(file_lines, "test_line_info.py", 16))
173173
elif func == "call":
174174
assert (check_file_lines(file_lines, "test_line_info.py", 28))
175-
assert (check_file_lines(file_lines, "test_line_info.py", 21))
176175
assert (check_file_lines(file_lines, "test_line_info.py", 30))
177176
elif func == "call_noinline":
178177
assert (check_file_lines(file_lines, "test_line_info.py", 42))
179178
assert (check_file_lines(file_lines, "test_line_info.py", 35))
180-
assert (check_file_lines(file_lines, "test_line_info.py", 36))
181179
assert (check_file_lines(file_lines, "test_line_info.py", 37))
182180
elif func == "autotune":
183181
assert (check_file_lines(file_lines, "test_line_info.py", 53))

test/Conversion/amd/load_store.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
1515
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
1616
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
1717
// Load 8 elements from A with two vectorized load instruction
18-
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr -> vector<4xf32>
18+
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
1919
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
2020
// Load 8 elements from B with two vectorized load instruction
21-
// CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr -> vector<4xf32>
21+
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
2222
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
2323
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
2424
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
99
// CHECK: llvm.br
1010
// CHECK: rocdl.barrier
1111
// CHECK: llvm.load
12-
// CHECK: llvm.store
12+
// CHECK: llvm.intr.masked.store
1313
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (!tt.ptr<f32>, f32, i1) -> f32
1414
tt.store %arg0, %0 : !tt.ptr<f32>
1515
tt.return
@@ -25,10 +25,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
2525
// CHECK: llvm.cond_br
2626
// CHECK: llvm.atomicrmw
2727
// CHECK: llvm.atomicrmw
28-
// CHECK: %[[ADDR1:.*]] = llvm.extractvalue
29-
// CHECK: %[[ADDR2:.*]] = llvm.extractvalue
30-
// CHECK: llvm.store %{{.*}}, %[[ADDR1]]
31-
// CHECK: llvm.store %{{.*}}, %[[ADDR2]]
28+
// CHECK: %[[ADDR1:.*]] = llvm.addrspacecast
29+
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR1]]
30+
// CHECK: %[[ADDR2:.*]] = llvm.addrspacecast
31+
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR2]]
3232
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
3333
tt.store %arg0, %0 : tensor<256x!tt.ptr<f32>, #blocked0>
3434
tt.return

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def make_ttgir(mod, metadata, options):
178178
passes.ttgpuir.add_reduce_data_duplication(pm)
179179
if use_new_pipeliner or options.num_stages != 0:
180180
amd.passes.ttgpuir.add_reorder_instructions(pm)
181+
amd.passes.ttgpuir.add_canonicalize_pointers(pm)
182+
passes.common.add_canonicalizer(pm)
181183
passes.common.add_cse(pm)
182184
passes.common.add_symbol_dce(pm)
183185
pm.run(mod)

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "TargetInfo.h"
33
#include "Utility.h"
4+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
6+
#include "mlir/IR/PatternMatch.h"
7+
#include "mlir/Transforms/DialectConversion.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
49
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
510

611
using namespace mlir;
@@ -86,6 +91,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
8691
}
8792
return mask;
8893
}
94+
8995
// Contains some helper functions for both Load and Store conversions.
9096
struct LoadStoreConversionBase {
9197
explicit LoadStoreConversionBase(const AMD::TargetInfo &targetInfo,
@@ -192,7 +198,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
192198
auto cacheMod = op.getCache();
193199
SmallVector<Value> loadedVals;
194200
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
195-
// TODO: optimization when ptr is GEP with constant offset
196201
size_t in_off = 0;
197202

198203
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
@@ -218,8 +223,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
218223
Value v = undef(vecTy);
219224
for (size_t s = 0; s < vec; ++s) {
220225
Value otherElem = otherElems[vecStart + s];
221-
Value indexVal = createIndexAttrConstant(
222-
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
226+
Value indexVal = LLVM::createIndexConstant(
227+
rewriter, loc, this->getTypeConverter(), s);
223228
v = insert_element(vecTy, v, otherElem, indexVal);
224229
}
225230
falseVal = v;
@@ -259,6 +264,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
259264
ConversionPatternRewriter &rewriter) const override {
260265
Value ptr = op.getPtr();
261266
Value value = op.getValue();
267+
Value mask = op.getMask();
262268

263269
Value llPtr = adaptor.getPtr();
264270
Value llMask = adaptor.getMask();
@@ -281,24 +287,24 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
281287
// Determine the vectorization size
282288
SmallVector<Value> maskElems;
283289
if (llMask) {
284-
Value mask = op.getMask();
285290
maskElems = unpackLLElements(loc, llMask, rewriter);
286291
assert(valueElems.size() == maskElems.size());
287292

288293
unsigned maskAlign = getMaskAlignment(mask);
289294
vec = std::min(vec, maskAlign);
290295
}
291296

292-
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
293297
const size_t dtsize =
294298
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
295299
const size_t valueElemNBits = dtsize * 8;
296300

297301
auto cacheMod = op.getCache();
298302
const int numVecs = elemsPerThread / vec;
303+
Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
299304
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
300-
// TODO: optimization when ptr is AddPtr with constant offset
301305
size_t in_off = 0;
306+
Value pred = mask ? and_(maskElems[vecStart], rDataMask) : rDataMask;
307+
auto vecTy = LLVM::getFixedVectorType(valueElemTy, vec);
302308

303309
const size_t maxWordWidth = std::max<size_t>(32, valueElemNBits);
304310
const size_t totalWidth = valueElemNBits * vec;
@@ -307,33 +313,23 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
307313
const size_t wordNElems = width / valueElemNBits;
308314
assert(wordNElems * nWords * numVecs == elemsPerThread);
309315

310-
// TODO(Superjomn) Add cache policy fields to StoreOp.
311-
// TODO(Superjomn) Deal with cache policy here.
312-
313316
Type valArgTy = IntegerType::get(ctx, width);
314317
auto wordTy = vec_ty(valueElemTy, wordNElems);
315318

316319
SmallVector<std::pair<Value, std::string>> asmArgs;
317-
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
318-
// llWord is a width-len composition
319-
Value llWord = undef(wordTy);
320-
// Insert each value element to the composition
321-
for (size_t elemIdx = 0; elemIdx < wordNElems; ++elemIdx) {
322-
const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
323-
assert(elemOffset < valueElems.size());
324-
Value elem = valueElems[elemOffset];
325-
if (elem.getType().isInteger(1))
326-
elem = sext(i8_ty, elem);
327-
elem = bitcast(elem, valueElemTy);
328-
329-
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
330-
}
331-
llWord = bitcast(llWord, valArgTy);
332-
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
333-
auto address = ptrElems[vecStart + wordIdx * wordNElems];
334-
llStore(rewriter, loc, address, llWord, maskVal, cacheMod);
320+
Value elem = valueElems[vecStart];
321+
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);
322+
323+
// Create the store val
324+
Value storeVal = undef(vecTy);
325+
for (size_t s = 0; s < vec; ++s) {
326+
Value otherElem = valueElems[vecStart + s];
327+
Value indexVal = createIndexAttrConstant(
328+
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
329+
storeVal = insert_element(vecTy, storeVal, otherElem, indexVal);
335330
}
336-
}
331+
llStore(rewriter, loc, ptr, storeVal, pred, cacheMod);
332+
} // end vec
337333
rewriter.eraseOp(op);
338334
return success();
339335
}

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
#include "Utility.h"
22
#include "PatternTritonGPUOpToLLVM.h"
3+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
34
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
45
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
6+
#include "mlir/IR/PatternMatch.h"
57
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
8+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
9+
#include "triton/Dialect/Triton/IR/Dialect.h"
610

711
using mlir::triton::gpu::appendOrGetExternFuncOp;
812
using mlir::triton::gpu::getFunctionType;
@@ -35,6 +39,35 @@ std::string mangleFunc(std::string name, Type type) {
3539
}
3640
return mangled;
3741
}
42+
43+
// Utility function to create a constant vector mask of length `vecSize` with
44+
// the same `pred` value
45+
Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc,
46+
Value pred, int64_t vecSize) {
47+
auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vecSize);
48+
Value maskVal = undef(vecMaskTy);
49+
for (size_t s = 0; s < vecSize; ++s) {
50+
Value indexVal =
51+
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64IntegerAttr(s));
52+
maskVal = insert_element(vecMaskTy, maskVal, pred, indexVal);
53+
}
54+
return maskVal;
55+
}
56+
57+
// Utility function to get the number of elements of a vector or a scalar
58+
int64_t getNumElements(Type ty) {
59+
if (auto vecType = dyn_cast<VectorType>(ty))
60+
return vecType.getNumElements();
61+
return 1;
62+
}
63+
64+
// Utility function to cast the given scalar or vector type to a vector type
65+
Type castToVectorType(Type ty) {
66+
if (isa<VectorType>(ty))
67+
return ty;
68+
return LLVM::getFixedVectorType(ty, 1);
69+
}
70+
3871
} // namespace
3972

4073
namespace mlir::LLVM::AMD {
@@ -157,6 +190,25 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
157190

158191
Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
159192
Value pred, Value falseVal, triton::CacheModifier cm) {
193+
194+
// Try to emit llvm.intr.masked.load if we can. In theory the backend should
195+
// be happier because we emit less branchy code to optimize. The backend will
196+
// lower it down however it wants at some point.
197+
if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) {
198+
// `llvm.intr.masked.load` only accepts vectors. If we see a scalar we need
199+
// to bitcast to `vector<1xelemTy>` (and back)
200+
int64_t vecSize = getNumElements(elemTy);
201+
Type vecType = castToVectorType(elemTy);
202+
falseVal = bitcast(falseVal, vecType);
203+
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
204+
bool nt = (cm == triton::CacheModifier::CG);
205+
Value vecData = rewriter.create<LLVM::MaskedLoadOp>(
206+
loc, vecType, ptr, maskVal, falseVal, vecSize, nt);
207+
// If it is not a vector, remember to bitcast back to a scalar
208+
vecData = bitcast(vecData, elemTy);
209+
return vecData;
210+
}
211+
160212
Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal}));
161213
auto parent = ptr.getParentRegion()->getParentOfType<LLVM::LLVMFuncOp>();
162214
auto getLoadNameRaw = [](triton::CacheModifier cm) {
@@ -173,7 +225,6 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
173225
};
174226

175227
auto funcName = mangleFunc(getLoadNameRaw(cm), funcType);
176-
177228
LLVM::LLVMFuncOp funcOp =
178229
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
179230
auto loadVal =
@@ -185,6 +236,22 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
185236

186237
void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
187238
Value pred, triton::CacheModifier cm) {
239+
// Try to emit llvm.intr.masked.store if we can. In theory the backend should
240+
// be happier because we emit less branchy code to optimize. The backend will
241+
// lower it down however it wants at some point.
242+
if (cm == triton::CacheModifier::NONE) {
243+
// `llvm.intr.masked.store` only accepts vectors. If we see a scalar we need
244+
// to bitcast to `vector<1xelemTy>`
245+
Type elemTy = val.getType();
246+
int64_t vecSize = getNumElements(elemTy);
247+
Type vecType = castToVectorType(elemTy);
248+
val = bitcast(val, vecType);
249+
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vecSize);
250+
auto op =
251+
rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal, vecSize);
252+
return;
253+
}
254+
188255
auto ctx = ptr.getContext();
189256
Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred}));
190257
auto parent = ptr.getParentRegion()->getParentOfType<LLVM::LLVMFuncOp>();

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1717
#include "llvm/ADT/STLExtras.h"
1818
#include "llvm/ADT/TypeSwitch.h"
19+
#include "llvm/Support/Casting.h"
1920
#include "llvm/Support/Debug.h"
2021
#include <utility>
2122

@@ -225,17 +226,25 @@ Value getScalarConstant(IRRewriter &rewriter, Location loc, Value expr) {
225226
Operation *op = expr.getDefiningOp();
226227

227228
// Check for splatness
228-
if (auto splatOp = dyn_cast<triton::SplatOp>(op))
229+
if (auto splatOp = dyn_cast_or_null<triton::SplatOp>(op))
229230
return splatOp.getSrc();
230231

231232
// Check for constant
232233
DenseIntElementsAttr constVal;
233-
if (auto constOp = dyn_cast<arith::ConstantOp>(op)) {
234+
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op)) {
234235
Value val = constOp.getResult();
235236
if (matchPattern(val, m_Constant(&constVal)) && constVal.isSplat())
236237
return rewriter.create<arith::ConstantOp>(
237238
loc, constVal.getSplatValue<IntegerAttr>());
238239
}
240+
241+
// Check for block arguments
242+
if (auto blockArg = dyn_cast_or_null<BlockArgument>(expr)) {
243+
Type type = blockArg.getType();
244+
if (!isa<RankedTensorType>(type))
245+
return blockArg;
246+
}
247+
239248
return Value();
240249
}
241250

@@ -318,6 +327,14 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
318327
return {scalarConst, tensorZero};
319328
}
320329

330+
// Base case 2: block argument. Since it is not a scalar constant, it must be
331+
// a tensor. Note that this means we won't be able to decompose across loop
332+
// boundaries (TODO: giuseros).
333+
if (auto blockArg = dyn_cast<BlockArgument>(expr)) {
334+
Value scalarZero = rewriter.create<arith::ConstantIntOp>(loc, 0, bitness);
335+
return std::make_pair(scalarZero, expr);
336+
}
337+
321338
auto offsets =
322339
llvm::TypeSwitch<Operation *, std::pair<Value, Value>>(
323340
expr.getDefiningOp())
@@ -342,7 +359,7 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
342359
return decomposeOffsetFromMul(loc, expr, bitness);
343360
})
344361
.Default([&](Operation *op) {
345-
// Base case 2: it is not a supported operation. We assume no
362+
// Base case 3: it is not a supported operation. We assume no
346363
// uniform part
347364
Value scalarZero =
348365
rewriter.create<arith::ConstantIntOp>(loc, 0, bitness);

0 commit comments

Comments
 (0)