Skip to content

Commit a9df22e

Browse files
committed
The input type changed to accept memref
1 parent 0a80bbc commit a9df22e

File tree

8 files changed

+111
-57
lines changed

8 files changed

+111
-57
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
8383
}
8484
}];
8585
let extraClassDeclaration = [{
86-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
86+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
8787
}];
8888
}
8989

@@ -404,7 +404,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
404404
}
405405
}];
406406
let extraClassDeclaration = [{
407-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
407+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
408408
}];
409409
}
410410

@@ -413,7 +413,7 @@ def DotOp : AVX_LowOp<"dot", [Pure,
413413
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
414414
//----------------------------------------------------------------------------//
415415

416-
def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [Pure,
416+
def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [
417417
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418418
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
419419
let description = [{
@@ -428,7 +428,7 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
428428
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
429429
```
430430
}];
431-
let arguments = (ins LLVM_AnyPointer:$a);
431+
let arguments = (ins AnyMemRef:$a);
432432
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433433
let assemblyFormat =
434434
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -443,9 +443,13 @@ def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f3
443443
return intr;
444444
}
445445
}];
446+
447+
let extraClassDeclaration = [{
448+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449+
}];
446450
}
447451

448-
def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [Pure,
452+
def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [
449453
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
450454
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
451455
let description = [{
@@ -460,7 +464,7 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
460464
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
461465
```
462466
}];
463-
let arguments = (ins LLVM_AnyPointer:$a);
467+
let arguments = (ins AnyMemRef:$a);
464468
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
465469
let assemblyFormat =
466470
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -475,13 +479,17 @@ def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32"
475479
return intr;
476480
}
477481
}];
482+
483+
let extraClassDeclaration = [{
484+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
485+
}];
478486
}
479487

480488
//----------------------------------------------------------------------------//
481489
// AVX: Convert BF16 to F32 and broadcast into packed F32
482490
//----------------------------------------------------------------------------//
483491

484-
def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
492+
def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [
485493
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
486494
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
487495
let description = [{
@@ -497,7 +505,7 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
497505
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
498506
```
499507
}];
500-
let arguments = (ins LLVM_AnyPointer:$a);
508+
let arguments = (ins AnyMemRef:$a);
501509
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
502510
let assemblyFormat =
503511
"$a attr-dict`:` type($a)`->` type($dst)";
@@ -512,6 +520,11 @@ def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [Pure,
512520
return intr;
513521
}
514522
}];
523+
524+
let extraClassDeclaration = [{
525+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
526+
}];
527+
515528
}
516529

517530
#endif // X86VECTOR_OPS

mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1718
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1819
#include "mlir/IR/BuiltinTypes.h"
1920
#include "mlir/IR/Dialect.h"

mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
5858
}],
5959
/*retType=*/"SmallVector<Value>",
6060
/*methodName=*/"getIntrinsicOperands",
61-
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
61+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
6262
/*methodBody=*/"",
6363
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
6464
>,

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
14+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1416
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1517
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1618
#include "mlir/IR/Builders.h"
@@ -31,6 +33,26 @@ void x86vector::X86VectorDialect::initialize() {
3133
>();
3234
}
3335

36+
static SmallVector<Value>
37+
getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
38+
RewriterBase &rewriter,
39+
const LLVMTypeConverter &typeConverter) {
40+
SmallVector<Value> operands;
41+
auto opType = memrefVal.getType();
42+
43+
Type llvmStructType = typeConverter.convertType(opType);
44+
Value llvmStruct =
45+
rewriter
46+
.create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
47+
.getResult(0);
48+
MemRefDescriptor memRefDescriptor(llvmStruct);
49+
50+
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
51+
operands.push_back(ptr);
52+
53+
return operands;
54+
}
55+
3456
LogicalResult x86vector::MaskCompressOp::verify() {
3557
if (getSrc() && getConstantSrc())
3658
return emitError("cannot use both src and constant_src");
@@ -45,8 +67,8 @@ LogicalResult x86vector::MaskCompressOp::verify() {
4567
return success();
4668
}
4769

48-
SmallVector<Value>
49-
x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
70+
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
71+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
5072
auto loc = getLoc();
5173

5274
auto opType = getA().getType();
@@ -64,7 +86,8 @@ x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
6486
}
6587

6688
SmallVector<Value>
67-
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
89+
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
90+
const LLVMTypeConverter &typeConverter) {
6891
SmallVector<Value> operands(getOperands());
6992
// Dot product of all elements, broadcasted to all elements.
7093
Value scale =
@@ -74,5 +97,22 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
7497
return operands;
7598
}
7699

100+
SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
101+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
102+
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
103+
}
104+
105+
SmallVector<Value>
106+
x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
107+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
108+
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
109+
}
110+
111+
SmallVector<Value>
112+
x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
113+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
114+
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
115+
}
116+
77117
#define GET_OP_CLASSES
78118
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ struct OneToOneIntrinsicOpConversion
9696
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
9797
PatternRewriter &rewriter) const override {
9898
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
99-
op.getIntrinsicOperands(rewriter), typeConverter,
100-
rewriter);
99+
op.getIntrinsicOperands(rewriter, typeConverter),
100+
typeConverter, rewriter);
101101
}
102102

103103
private:

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,55 +97,55 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
9797

9898
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
9999
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
100-
%a: !llvm.ptr) -> vector<4xf32>
100+
%a: memref<8xbf16>) -> vector<4xf32>
101101
{
102102
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
103-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
103+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104104
return %0 : vector<4xf32>
105105
}
106106

107107
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
108108
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109-
%a: !llvm.ptr) -> vector<8xf32>
109+
%a: memref<16xbf16>) -> vector<8xf32>
110110
{
111111
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
112-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
112+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
113113
return %0 : vector<8xf32>
114114
}
115115

116116
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
117117
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
118-
%a: !llvm.ptr) -> vector<4xf32>
118+
%a: memref<8xbf16>) -> vector<4xf32>
119119
{
120120
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
121-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
121+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
122122
return %0 : vector<4xf32>
123123
}
124124

125125
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
126126
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
127-
%a: !llvm.ptr) -> vector<8xf32>
127+
%a: memref<16xbf16>) -> vector<8xf32>
128128
{
129129
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
130-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
130+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
131131
return %0 : vector<8xf32>
132132
}
133133

134134
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
135135
func.func @avxbf16_bsct_bf16_to_f32_packed_128(
136-
%a: !llvm.ptr) -> vector<4xf32>
136+
%a: memref<1xbf16>) -> vector<4xf32>
137137
{
138138
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
139-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
139+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
140140
return %0 : vector<4xf32>
141141
}
142142

143143
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
144144
func.func @avxbf16_bsct_bf16_to_f32_packed_256(
145-
%a: !llvm.ptr) -> vector<8xf32>
145+
%a: memref<1xbf16>) -> vector<8xf32>
146146
{
147147
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
148-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
148+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
149149
return %0 : vector<8xf32>
150150
}
151151

mlir/test/Dialect/X86Vector/roundtrip.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,61 +96,61 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
9696

9797
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
9898
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
99-
%a: !llvm.ptr) -> vector<4xf32>
99+
%a: memref<8xbf16>) -> vector<4xf32>
100100
{
101101
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
102-
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
103-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
102+
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
103+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104104
return %0 : vector<4xf32>
105105
}
106106

107107
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
108108
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109-
%a: !llvm.ptr) -> vector<8xf32>
109+
%a: memref<16xbf16>) -> vector<8xf32>
110110
{
111111
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
112-
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
113-
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
112+
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
113+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
114114
return %0 : vector<8xf32>
115115
}
116116

117117
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
118118
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
119-
%a: !llvm.ptr) -> vector<4xf32>
119+
%a: memref<8xbf16>) -> vector<4xf32>
120120
{
121121
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
122-
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
123-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<4xf32>
122+
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
123+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
124124
return %0 : vector<4xf32>
125125
}
126126

127127
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
128128
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
129-
%a: !llvm.ptr) -> vector<8xf32>
129+
%a: memref<16xbf16>) -> vector<8xf32>
130130
{
131131
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
132-
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
133-
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : !llvm.ptr -> vector<8xf32>
132+
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
133+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
134134
return %0 : vector<8xf32>
135135
}
136136

137137
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
138138
func.func @avxbf16_bcst_bf16_to_f32_128(
139-
%a: !llvm.ptr) -> vector<4xf32>
139+
%a: memref<1xbf16>) -> vector<4xf32>
140140
{
141141
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
142-
// CHECK-SAME: !llvm.ptr -> vector<4xf32>
143-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<4xf32>
142+
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
143+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
144144
return %0 : vector<4xf32>
145145
}
146146

147147
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
148148
func.func @avxbf16_bcst_bf16_to_f32_256(
149-
%a: !llvm.ptr) -> vector<8xf32>
149+
%a: memref<1xbf16>) -> vector<8xf32>
150150
{
151151
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
152-
// CHECK-SAME: !llvm.ptr -> vector<8xf32>
153-
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : !llvm.ptr -> vector<8xf32>
152+
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
153+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
154154
return %0 : vector<8xf32>
155155
}
156156

0 commit comments

Comments
 (0)