From 8a9abf6714427af3d82b13b28961eb6f65c18b4a Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 8 Nov 2024 10:44:02 +0000 Subject: [PATCH 1/4] [mlir][vector][nfc] Add tests + update docs for narrow-type emulation The documentation for narrow-type emulation is a bit inaccurate. In particular, we don't really support/generate masks like this: %mask = [0, 1, 1, 1, 1, 1, 0, 0] I updated the comment for `ConvertVectorMaskedStore` accordingly. I also added a few clarification (e.g. that the comment is discussing i4 -> i8 emulation). Separately, I've noticed inconsistency in testing for narrow-type-emulation. In particular, there's a few cases that are tested for "loading" and which are missing for "storing". I've added * comments in the test file so that it's easy to see what's tested, * missing tests for `vector.maskedstor`. Finally, I've added a top level comment in VectorEmulateNarrowType.cpp so that the overall intent and design are clearer. --- .../Transforms/VectorEmulateNarrowType.cpp | 41 ++++-- .../Vector/vector-emulate-narrow-type.mlir | 126 ++++++++++++++++++ 2 files changed, 154 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 58841f29698e0..0f88ff21e847e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1,11 +1,19 @@ -//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++ -//-*-===// +//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to emulate +// narrow types that are not supported by the target hardware, e.g. i4, using +// wider types, e.g. i8. +// +/// Currently, only power-of-two integer types are supported. These are +/// converted to wider integers that are either 8 bits wide or wider. +// +//===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr); // Load the whole data and use arith.select to handle the corner cases. - // E.g., given these input values: + // E.g., given these input i4 values: + // + // %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store : + // + // %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1) + // %0[%c0, %c0] = + // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4) + // %val_to_store = + // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4) // - // %mask = [0, 1, 1, 1, 1, 1, 0, 0] - // %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] - // %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] + // we'll have the following i4 output: // - // we'll have + // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] // - // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8] + // Emulating the above using i8 will give: // - // %new_mask = [1, 1, 1, 0] - // %maskedload = [0x12, 0x34, 0x56, 0x00] - // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] - // %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0] - // %packed_data = [0x1A, 0xBC, 0xDE, 0x00] + // %compressed_mask = [1, 1, 1, 1] (4 * i1) + // %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8) + // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4) + // %select_using_shifted_mask = + // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4) + // %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8) // // Using the new mask to store %packed_data results in expected output. FailureOr newMask = diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index cba299b2a1d95..c98b4dd50a702 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -1,6 +1,10 @@ // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s // RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32 +///---------------------------------------------------------------------------------------- +/// vector.load +///---------------------------------------------------------------------------------------- + func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> { %0 = memref.alloc() : memref<3x4xi8> %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8> @@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, % // ----- +///---------------------------------------------------------------------------------------- +/// vector.transfer_read +///---------------------------------------------------------------------------------------- + func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> { %c0 = arith.constant 0 : i4 %0 = memref.alloc() : memref<3x8xi4> @@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> { // ----- +///---------------------------------------------------------------------------------------- +/// vector.maskedload +///---------------------------------------------------------------------------------------- + func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> { %0 = memref.alloc() : memref<3x4xi8> %mask = vector.create_mask %arg3 : vector<4xi1> @@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto // ----- +///---------------------------------------------------------------------------------------- +/// vector.extract -> vector.masked_load +///---------------------------------------------------------------------------------------- + func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> { %0 = memref.alloc() : memref<8x8x16xi4> %c0 = arith.constant 0 : index @@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> { // ----- +///---------------------------------------------------------------------------------------- +/// vector.store +///---------------------------------------------------------------------------------------- + func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) { %0 = memref.alloc() : memref<4x8xi8> vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8> @@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind // ----- +///---------------------------------------------------------------------------------------- +/// vector.maskedstore +///---------------------------------------------------------------------------------------- + func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) { %0 = memref.alloc() : memref<3x8xi8> %mask = vector.create_mask %arg2 : vector<8xi1> @@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu // ----- +func.func @vector_maskedstore_i4( + %idx1: index, + %idx2: index, + %num_elements_to_store: index, + %value: vector<8xi4>) { + + %0 = memref.alloc() : memref<3x8xi4> + %cst = arith.constant dense<0> : vector<3x8xi4> + %mask = vector.create_mask %num_elements_to_store : vector<8xi1> + vector.maskedstore %0[%idx1, %idx2], %mask, %value : + memref<3x8xi4>, vector<8xi1>, vector<8xi4> + return +} +// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)> + +// CHECK-LABEL: func.func @vector_maskedstore_i4( +// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1> +// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]] +// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1> +// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> +// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> +// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> +// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8> + +// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)> + +// CHECK32-LABEL: func.func @vector_maskedstore_i4( +// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, +// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index, +// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index, +// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1> +// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]] +// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1> +// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> +// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> +// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> +// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> +// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32> + +// ----- + func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) { %0 = memref.alloc() : memref<3x8xi8> %mask = vector.constant_mask [4] : vector<8xi1> @@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector< // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8> // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32> // CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] + +// ----- + +func.func @vector_cst_maskedstore_i4( + %idx_1: index, + %idx_2: index, + %val_to_store: vector<8xi4>) { + + %0 = memref.alloc() : memref<3x8xi4> + %cst = arith.constant dense<0> : vector<3x8xi4> + %mask = vector.constant_mask [4] : vector<8xi1> + vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store : + memref<3x8xi4>, vector<8xi1>, vector<8xi4> + return +} + +// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK-LABEL: func.func @vector_cst_maskedstore_i4( +// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1> +// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1> +// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> +// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> +// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8> + +// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4( +// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, +// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index, +// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1> +// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1> +// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> +// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> +// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4> +// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> +// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32> From eef52e2784ac8bf7d7b011d8db84f284f8f14576 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 8 Nov 2024 13:45:58 +0000 Subject: [PATCH 2/4] fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation * Fix failing test * Tweak/fix the comment * Rename: @vector_cst_maskedload_i8 -> @vector_cst_maskedload_i8_constant_mask (same for other similar tests) --- .../Transforms/VectorEmulateNarrowType.cpp | 24 ++++++----- .../Vector/vector-emulate-narrow-type.mlir | 42 +++++++++---------- 2 files changed, 34 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 0f88ff21e847e..b9f5c71fa4805 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -323,11 +323,14 @@ struct ConvertVectorMaskedStore final getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr); // Load the whole data and use arith.select to handle the corner cases. - // E.g., given these input i4 values: // - // %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store : + // As an example, for this masked store: // - // %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1) + // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store + // + // and given these input i4 values: + // + // %mask = [1, 1, 1, 1, 1, 0, 0, 0] (8 * i1) // %0[%c0, %c0] = // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4) // %val_to_store = @@ -335,18 +338,19 @@ struct ConvertVectorMaskedStore final // // we'll have the following i4 output: // - // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] + // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8] // // Emulating the above using i8 will give: // - // %compressed_mask = [1, 1, 1, 1] (4 * i1) - // %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8) - // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4) + // %compressed_mask = [1, 1, 1, 0] (4 * i1) + // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8) + // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4) // %select_using_shifted_mask = - // [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4) - // %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8) + // [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4) + // %packed_data = [0x9A, 0xBC, 0xD6, 0x00] (4 * i8) // - // Using the new mask to store %packed_data results in expected output. + // Using the compressed mask to store %packed_data results in expected + // output. FailureOr newMask = getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale); if (failed(newMask)) diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index c98b4dd50a702..5e139b04d7ee6 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -202,7 +202,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt // ----- -func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> { +func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passthru: vector<4xi8>) -> vector<4xi8> { %0 = memref.alloc() : memref<3x4xi8> %mask = vector.constant_mask [2] : vector<4xi1> %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : @@ -210,7 +210,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto return %1 : vector<4xi8> } // Expect no conversions, i8 is supported. -// CHECK: func @vector_cst_maskedload_i8( +// CHECK: func @vector_maskedload_i8_constant_mask( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<4xi8>) // CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8> @@ -220,7 +220,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto // CHECK-NEXT: return // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)> -// CHECK32: func @vector_cst_maskedload_i8( +// CHECK32: func @vector_maskedload_i8_constant_mask( // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: vector<4xi8>) // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> @@ -236,7 +236,7 @@ func.func @vector_cst_maskedload_i8(%arg1: index, %arg2: index, %passthru: vecto // ----- -func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> { +func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passthru: vector<8xi4>) -> vector<3x8xi4> { %0 = memref.alloc() : memref<3x8xi4> %cst = arith.constant dense<0> : vector<3x8xi4> %mask = vector.constant_mask [4] : vector<8xi1> @@ -246,7 +246,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto return %2 : vector<3x8xi4> } // CHECK-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> -// CHECK: func @vector_cst_maskedload_i4( +// CHECK: func @vector_maskedload_i4_constant_mask( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>) // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> @@ -260,7 +260,7 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4> // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> -// CHECK32: func @vector_cst_maskedload_i4( +// CHECK32: func @vector_maskedload_i4_constant_mask( // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index, // CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: vector<8xi4>) // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> @@ -500,7 +500,6 @@ func.func @vector_maskedstore_i4( %value: vector<8xi4>) { %0 = memref.alloc() : memref<3x8xi4> - %cst = arith.constant dense<0> : vector<3x8xi4> %mask = vector.create_mask %num_elements_to_store : vector<8xi1> vector.maskedstore %0[%idx1, %idx2], %mask, %value : memref<3x8xi4>, vector<8xi1>, vector<8xi4> @@ -548,14 +547,14 @@ func.func @vector_maskedstore_i4( // ----- -func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) { +func.func @vector_maskedstore_i8_constant_mask(%arg0: index, %arg1: index, %value: vector<8xi8>) { %0 = memref.alloc() : memref<3x8xi8> %mask = vector.constant_mask [4] : vector<8xi1> vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8> return } // Expect no conversions, i8 is supported. -// CHECK: func @vector_cst_maskedstore_i8( +// CHECK: func @vector_maskedstore_i8_constant_mask( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[VAL:[a-zA-Z0-9]+]] @@ -565,7 +564,7 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector< // CHECK-NEXT: return // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)> -// CHECK32: func @vector_cst_maskedstore_i8( +// CHECK32: func @vector_maskedstore_i8_constant_mask( // CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK32-SAME: %[[VAL:[a-zA-Z0-9]+]] @@ -582,13 +581,12 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector< // ----- -func.func @vector_cst_maskedstore_i4( +func.func @vector_maskedstore_i4_constant_mask( %idx_1: index, %idx_2: index, %val_to_store: vector<8xi4>) { %0 = memref.alloc() : memref<3x8xi4> - %cst = arith.constant dense<0> : vector<3x8xi4> %mask = vector.constant_mask [4] : vector<8xi1> vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store : memref<3x8xi4>, vector<8xi1>, vector<8xi4> @@ -596,7 +594,7 @@ func.func @vector_cst_maskedstore_i4( } // CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> -// CHECK-LABEL: func.func @vector_cst_maskedstore_i4( +// CHECK-LABEL: func.func @vector_maskedstore_i4_constant_mask( // CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index, // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { @@ -606,13 +604,13 @@ func.func @vector_cst_maskedstore_i4( // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1> // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> -// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> -// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4> -// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> -// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8> +// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> +// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> +// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> +// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8> // CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> -// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4( +// CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask( // CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index, // CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index, // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { @@ -622,7 +620,7 @@ func.func @vector_cst_maskedstore_i4( // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1> // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> -// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> -// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4> -// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> -// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32> +// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> +// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> +// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> +// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32> From 61e79844aa3696c4e4b13d6ecb6c19a70dbe66ff Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 11 Nov 2024 16:53:56 +0000 Subject: [PATCH 3/4] fixup! fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation Refine the comment --- .../Vector/Transforms/VectorEmulateNarrowType.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index b9f5c71fa4805..de9b01de563df 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -324,13 +324,13 @@ struct ConvertVectorMaskedStore final // Load the whole data and use arith.select to handle the corner cases. // - // As an example, for this masked store: + // As an example, for this masked store of i4 values: // // vector.maskedstore %0[%c0, %c0], %mask, %val_to_store // - // and given these input i4 values: + // and given these input values: // - // %mask = [1, 1, 1, 1, 1, 0, 0, 0] (8 * i1) + // %mask = [0, 1, 1, 1, 1, 0, 0, 0] (8 * i1) // %0[%c0, %c0] = // [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4) // %val_to_store = @@ -338,7 +338,7 @@ struct ConvertVectorMaskedStore final // // we'll have the following i4 output: // - // expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8] + // expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x7, 0x8] // // Emulating the above using i8 will give: // @@ -346,8 +346,8 @@ struct ConvertVectorMaskedStore final // %maskedload = [0x12, 0x34, 0x56, 0x00] (4 * i8) // %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0] (8 * i4) // %select_using_shifted_mask = - // [0x9, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4) - // %packed_data = [0x9A, 0xBC, 0xD6, 0x00] (4 * i8) + // [0x1, 0xA, 0xB, 0xC, 0xD, 0x6, 0x0, 0x0] (8 * i4) + // %packed_data = [0x1A, 0xBC, 0xD6, 0x00] (4 * i8) // // Using the compressed mask to store %packed_data results in expected // output. From 933694b0a66bc52ece61b2a7da850a9a46a2ec26 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 12 Nov 2024 13:16:58 +0000 Subject: [PATCH 4/4] fixup! fixup! fixup! [mlir][vector][nfc] Add tests + update docs for narrow-type emulation Final tweaks --- .../Transforms/VectorEmulateNarrowType.cpp | 6 +++++- .../Vector/vector-emulate-narrow-type.mlir | 20 +++++++++---------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index de9b01de563df..2e8080cc34095 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -12,7 +12,8 @@ // /// Currently, only power-of-two integer types are supported. These are /// converted to wider integers that are either 8 bits wide or wider. -// +/// +/// TODO: Support for non-powers-of-two. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -351,6 +352,9 @@ struct ConvertVectorMaskedStore final // // Using the compressed mask to store %packed_data results in expected // output. + // + // FIXME: Make an example based on the comment above work (see #115460 for + // reproducer). FailureOr newMask = getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale); if (failed(newMask)) diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 5e139b04d7ee6..034bd47f6163e 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -515,11 +515,11 @@ func.func @vector_maskedstore_i4( // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> // CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1> -// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]] -// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]] +// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]]()[%[[NUM_EL_TO_STORE]]] // CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1> // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> -// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> @@ -535,11 +535,11 @@ func.func @vector_maskedstore_i4( // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> // CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1> -// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]] -// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]] +// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]]()[%[[IDX_1]], %[[IDX_2]]] +// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]]()[%[[NUM_EL_TO_STORE]]] // CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1> // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> -// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> @@ -600,10 +600,10 @@ func.func @vector_maskedstore_i4_constant_mask( // CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { // CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> // CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1> -// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1> // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> -// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> @@ -616,10 +616,10 @@ func.func @vector_maskedstore_i4_constant_mask( // CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) { // CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> // CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1> -// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]] +// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1> // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> -// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>