Skip to content

Commit 92c789c

Browse files
authored
Fix getPredMask to handle pointers to tensors loads. (#4582)
Fixes issue #4580. Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent d8bde0a commit 92c789c

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

lib/Dialect/Triton/IR/Utility.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
33
#include "mlir/Dialect/SCF/IR/SCF.h"
44
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include "triton/Dialect/Triton/IR/Types.h"
56

67
using namespace mlir;
78
namespace tt = mlir::triton;
89

910
Value tt::getPredMask(RewriterBase &rewriter, Type typeLike, Value currentMask,
1011
Value pred) {
11-
Type maskType = tt::getI1SameShape(typeLike);
12+
Type maskType = tt::getI1SameShape(tt::getPointeeType(typeLike));
1213
Location loc = pred.getLoc();
1314
Value mask = pred;
1415
if (isa<RankedTensorType>(maskType)) {

test/Triton/loop-invariant-code-motion.mlir

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt --split-input-file %s -triton-licm | FileCheck %s
22

3-
tt.func @hoist_load_without_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>) {
3+
tt.func @hoist_load_without_mask1(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>) {
44
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
55
%c1_i32 = arith.constant 1 : i32
66
// Check if the load is hoisted
7-
// CHECK-LABEL: hoist_load_without_mask
7+
// CHECK-LABEL: hoist_load_without_mask1
88
// CHECK: %[[TRIP_COUNT_CMP:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
99
// CHECK: %[[SPLAT:.*]] = tt.splat %[[TRIP_COUNT_CMP]]
1010
// CHECK: %[[LOAD:.*]] = tt.load %[[_:.*]], %[[SPLAT]]
@@ -23,6 +23,29 @@ tt.func @hoist_load_without_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor
2323

2424
// -----
2525

26+
tt.func @hoist_load_without_mask2(%arg0: !tt.ptr<tensor<1024xf32>>, %arg3: i32, %arg4 : i32, %arg5: !tt.ptr<tensor<1024xf32>>) {
27+
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
28+
%c1_i32 = arith.constant 1 : i32
29+
// Check if the load is hoisted
30+
// CHECK-LABEL: hoist_load_without_mask2
31+
// CHECK: %[[TRIP_COUNT_CMP:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]]
32+
// CHECK: %[[SPLAT:.*]] = tt.splat %[[TRIP_COUNT_CMP]]
33+
// CHECK: %[[LOAD:.*]] = tt.load %[[_:.*]], %[[SPLAT]]
34+
// CHECK: arith.addf %[[LOAD]], %[[LOAD]]
35+
// CHECK: scf.for
36+
// CHECK-NOT: tt.load
37+
%1 = scf.for %arg7 = %arg3 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<1024xf32>) : i32 {
38+
%2 = tt.load %arg0 : !tt.ptr<tensor<1024xf32>>
39+
%3 = arith.addf %2, %2 : tensor<1024xf32>
40+
%4 = arith.addf %arg6, %3 : tensor<1024xf32>
41+
scf.yield %4 : tensor<1024xf32>
42+
}
43+
tt.store %arg5, %1 : !tt.ptr<tensor<1024xf32>>
44+
tt.return
45+
}
46+
47+
// -----
48+
2649
tt.func @hoist_two_loads_without_mask(%arg0: tensor<1024x!tt.ptr<f32>>, %arg1: tensor<1024xi32>, %arg2: tensor<1024xi32>, %arg3: i32, %arg4 : i32, %arg5: tensor<1024x!tt.ptr<f32>>, %arg6: tensor<1024x!tt.ptr<f32>>) {
2750
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32>
2851
%c1_i32 = arith.constant 1 : i32

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
161161
return TypeSwitch<Operation *, Operation *>(op)
162162
.Case<tt::LoadOp, ttgi::PrefetchOp>([&](auto op) {
163163
rewriter.setInsertionPoint(op);
164-
Value mask =
165-
tt::getPredMask(rewriter, tt::getPointeeType(op.getPtr().getType()),
166-
op.getMask(), pred);
164+
Value mask = tt::getPredMask(rewriter, op.getPtr().getType(),
165+
op.getMask(), pred);
167166
op.getMaskMutable().assign(mask);
168167
return op;
169168
});

0 commit comments

Comments
 (0)