@@ -426,20 +426,18 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
426426 Example:
427427 ```mlir
428428 %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
429- ```
430- ```mlir
431429 %dst = x86vector.avx.bcst_to_f32.packed %a : memref<1xf16> -> vector<8xf32>
432430 ```
433431 }];
434- let arguments = (ins AnyMemRef :$a);
432+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
435433 let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
436434 let assemblyFormat =
437435 "$a attr-dict`:` type($a)`->` type($dst)";
438436
439437 let extraClassDefinition = [{
440438 std::string $cppClass::getIntrinsicName() {
441439 auto elementType =
442- (cast<MemRefType>( getA().getType()) ).getElementType();
440+ getA().getType().getElementType();
443441 std::string intr = "llvm.x86.";
444442 if (elementType.isBF16())
445443 intr += "vbcstnebf162ps";
@@ -453,7 +451,7 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
453451 }
454452 }];
455453
456- let extraClassDeclaration = [{
454+ let extraClassDeclaration = [{
457455 SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
458456 }];
459457
@@ -476,20 +474,18 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
476474 Example:
477475 ```mlir
478476 %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
479- ```
480- ```mlir
481477 %dst = x86vector.avx.cvt.packed.even.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
482478 ```
483479 }];
484- let arguments = (ins AnyMemRef :$a);
480+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
485481 let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
486482 let assemblyFormat =
487483 "$a attr-dict`:` type($a)`->` type($dst)";
488484
489485 let extraClassDefinition = [{
490486 std::string $cppClass::getIntrinsicName() {
491487 auto elementType =
492- (cast<MemRefType>( getA().getType()) ).getElementType();
488+ getA().getType().getElementType();
493489 std::string intr = "llvm.x86.";
494490 if (elementType.isBF16())
495491 intr += "vcvtneebf162ps";
@@ -521,20 +517,18 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
521517 Example:
522518 ```mlir
523519 %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xbf16> -> vector<8xf32>
524- ```
525- ```mlir
526520 %dst = x86vector.avx.cvt.packed.odd.indexed_to_f32 %a : memref<16xf16> -> vector<8xf32>
527521 ```
528522 }];
529- let arguments = (ins AnyMemRef :$a);
523+ let arguments = (ins MemRefOf<[BF16, F16]> :$a);
530524 let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
531525 let assemblyFormat =
532526 "$a attr-dict`:` type($a)`->` type($dst)";
533527
534528 let extraClassDefinition = [{
535529 std::string $cppClass::getIntrinsicName() {
536530 auto elementType =
537- (cast<MemRefType>( getA().getType()) ).getElementType();
531+ getA().getType().getElementType();
538532 std::string intr = "llvm.x86.";
539533 if (elementType.isBF16())
540534 intr += "vcvtneobf162ps";
0 commit comments