Skip to content

Commit e39c908

Browse files
authored
Pass nonTemporal flag when lowering tt.store instruction (#5472)
When lowering tt.store the backend currently ignores attributes such as the `cacheModifier` attribute. This PR rectify the situation for store operations that use a tensor of ptrs --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 4f5bb10 commit e39c908

File tree

2 files changed

+84
-33
lines changed

2 files changed

+84
-33
lines changed

test/Conversion/intel/load_store_to_llvm.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,36 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
2929
tt.return
3030
}
3131
}
32+
33+
// -----
34+
35+
#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
36+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
37+
// CHECK-LABEL: global_store_with_attributes
38+
tt.func @global_store_with_attributes(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
39+
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0>
40+
%c256_i32 = arith.constant 256 : i32
41+
%0 = tt.get_program_id x : i32
42+
%1 = arith.muli %0, %c256_i32 : i32
43+
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
44+
%3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0>
45+
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
46+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
47+
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
48+
tt.store %6, %cst : tensor<256x!tt.ptr<f32>, #blocked0>
49+
tt.store %6, %cst cacheModifier = ca : tensor<256x!tt.ptr<f32>, #blocked0>
50+
tt.store %6, %cst cacheModifier = cg : tensor<256x!tt.ptr<f32>, #blocked0>
51+
tt.store %6, %cst cacheModifier = wb : tensor<256x!tt.ptr<f32>, #blocked0>
52+
tt.store %6, %cst cacheModifier = cs : tensor<256x!tt.ptr<f32>, #blocked0>
53+
tt.store %6, %cst cacheModifier = wt : tensor<256x!tt.ptr<f32>, #blocked0>
54+
tt.store %6, %cst cacheModifier = cv : tensor<256x!tt.ptr<f32>, #blocked0>
55+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
56+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
57+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64, nontemporal} : vector<4xi32>, !llvm.ptr<1>
58+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
59+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64, nontemporal} : vector<4xi32>, !llvm.ptr<1>
60+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64} : vector<4xi32>, !llvm.ptr<1>
61+
// CHECK-COUNT-2: llvm.store {{.*}} {alignment = 16 : i64, nontemporal} : vector<4xi32>, !llvm.ptr<1>
62+
tt.return
63+
}
64+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,11 @@ struct LoadStoreConversionBase {
206206
// }
207207
// All the values are decomposed by `unpackLLElements` into a vector.
208208
// Defines the indices for the block pointer struct.
209-
unsigned blockOffset = 0, blockShape = 1 * rank, blockStride = 2 * rank,
210-
blockBase = 3 * rank;
209+
const unsigned blockOffset = 0, blockShape = 1 * rank,
210+
blockStride = 2 * rank, blockBase = 3 * rank;
211211
const SmallVector<Value> &blockPtr =
212212
unpackLLElements(loc, blockPointerStruct, rewriter);
213-
214-
unsigned numElems = getTotalElemsPerThread(tensorType);
213+
const unsigned numElems = getTotalElemsPerThread(tensorType);
215214

216215
// Get the LLVM values for indices in block
217216
auto indices = emitIndices(loc, rewriter, targetInfo,
@@ -303,6 +302,34 @@ struct LoadStoreConversionBase {
303302
return std::make_tuple(ptrElems, maskElems, otherElems);
304303
}
305304

305+
// Ensure the operation doesn't have attributes that the IGC predicated
306+
// instruction cannot handle.
307+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
308+
OpType, LoadOp, StoreOp>::value>>
309+
bool canUsePredicatedInstructions(OpType op) const {
310+
if (!usePredicatedInstructions)
311+
return false;
312+
313+
if constexpr (std::is_same_v<OpType, LoadOp>)
314+
return !op.getIsVolatile() && op.getCache() == CacheModifier::NONE;
315+
316+
return op.getCache() == CacheModifier::NONE;
317+
}
318+
319+
template <typename OpType, typename = std::enable_if_t<llvm::is_one_of<
320+
OpType, LoadOp, StoreOp>::value>>
321+
bool getNonTemporalFlag(OpType op) const {
322+
switch (op.getCache()) {
323+
case triton::CacheModifier::CG:
324+
case triton::CacheModifier::CS:
325+
case triton::CacheModifier::CV:
326+
return true;
327+
case triton::CacheModifier::CA:
328+
default:
329+
return false;
330+
}
331+
}
332+
306333
protected:
307334
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
308335
const triton::intel::TargetInfo &targetInfo;
@@ -2438,11 +2465,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24382465
: retTys[0];
24392466

24402467
Value other_ = b.undef(retTy);
2441-
if (otherElems.size()) {
2468+
if (otherElems.empty()) {
2469+
other_ = rewriter.create<LLVM::ConstantOp>(loc, retTy,
2470+
rewriter.getZeroAttr(retTy));
2471+
} else {
24422472
for (size_t ii = 0; ii < nWords; ++ii) {
24432473
size_t size = width / valueElemNBits;
2444-
2445-
auto vecTy = vec_ty(valueElemTy, size);
2474+
VectorType vecTy = vec_ty(valueElemTy, size);
24462475
Value v = b.undef(vecTy);
24472476
for (size_t s = 0; s < size; ++s) {
24482477
Value falseVal = otherElems[vecStart + ii * size + s];
@@ -2468,36 +2497,21 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24682497

24692498
v;
24702499
}
2471-
} else {
2472-
other_ = rewriter.create<LLVM::ConstantOp>(loc, retTy,
2473-
rewriter.getZeroAttr(retTy));
24742500
}
2501+
assert(other_ && "Expecting a valid value");
24752502

24762503
Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
24772504
uint32_t alignment = nWords * width / 8;
2478-
auto createLoadWithAttrs = [&]() -> SmallVector<Value, 1> {
2479-
auto getNonTemporalFlag = [](triton::LoadOp loadOp) {
2480-
switch (loadOp.getCache()) {
2481-
case triton::CacheModifier::CG:
2482-
case triton::CacheModifier::CS:
2483-
case triton::CacheModifier::CV:
2484-
return true;
2485-
case triton::CacheModifier::CA:
2486-
default:
2487-
return false;
2488-
}
2489-
};
2490-
2491-
Value ret = b.load(retTy, addrElem, alignment, op.getIsVolatile(),
2492-
getNonTemporalFlag(op));
2493-
return {ret};
2505+
auto createLoadWithAttrs = [&]() {
2506+
return SmallVector<Value>{b.load(retTy, addrElem, alignment,
2507+
op.getIsVolatile(),
2508+
getNonTemporalFlag(op))};
24942509
};
24952510

24962511
Value ret;
2497-
24982512
if (!pred)
24992513
ret = createLoadWithAttrs()[0];
2500-
else if (usePredicatedInstructions)
2514+
else if (canUsePredicatedInstructions(op))
25012515
ret = rewriter.create<TritonGEN::PredicatedLoadOp>(
25022516
loc, retTy, addrElem, b.i64_val(alignment), pred, other_);
25032517
else {
@@ -2519,6 +2533,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25192533
curr, LLVM::getVectorType(valueElemTy, width / valueElemNBits));
25202534
rets.push_back(curr);
25212535
}
2536+
25222537
int tmp = width / valueElemNBits;
25232538
for (size_t ii = 0; ii < vec; ++ii) {
25242539
Value loaded =
@@ -2931,18 +2946,21 @@ struct StoreOpConversion
29312946

29322947
Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
29332948
uint32_t alignment = nWords * width / 8;
2934-
auto createStore = [&]() -> ArrayRef<Value> {
2935-
b.store(vecWord, addrElem, alignment);
2949+
auto createStoreWithAttrs = [&]() {
2950+
bool isVolatile = false;
2951+
b.store(vecWord, addrElem, alignment, isVolatile,
2952+
getNonTemporalFlag(op));
29362953
return ArrayRef<Value>();
29372954
};
29382955

29392956
if (!maskVal)
2940-
auto _ = createStore();
2941-
else if (usePredicatedInstructions)
2957+
auto _ = createStoreWithAttrs();
2958+
else if (canUsePredicatedInstructions(op))
29422959
rewriter.create<TritonGEN::PredicatedStoreOp>(
29432960
loc, addrElem, vecWord, b.i64_val(alignment), maskVal);
29442961
else
2945-
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal, createStore);
2962+
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal,
2963+
createStoreWithAttrs);
29462964
}
29472965

29482966
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)