Skip to content

Commit 90ebc48

Browse files
matthias-springernicolasvasilache
authored andcommitted
Add vp2intersect to AVX512 dialect.
Adds vp2intersect to the AVX512 dialect and defines a lowering to the LLVM dialect. Author: Matthias Springer <[email protected]> Differential Revision: https://reviews.llvm.org/D95301
1 parent d705c2f commit 90ebc48

File tree

7 files changed

+118
-7
lines changed

7 files changed

+118
-7
lines changed

mlir/include/mlir/Dialect/AVX512/AVX512.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,41 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
9696
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
9797
}
9898

99+
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
100+
AllTypesMatch<["a", "b"]>,
101+
TypesMatchWith<"k1 has the same number of bits as elements in a",
102+
"a", "k1",
103+
"IntegerType::get($_self.getContext(), "
104+
"($_self.cast<VectorType>().getShape()[0]))">,
105+
TypesMatchWith<"k2 has the same number of bits as elements in b",
106+
// Should use `b` instead of `a`, but that would require
107+
// adding `type($b)` to assemblyFormat.
108+
"a", "k2",
109+
"IntegerType::get($_self.getContext(), "
110+
"($_self.cast<VectorType>().getShape()[0]))">]> {
111+
let summary = "Vp2Intersect op";
112+
let description = [{
113+
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
114+
LLVMAVX512 operation: `llvm.vp2intersect.d.512` or
115+
`llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is
116+
applied to.
117+
118+
#### From the Intel Intrinsics Guide:
119+
120+
Compute intersection of packed integer vectors `a` and `b`, and store
121+
indication of match in the corresponding bit of two mask registers
122+
specified by `k1` and `k2`. A match in corresponding elements of `a` and
123+
`b` is indicated by a set bit in the corresponding bit of the mask
124+
registers.
125+
}];
126+
let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
127+
VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
128+
);
129+
let results = (outs AnyTypeOf<[I16, I8]>:$k1,
130+
AnyTypeOf<[I16, I8]>:$k2
131+
);
132+
let assemblyFormat =
133+
"$a `,` $b attr-dict `:` type($a)";
134+
}
135+
99136
#endif // AVX512_OPS

mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/BuiltinTypes.h"
1717
#include "mlir/IR/Dialect.h"
1818
#include "mlir/IR/OpDefinition.h"
19+
#include "mlir/IR/OpImplementation.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"
2021

2122
#include "mlir/Dialect/AVX512/AVX512Dialect.h.inc"

mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,33 @@ def LLVMAVX512_Dialect : Dialect {
2828
// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system
2929
//----------------------------------------------------------------------------//
3030

31-
class LLVMAVX512_IntrOp<string mnemonic, list<OpTrait> traits = []> :
31+
class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
3232
LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
3333
"x86_avx512_" # !subst(".", "_", mnemonic),
34-
[], [], traits, 1>;
34+
[], [], traits, numResults>;
3535

3636
def LLVM_x86_avx512_mask_rndscale_ps_512 :
37-
LLVMAVX512_IntrOp<"mask.rndscale.ps.512">,
37+
LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
3838
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
3939

4040
def LLVM_x86_avx512_mask_rndscale_pd_512 :
41-
LLVMAVX512_IntrOp<"mask.rndscale.pd.512">,
41+
LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>,
4242
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
4343

4444
def LLVM_x86_avx512_mask_scalef_ps_512 :
45-
LLVMAVX512_IntrOp<"mask.scalef.ps.512">,
45+
LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>,
4646
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
4747

4848
def LLVM_x86_avx512_mask_scalef_pd_512 :
49-
LLVMAVX512_IntrOp<"mask.scalef.pd.512">,
49+
LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
5050
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
5151

52+
def LLVM_x86_avx512_vp2intersect_d_512 :
53+
LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
54+
Arguments<(ins LLVM_Type, LLVM_Type)>;
55+
56+
def LLVM_x86_avx512_vp2intersect_q_512 :
57+
LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>,
58+
Arguments<(ins LLVM_Type, LLVM_Type)>;
59+
5260
#endif // AVX512_OPS

mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,38 @@ struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
7777
return failure();
7878
}
7979
};
80+
81+
struct Vp2IntersectOp512Conversion
82+
: public ConvertOpToLLVMPattern<Vp2IntersectOp> {
83+
explicit Vp2IntersectOp512Conversion(MLIRContext *context,
84+
LLVMTypeConverter &typeConverter)
85+
: ConvertOpToLLVMPattern<Vp2IntersectOp>(typeConverter) {}
86+
87+
LogicalResult
88+
matchAndRewrite(Vp2IntersectOp op, ArrayRef<Value> operands,
89+
ConversionPatternRewriter &rewriter) const override {
90+
Type elementType =
91+
op.a().getType().template cast<VectorType>().getElementType();
92+
if (elementType.isInteger(32))
93+
return LLVM::detail::oneToOneRewrite(
94+
op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands,
95+
*getTypeConverter(), rewriter);
96+
if (elementType.isInteger(64))
97+
return LLVM::detail::oneToOneRewrite(
98+
op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands,
99+
*getTypeConverter(), rewriter);
100+
return failure();
101+
}
102+
};
80103
} // namespace
81104

82105
/// Populate the given list with patterns that convert from AVX512 to LLVM.
83106
void mlir::populateAVX512ToLLVMConversionPatterns(
84107
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
85108
// clang-format off
86109
patterns.insert<MaskRndScaleOp512Conversion,
87-
ScaleFOp512Conversion>(&converter.getContext(), converter);
110+
ScaleFOp512Conversion,
111+
Vp2IntersectOp512Conversion>(&converter.getContext(),
112+
converter);
88113
// clang-format on
89114
}

mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,13 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
1616
// Keep results alive.
1717
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
1818
}
19+
20+
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
21+
-> (i16, i16, i8, i8)
22+
{
23+
// CHECK: llvm_avx512.vp2intersect.d.512
24+
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
25+
// CHECK: llvm_avx512.vp2intersect.q.512
26+
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
27+
return %0, %1, %2, %3 : i16, i16, i8, i8
28+
}

mlir/test/Dialect/AVX512/roundtrip.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,13 @@ func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16,
1919
%1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
2020
return %0, %1: vector<16xf32>, vector<8xf64>
2121
}
22+
23+
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
24+
-> (i16, i16, i8, i8)
25+
{
26+
// CHECK: avx512.vp2intersect {{.*}} : vector<16xi32>
27+
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
28+
// CHECK: avx512.vp2intersect {{.*}} : vector<8xi64>
29+
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
30+
return %0, %1, %2, %3 : i16, i16, i8, i8
31+
}

mlir/test/Target/avx512.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,23 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
2929
(vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64>
3030
llvm.return %1: vector<8xf64>
3131
}
32+
33+
// CHECK-LABEL: define <{ i16, i16 }> @LLVM_x86_vp2intersect_d_512
34+
llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
35+
-> !llvm.struct<packed (i16, i16)>
36+
{
37+
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
38+
%0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) :
39+
(vector<16xi32>, vector<16xi32>) -> !llvm.struct<packed (i16, i16)>
40+
llvm.return %0 : !llvm.struct<packed (i16, i16)>
41+
}
42+
43+
// CHECK-LABEL: define <{ i8, i8 }> @LLVM_x86_vp2intersect_q_512
44+
llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
45+
-> !llvm.struct<packed (i8, i8)>
46+
{
47+
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
48+
%0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) :
49+
(vector<8xi64>, vector<8xi64>) -> !llvm.struct<packed (i8, i8)>
50+
llvm.return %0 : !llvm.struct<packed (i8, i8)>
51+
}

0 commit comments

Comments
 (0)