Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
//-*-===//
//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent rewrites and utilities to emulate
// narrow types that are not supported by the target hardware, e.g. i4, using
// wider types, e.g. i8.
//
/// Currently, only power-of-two integer types are supported. These are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there is actually nothing fundamental about power of 2 here. It is an implementation detail and just NYI. Maybe might be better to make that clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, perhaps we can explicitly say that the non-power-of-two types are not implemented yet in the description. It looks more clear for newcomers for the narrow type emulations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a TODO at the top - clear indication for anyone brave enough to tackle this :)

/// converted to wider integers that are either 8 bits wide or wider.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);

// Load the whole data and use arith.select to handle the corner cases.
// E.g., given these input values:
// E.g., given these input i4 values:
//
// %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
//
// %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1)
// %0[%c0, %c0] =
// [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
// %val_to_store =
// [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
//
// %mask = [0, 1, 1, 1, 1, 1, 0, 0]
// %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
// %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
// we'll have the following i4 output:
//
// we'll have
// expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
//
// expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
// Emulating the above using i8 will give:
//
// %new_mask = [1, 1, 1, 0]
// %maskedload = [0x12, 0x34, 0x56, 0x00]
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
// %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
// %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
// %compressed_mask = [1, 1, 1, 1] (4 * i1)
// %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8)
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
// %select_using_shifted_mask =
// [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4)
// %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8)
//
// Using the new mask to store %packed_data results in expected output.
FailureOr<Operation *> newMask =
Expand Down
126 changes: 126 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32

///----------------------------------------------------------------------------------------
/// vector.load
///----------------------------------------------------------------------------------------

func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
Expand Down Expand Up @@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %

// -----

///----------------------------------------------------------------------------------------
/// vector.transfer_read
///----------------------------------------------------------------------------------------

func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
%c0 = arith.constant 0 : i4
%0 = memref.alloc() : memref<3x8xi4>
Expand Down Expand Up @@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {

// -----

///----------------------------------------------------------------------------------------
/// vector.maskedload
///----------------------------------------------------------------------------------------

func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %arg3 : vector<4xi1>
Expand Down Expand Up @@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto

// -----

///----------------------------------------------------------------------------------------
/// vector.extract -> vector.masked_load
///----------------------------------------------------------------------------------------

func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
%0 = memref.alloc() : memref<8x8x16xi4>
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {

// -----

///----------------------------------------------------------------------------------------
/// vector.store
///----------------------------------------------------------------------------------------

func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
%0 = memref.alloc() : memref<4x8xi8>
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
Expand Down Expand Up @@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind

// -----

///----------------------------------------------------------------------------------------
/// vector.maskedstore
///----------------------------------------------------------------------------------------

func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.create_mask %arg2 : vector<8xi1>
Expand Down Expand Up @@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu

// -----

func.func @vector_maskedstore_i4(
%idx1: index,
%idx2: index,
%num_elements_to_store: index,
%value: vector<8xi4>) {

%0 = memref.alloc() : memref<3x8xi4>
%cst = arith.constant dense<0> : vector<3x8xi4>
%mask = vector.create_mask %num_elements_to_store : vector<8xi1>
vector.maskedstore %0[%idx1, %idx2], %mask, %value :
memref<3x8xi4>, vector<8xi1>, vector<8xi4>
return
}
// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>

// CHECK-LABEL: func.func @vector_maskedstore_i4(
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we don't need to escape [ because we do not capture % in variables. It is one of the pros of the trick, and it looks cleaner. Can you update it and fix other checks in the new lit tests?

Suggested change
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]]()[%[[IDX_1]], %[[IDX_2]]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, very nice, thanks for pointing this out 🙏🏻

// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>

// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>

// CHECK32-LABEL: func.func @vector_maskedstore_i4(
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

// -----

func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
%0 = memref.alloc() : memref<3x8xi8>
%mask = vector.constant_mask [4] : vector<8xi1>
Expand Down Expand Up @@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]

// -----

func.func @vector_cst_maskedstore_i4(
%idx_1: index,
%idx_2: index,
%val_to_store: vector<8xi4>) {

%0 = memref.alloc() : memref<3x8xi4>
%cst = arith.constant dense<0> : vector<3x8xi4>
%mask = vector.constant_mask [4] : vector<8xi1>
vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
memref<3x8xi4>, vector<8xi1>, vector<8xi4>
return
}

// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
// CHECK-LABEL: func.func @vector_cst_maskedstore_i4(
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>

// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4(
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
Loading