Skip to content

Commit 8aade1e

Browse files
committed
updated from AnyMemRef to MemRefOf[BF16, F16] and few clean-ups
1 parent d8205f9 commit 8aade1e

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -426,20 +426,18 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
426426
Example:
427427
```mlir
428428
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
429-
```
430-
```mlir
431429
%dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
432430
```
433431
}];
434-
let arguments = (ins AnyMemRef:$a);
432+
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
435433
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
436434
let assemblyFormat =
437435
"$a attr-dict`:` type($a)`->` type($dst)";
438436

439437
let extraClassDefinition = [{
440438
std::string $cppClass::getIntrinsicName() {
441439
auto elementType =
442-
(cast<MemRefType>(getA().getType())).getElementType();
440+
getA().getType().getElementType();
443441
std::string intr = "llvm.x86.";
444442
if (elementType.isBF16())
445443
intr += "vbcstnebf162ps";
@@ -453,7 +451,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
453451
}
454452
}];
455453

456-
let extraClassDeclaration = [{
454+
let extraClassDeclaration = [{
457455
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
458456
}];
459457

@@ -476,20 +474,18 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
476474
Example:
477475
```mlir
478476
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
479-
```
480-
```mlir
481477
%dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
482478
```
483479
}];
484-
let arguments = (ins AnyMemRef:$a);
480+
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
485481
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
486482
let assemblyFormat =
487483
"$a attr-dict`:` type($a)`->` type($dst)";
488484

489485
let extraClassDefinition = [{
490486
std::string $cppClass::getIntrinsicName() {
491487
auto elementType =
492-
(cast<MemRefType>(getA().getType())).getElementType();
488+
getA().getType().getElementType();
493489
std::string intr = "llvm.x86.";
494490
if (elementType.isBF16())
495491
intr += "vcvtneebf162ps";
@@ -521,20 +517,18 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
521517
Example:
522518
```mlir
523519
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
524-
```
525-
```mlir
526520
%dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
527521
```
528522
}];
529-
let arguments = (ins AnyMemRef:$a);
523+
let arguments = (ins MemRefOf<[BF16, F16]>:$a);
530524
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
531525
let assemblyFormat =
532526
"$a attr-dict`:` type($a)`->` type($dst)";
533527

534528
let extraClassDefinition = [{
535529
std::string $cppClass::getIntrinsicName() {
536530
auto elementType =
537-
(cast<MemRefType>(getA().getType())).getElementType();
531+
getA().getType().getElementType();
538532
std::string intr = "llvm.x86.";
539533
if (elementType.isBF16())
540534
intr += "vcvtneobf162ps";

0 commit comments

Comments
 (0)