Skip to content

Commit 933694b

Browse files
committed
fixup! fixup! fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation
Final tweaks
1 parent 61e7984 commit 933694b

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
//
1313
/// Currently, only power-of-two integer types are supported. These are
1414
/// converted to wider integers that are either 8 bits wide or wider.
15-
//
15+
///
16+
/// TODO: Support for non-powers-of-two.
1617
//===----------------------------------------------------------------------===//
1718

1819
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -351,6 +352,9 @@ struct ConvertVectorMaskedStore final
351352
//
352353
// Using the compressed mask to store %packed_data results in expected
353354
// output.
355+
//
356+
// FIXME: Make an example based on the comment above work (see #115460 for
357+
// reproducer).
354358
FailureOr<Operation *> newMask =
355359
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
356360
if (failed(newMask))

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -515,11 +515,11 @@ func.func @vector_maskedstore_i4(
515515
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
516516
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
517517
// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
518-
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
519-
// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
518+
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]]
519+
// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]]()[%[[NUM_EL_TO_STORE]]]
520520
// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
521521
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
522-
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
522+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
523523
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
524524
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
525525
// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
@@ -535,11 +535,11 @@ func.func @vector_maskedstore_i4(
535535
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
536536
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
537537
// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
538-
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
539-
// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
538+
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]]()[%[[IDX_1]], %[[IDX_2]]]
539+
// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]]()[%[[NUM_EL_TO_STORE]]]
540540
// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
541541
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
542-
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
542+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
543543
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
544544
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
545545
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
@@ -600,10 +600,10 @@ func.func @vector_maskedstore_i4_constant_mask(
600600
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
601601
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
602602
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
603-
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
603+
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]]()[%[[IDX_1]], %[[IDX_2]]]
604604
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
605605
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
606-
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
606+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
607607
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
608608
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
609609
// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
@@ -616,10 +616,10 @@ func.func @vector_maskedstore_i4_constant_mask(
616616
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
617617
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
618618
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
619-
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
619+
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]]()[%[[IDX_1]], %[[IDX_2]]]
620620
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
621621
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
622-
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
622+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
623623
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
624624
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
625625
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>

0 commit comments

Comments
 (0)