Skip to content

Commit 455ffb6

Browse files
committed
[mlir][llvm] Add alignment to masked intrinsics.
This is a squash of PR #153063.
1 parent a0cc39a commit 455ffb6

File tree

5 files changed

+116
-14
lines changed

5 files changed

+116
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,21 +87,21 @@ class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
8787
class LLVM_CountZerosIntrOp<string func, list<Trait> traits = []> :
8888
LLVM_OneResultIntrOp<func, [], [0],
8989
!listconcat([Pure, SameOperandsAndResultType], traits),
90-
/*requiresFastmath=*/0,
90+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
9191
/*immArgPositions=*/[1], /*immArgAttrNames=*/["is_zero_poison"]> {
9292
let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
9393
I1Attr:$is_zero_poison);
9494
}
9595

9696
def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure],
97-
/*requiresFastmath=*/0,
97+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
9898
/*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> {
9999
let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
100100
I1Attr:$is_int_min_poison);
101101
}
102102

103103
def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
104-
/*requiresFastmath=*/0,
104+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
105105
/*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> {
106106
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit);
107107
}
@@ -360,8 +360,8 @@ def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">;
360360

361361
def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
362362
[DeclareOpInterfaceMethods<PromotableOpInterface>],
363-
/*requiresFastmath=*/0, /*immArgPositions=*/[0],
364-
/*immArgAttrNames=*/["size"]> {
363+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
364+
/*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
365365
let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
366366
let results = (outs LLVM_DefaultPointer:$res);
367367
let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
@@ -412,6 +412,7 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
412412
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
413413
true : []),
414414
/*requiresFastmath=*/0,
415+
/*requiresArgAndResultAttrs=*/0,
415416
/*immArgPositions=*/[],
416417
/*immArgAttrNames=*/[]> {
417418
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
@@ -589,7 +590,7 @@ def LLVM_ExpectOp
589590
def LLVM_ExpectWithProbabilityOp
590591
: LLVM_OneResultIntrOp<"expect.with.probability", [], [0],
591592
[Pure, AllTypesMatch<["val", "expected", "res"]>],
592-
/*requiresFastmath=*/0,
593+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
593594
/*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> {
594595
let arguments = (ins AnySignlessInteger:$val,
595596
AnySignlessInteger:$expected,
@@ -825,7 +826,7 @@ class LLVM_VecReductionAccBase<string mnem, Type element>
825826
/*overloadedResults=*/[],
826827
/*overloadedOperands=*/[1],
827828
/*traits=*/[Pure, SameOperandsAndResultElementType],
828-
/*equiresFastmath=*/1>,
829+
/*requiresFastmath=*/1>,
829830
Arguments<(ins element:$start_value,
830831
LLVM_VectorOf<element>:$input,
831832
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;
@@ -1069,14 +1070,36 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> {
10691070
}
10701071

10711072
/// Create a call to Masked Expand Load intrinsic.
1072-
def LLVM_masked_expandload : LLVM_IntrOp<"masked.expandload", [0], [], [], 1> {
1073-
let arguments = (ins LLVM_AnyPointer, LLVM_VectorOf<I1>, LLVM_AnyVector);
1073+
def LLVM_masked_expandload
1074+
: LLVM_OneResultIntrOp<"masked.expandload", [0], [],
1075+
/*traits=*/[], /*requiresFastMath=*/0, /*requiresArgAndResultAttrs=*/1,
1076+
/*immArgPositions=*/[], /*immArgAttrNames=*/[]> {
1077+
dag args = (ins LLVM_AnyPointer:$ptr,
1078+
LLVM_VectorOf<I1>:$mask,
1079+
LLVM_AnyVector:$passthru);
1080+
1081+
let arguments = !con(args, baseArgs);
1082+
1083+
let builders = [
1084+
OpBuilder<(ins "TypeRange":$resTy, "Value":$ptr, "Value":$mask, "Value":$passthru, CArg<"uint64_t", "1">:$align)>
1085+
];
10741086
}
10751087

10761088
/// Create a call to Masked Compress Store intrinsic.
10771089
def LLVM_masked_compressstore
1078-
: LLVM_IntrOp<"masked.compressstore", [], [0], [], 0> {
1079-
let arguments = (ins LLVM_AnyVector, LLVM_AnyPointer, LLVM_VectorOf<I1>);
1090+
: LLVM_ZeroResultIntrOp<"masked.compressstore", [0],
1091+
/*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
1092+
/*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
1093+
/*immArgPositions=*/[], /*immArgAttrNames=*/[]> {
1094+
dag args = (ins LLVM_AnyVector:$value,
1095+
LLVM_AnyPointer:$ptr,
1096+
LLVM_VectorOf<I1>:$mask);
1097+
1098+
let arguments = !con(args, baseArgs);
1099+
1100+
let builders = [
1101+
OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, CArg<"uint64_t", "1">:$align)>
1102+
];
10801103
}
10811104

10821105
//
@@ -1155,7 +1178,7 @@ def LLVM_vector_insert
11551178
PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
11561179
CPred<"!isScalableVectorType($srcvec.getType()) || "
11571180
"isScalableVectorType($dstvec.getType())">>],
1158-
/*requiresFastmath=*/0,
1181+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
11591182
/*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> {
11601183
let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec,
11611184
I64Attr:$pos);
@@ -1189,7 +1212,7 @@ def LLVM_vector_extract
11891212
PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
11901213
CPred<"!isScalableVectorType($res.getType()) || "
11911214
"isScalableVectorType($srcvec.getType())">>],
1192-
/*requiresFastmath=*/0,
1215+
/*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
11931216
/*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> {
11941217
let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
11951218
let results = (outs LLVM_AnyVector:$res);

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,11 +475,12 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
475475
list<int> overloadedOperands = [],
476476
list<Trait> traits = [],
477477
bit requiresFastmath = 0,
478+
bit requiresArgAndResultAttrs = 0,
478479
list<int> immArgPositions = [],
479480
list<string> immArgAttrNames = []>
480481
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
481482
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
482-
requiresFastmath, /*requiresArgAndResultAttrs=*/0,
483+
requiresFastmath, requiresArgAndResultAttrs,
483484
/*requiresOpBundles=*/0, immArgPositions,
484485
immArgAttrNames>;
485486

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,37 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
141141
return success();
142142
}
143143

144+
static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder,
145+
bool isExpandLoad,
146+
uint64_t alignment = 1) {
147+
// From
148+
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
149+
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
150+
//
151+
// The pointer alignment defaults to 1.
152+
if (alignment == 1) {
153+
return nullptr;
154+
}
155+
156+
auto emptyDictAttr = builder.getDictionaryAttr({});
157+
auto alignmentAttr = builder.getI64IntegerAttr(alignment);
158+
auto namedAttr =
159+
builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
160+
SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
161+
auto alignDictAttr = builder.getDictionaryAttr(attrs);
162+
// From
163+
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
164+
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
165+
//
166+
// The align parameter attribute can be provided for [expandload]'s first
167+
// argument. The align parameter attribute can be provided for
168+
// [compressstore]'s second argument.
169+
int pos = isExpandLoad ? 0 : 1;
170+
return pos == 0
171+
? builder.getArrayAttr({alignDictAttr, emptyDictAttr, emptyDictAttr})
172+
: builder.getArrayAttr({emptyDictAttr, alignDictAttr, emptyDictAttr});
173+
}
174+
144175
//===----------------------------------------------------------------------===//
145176
// Operand bundle helpers.
146177
//===----------------------------------------------------------------------===//
@@ -4116,6 +4147,33 @@ LogicalResult LLVM::masked_scatter::verify() {
41164147
return success();
41174148
}
41184149

4150+
//===----------------------------------------------------------------------===//
4151+
// masked_expandload (intrinsic)
4152+
//===----------------------------------------------------------------------===//
4153+
4154+
void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
4155+
mlir::TypeRange resTys, Value ptr,
4156+
Value mask, Value passthru,
4157+
uint64_t align) {
4158+
ArrayAttr callArgs = getLLVMAlignParamForCompressExpand(builder, true, align);
4159+
build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/callArgs,
4160+
/*res_attrs=*/nullptr);
4161+
}
4162+
4163+
//===----------------------------------------------------------------------===//
4164+
// masked_compressstore (intrinsic)
4165+
//===----------------------------------------------------------------------===//
4166+
4167+
void LLVM::masked_compressstore::build(OpBuilder &builder,
4168+
OperationState &state, Value value,
4169+
Value ptr, Value mask, uint64_t align) {
4170+
4171+
ArrayAttr callArgs =
4172+
getLLVMAlignParamForCompressExpand(builder, false, align);
4173+
build(builder, state, value, ptr, mask, /*arg_attrs=*/callArgs,
4174+
/*res_attrs=*/nullptr);
4175+
}
4176+
41194177
//===----------------------------------------------------------------------===//
41204178
// InlineAsmOp
41214179
//===----------------------------------------------------------------------===//

mlir/test/Target/LLVMIR/Import/intrinsic.ll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,15 @@ define void @masked_expand_compress_intrinsics(ptr %0, <7 x i1> %1, <7 x float>
545545
ret void
546546
}
547547

548+
; CHECK-LABEL: llvm.func @masked_expand_compress_intrinsics_with_alignment
549+
define void @masked_expand_compress_intrinsics_with_alignment(ptr %0, <7 x i1> %1, <7 x float> %2) {
550+
; CHECK: %[[val1:.+]] = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> vector<7xf32>
551+
%4 = call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %0, <7 x i1> %1, <7 x float> %2)
552+
; CHECK: "llvm.intr.masked.compressstore"(%[[val1]], %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> ()
553+
call void @llvm.masked.compressstore.v7f32(<7 x float> %4, ptr align 8 %0, <7 x i1> %1)
554+
ret void
555+
}
556+
548557
; CHECK-LABEL: llvm.func @annotate_intrinsics
549558
define void @annotate_intrinsics(ptr %var, ptr %ptr, i16 %int, ptr %annotation, ptr %fileName, i32 %line, ptr %args) {
550559
; CHECK: "llvm.intr.var.annotation"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> ()

mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,17 @@ llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr, %mask: vector<7xi1
577577
llvm.return
578578
}
579579

580+
// CHECK-LABEL: @masked_expand_compress_intrinsics_with_alignment
581+
llvm.func @masked_expand_compress_intrinsics_with_alignment(%ptr: !llvm.ptr, %mask: vector<7xi1>, %passthru: vector<7xf32>) {
582+
// CHECK: call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
583+
%0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru) {arg_attrs = [{llvm.align = 8 : i32}, {}, {}]}
584+
: (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> (vector<7xf32>)
585+
// CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, ptr align 8 %{{.*}}, <7 x i1> %{{.*}})
586+
"llvm.intr.masked.compressstore"(%0, %ptr, %mask) {arg_attrs = [{}, {llvm.align = 8 : i32}, {}]}
587+
: (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> ()
588+
llvm.return
589+
}
590+
580591
// CHECK-LABEL: @annotate_intrinsics
581592
llvm.func @annotate_intrinsics(%var: !llvm.ptr, %int: i16, %ptr: !llvm.ptr, %annotation: !llvm.ptr, %fileName: !llvm.ptr, %line: i32, %attr: !llvm.ptr) {
582593
// CHECK: call void @llvm.var.annotation.p0.p0(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, ptr %{{.*}})

0 commit comments

Comments
 (0)