Skip to content

Commit abb8113

Browse files
authored
Pass volatile and nonTemporal flag when lowering tt.load instruction (#5465)
When lowering `tt.load` the backend currently ignores attributes such has 'volatile' and `cacheModifier`. This PR rectify the situation for load operation that use a tensor of ptrs --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 98f7946 commit abb8113

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s
2+
3+
#blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
4+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
5+
// CHECK-LABEL: global_load_with_attributes
6+
tt.func @global_load_with_attributes(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
7+
%c256_i32 = arith.constant 256 : i32
8+
%0 = tt.get_program_id x : i32
9+
%1 = arith.muli %0, %c256_i32 : i32
10+
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0>
11+
%3 = tt.splat %1 : i32 -> tensor<256xi32, #blocked0>
12+
%4 = arith.addi %3, %2 : tensor<256xi32, #blocked0>
13+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
14+
%6 = tt.addptr %5, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
15+
%9 = tt.load %6 {isVolatile = true} : tensor<256x!tt.ptr<f32>, #blocked0>
16+
%10 = tt.load %6 cacheModifier = ca : tensor<256x!tt.ptr<f32>, #blocked0>
17+
%12 = tt.load %6 cacheModifier = cg : tensor<256x!tt.ptr<f32>, #blocked0>
18+
%13 = tt.load %6 cacheModifier = wb : tensor<256x!tt.ptr<f32>, #blocked0>
19+
%14 = tt.load %6 cacheModifier = cs : tensor<256x!tt.ptr<f32>, #blocked0>
20+
%15 = tt.load %6 cacheModifier = wt : tensor<256x!tt.ptr<f32>, #blocked0>
21+
%16 = tt.load %6 cacheModifier = cv : tensor<256x!tt.ptr<f32>, #blocked0>
22+
// CHECK-COUNT-2: llvm.load volatile {{.*}} {alignment = 16 : i64} : !llvm.ptr<1> -> vector<4xi32>
23+
// CHECK-COUNT-2: llvm.load {{.*}} {alignment = 16 : i64} : !llvm.ptr<1> -> vector<4xi32>
24+
// CHECK-COUNT-2: llvm.load {{.*}} {alignment = 16 : i64, nontemporal} : !llvm.ptr<1> -> vector<4xi32>
25+
// CHECK-COUNT-2: llvm.load {{.*}} {alignment = 16 : i64} : !llvm.ptr<1> -> vector<4xi32>
26+
// CHECK-COUNT-2: llvm.load {{.*}} {alignment = 16 : i64, nontemporal} : !llvm.ptr<1> -> vector<4xi32>
27+
// CHECK-COUNT-2: llvm.load {{.*}} {alignment = 16 : i64} : !llvm.ptr<1> -> vector<4xi32>
28+
// CHECK-COUNT-2: llvm.load {{.*}} {alignment = 16 : i64, nontemporal} : !llvm.ptr<1> -> vector<4xi32>
29+
tt.return
30+
}
31+
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "Dialect/TritonIntelGPU/IR/Dialect.h"
2+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
23
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
34
#include "mlir/IR/Matchers.h"
45
#include "mlir/IR/TypeUtilities.h"
@@ -3065,8 +3066,21 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30653066

30663067
Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
30673068
uint32_t alignment = nWords * width / 8;
3068-
auto createLoadInstruction = [&]() -> SmallVector<Value, 1> {
3069-
Value ret = b.load(retTy, addrElem, alignment);
3069+
auto createLoadWithAttrs = [&]() -> SmallVector<Value, 1> {
3070+
auto getNonTemporalFlag = [](triton::LoadOp loadOp) {
3071+
switch (loadOp.getCache()) {
3072+
case triton::CacheModifier::CG:
3073+
case triton::CacheModifier::CS:
3074+
case triton::CacheModifier::CV:
3075+
return true;
3076+
case triton::CacheModifier::CA:
3077+
default:
3078+
return false;
3079+
}
3080+
};
3081+
3082+
Value ret = b.load(retTy, addrElem, alignment, op.getIsVolatile(),
3083+
getNonTemporalFlag(op));
30703084
return {ret};
30713085
};
30723086

@@ -3079,11 +3093,11 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
30793093
else {
30803094
Block &endBlock = LLVM::intel::createPredicatedBlock(
30813095
rewriter, loc, pred, SmallVector<Value, 1>{other_},
3082-
createLoadInstruction);
3096+
createLoadWithAttrs);
30833097
ret = *endBlock.args_begin();
30843098
}
30853099
} else {
3086-
ret = createLoadInstruction()[0];
3100+
ret = createLoadWithAttrs()[0];
30873101
}
30883102

30893103
// Extract and store return values

0 commit comments

Comments
 (0)