Skip to content

Commit 36cd847

Browse files
SPIRV vector.mask lowering: use 64-bit type. (#1093)
So far, most of the kernels use 16 or 32 vector sizes. But there is a kernel that returns an i8 tensor and the vector size = 64. So, there was an overflow when computing the mask (1 << valid_elements). This PR updates the data type from 32-bit to 64-bit when computing the mask. It also adds an assert to catch more easily this type of issue.
1 parent b94c6e1 commit 36cd847

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,31 @@ class VectorMaskConversionPattern final
9393
return mlir::failure();
9494

9595
auto vWidth = vTy.getNumElements();
96+
assert(vWidth <= 64 && "vector.create_mask supports vector widths <= 64");
9697
auto vWidthConst = rewriter.create<mlir::arith::ConstantOp>(
97-
vMaskOp.getLoc(), rewriter.getI32IntegerAttr(vWidth));
98+
vMaskOp.getLoc(), rewriter.getI64IntegerAttr(vWidth));
9899
auto maskVal = adaptor.getOperands()[0];
99100
maskVal = rewriter.create<mlir::arith::TruncIOp>(
100-
vMaskOp.getLoc(), rewriter.getI32Type(), maskVal);
101+
vMaskOp.getLoc(), rewriter.getI64Type(), maskVal);
101102

102103
// maskVal < vWidth
103104
auto cmp = rewriter.create<mlir::arith::CmpIOp>(
104105
vMaskOp.getLoc(), mlir::arith::CmpIPredicate::slt, maskVal,
105106
vWidthConst);
106107
auto one = rewriter.create<mlir::arith::ConstantOp>(
107-
vMaskOp.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(1));
108+
vMaskOp.getLoc(), rewriter.getI64IntegerAttr(1));
108109
auto shift = rewriter.create<mlir::spirv::ShiftLeftLogicalOp>(
109110
vMaskOp.getLoc(), one, maskVal);
110111
auto mask1 =
111112
rewriter.create<mlir::arith::SubIOp>(vMaskOp.getLoc(), shift, one);
112113
auto mask2 = rewriter.create<mlir::arith::ConstantOp>(
113-
vMaskOp.getLoc(), rewriter.getI32Type(),
114-
rewriter.getI32IntegerAttr(0xFFFFFFFF));
114+
vMaskOp.getLoc(), rewriter.getI64IntegerAttr(-1)); // all ones
115115
mlir::Value sel = rewriter.create<mlir::arith::SelectOp>(vMaskOp.getLoc(),
116116
cmp, mask1, mask2);
117117

118118
// maskVal < 0
119119
auto zero = rewriter.create<mlir::arith::ConstantOp>(
120-
vMaskOp.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0));
120+
vMaskOp.getLoc(), rewriter.getI64IntegerAttr(0));
121121
auto cmp2 = rewriter.create<mlir::arith::CmpIOp>(
122122
vMaskOp.getLoc(), mlir::arith::CmpIPredicate::slt, maskVal, zero);
123123
sel = rewriter.create<mlir::arith::SelectOp>(vMaskOp.getLoc(), cmp2, zero,

test/Conversion/GPUToSPIRV/create_mask.mlir

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@ module attributes {
1616

1717
// CHECK-LABEL: spirv.func @create_mask
1818
// CHECK-SAME: %[[MASK_VAL:[[:alnum:]]+]]: i64
19-
// CHECK-NEXT: %[[VECTOR_WIDTH:.*]] = spirv.Constant 16 : i32
20-
// CHECK-NEXT: %[[MASK_VAL_I32:.*]] = spirv.SConvert %[[MASK_VAL]] : i64 to i32
21-
// CHECK-NEXT: %[[CMP1:.*]] = spirv.SLessThan %[[MASK_VAL_I32]], %[[VECTOR_WIDTH]] : i32
22-
// CHECK-NEXT: %[[ONE:.*]] = spirv.Constant 1 : i32
23-
// CHECK-NEXT: %[[SHIFT:.*]] = spirv.ShiftLeftLogical %[[ONE]], %[[MASK_VAL_I32]] : i32, i32
24-
// CHECK-NEXT: %[[MASK:.*]] = spirv.ISub %[[SHIFT]], %[[ONE]] : i32
25-
// CHECK-NEXT: %[[MASK_ONES:.*]] = spirv.Constant -1 : i32
26-
// CHECK-NEXT: %[[SELECT1:.*]] = spirv.Select %[[CMP1]], %[[MASK]], %[[MASK_ONES]] : i1, i32
27-
// CHECK-NEXT: %[[ZERO:.*]] = spirv.Constant 0 : i32
28-
// CHECK-NEXT: %[[CMP2:.*]] = spirv.SLessThan %[[MASK_VAL_I32]], %[[ZERO]] : i32
29-
// CHECK-NEXT: %[[SELECT2:.*]] = spirv.Select %[[CMP2]], %[[ZERO]], %[[SELECT1]] : i1, i32
30-
// CHECK-NEXT: %[[CAST:.*]] = spirv.SConvert %[[SELECT2]] : i32 to i16
19+
// CHECK-NEXT: %[[VECTOR_WIDTH:.*]] = spirv.Constant 16 : i64
20+
// CHECK-NEXT: %[[CMP1:.*]] = spirv.SLessThan %[[MASK_VAL]], %[[VECTOR_WIDTH]] : i64
21+
// CHECK-NEXT: %[[ONE:.*]] = spirv.Constant 1 : i64
22+
// CHECK-NEXT: %[[SHIFT:.*]] = spirv.ShiftLeftLogical %[[ONE]], %[[MASK_VAL]] : i64, i64
23+
// CHECK-NEXT: %[[MASK:.*]] = spirv.ISub %[[SHIFT]], %[[ONE]] : i64
24+
// CHECK-NEXT: %[[MASK_ONES:.*]] = spirv.Constant -1 : i64
25+
// CHECK-NEXT: %[[SELECT1:.*]] = spirv.Select %[[CMP1]], %[[MASK]], %[[MASK_ONES]] : i1, i64
26+
// CHECK-NEXT: %[[ZERO:.*]] = spirv.Constant 0 : i64
27+
// CHECK-NEXT: %[[CMP2:.*]] = spirv.SLessThan %[[MASK_VAL]], %[[ZERO]] : i64
28+
// CHECK-NEXT: %[[SELECT2:.*]] = spirv.Select %[[CMP2]], %[[ZERO]], %[[SELECT1]] : i1, i64
29+
// CHECK-NEXT: %[[CAST:.*]] = spirv.SConvert %[[SELECT2]] : i64 to i16
3130
// CHECK-NEXT: spirv.Bitcast %[[CAST]] : i16 to vector<16xi1>
3231
// CHECK-NEXT: spirv.Return
3332

0 commit comments

Comments
 (0)