Skip to content

Commit 8a9abf6

Browse files
committed
[mlir][vector][nfc] Add tests + update docs for narrow-type emulation
The documentation for narrow-type emulation is a bit inaccurate. In particular, we don't really support/generate masks like this: %mask = [0, 1, 1, 1, 1, 1, 0, 0] I updated the comment for `ConvertVectorMaskedStore` accordingly. I also added a few clarification (e.g. that the comment is discussing i4 -> i8 emulation). Separately, I've noticed inconsistency in testing for narrow-type-emulation. In particular, there's a few cases that are tested for "loading" and which are missing for "storing". I've added * comments in the test file so that it's easy to see what's tested, * missing tests for `vector.maskedstor`. Finally, I've added a top level comment in VectorEmulateNarrowType.cpp so that the overall intent and design are clearer.
1 parent 58a17e1 commit 8a9abf6

File tree

2 files changed

+154
-13
lines changed

2 files changed

+154
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----*- C++
2-
//-*-===//
1+
//===- VectorEmulateNarrowType.cpp - Narrow type emulation ----------------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
65
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
76
//
87
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to emulate
10+
// narrow types that are not supported by the target hardware, e.g. i4, using
11+
// wider types, e.g. i8.
12+
//
13+
/// Currently, only power-of-two integer types are supported. These are
14+
/// converted to wider integers that are either 8 bits wide or wider.
15+
//
16+
//===----------------------------------------------------------------------===//
917

1018
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1119
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -315,21 +323,28 @@ struct ConvertVectorMaskedStore final
315323
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndicesOfr);
316324

317325
// Load the whole data and use arith.select to handle the corner cases.
318-
// E.g., given these input values:
326+
// E.g., given these input i4 values:
327+
//
328+
// %res = vector.maskedload %0[%c0, %c0], %mask, %val_to_store :
329+
//
330+
// %mask = [1, 1, 1, 1, 1, 1, 1, 0] (8 * i1)
331+
// %0[%c0, %c0] =
332+
// [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
333+
// %val_to_store =
334+
// [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0] (8 * i4)
319335
//
320-
// %mask = [0, 1, 1, 1, 1, 1, 0, 0]
321-
// %0[%c0, %c0] contains [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8]
322-
// %value_to_store = [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x0]
336+
// we'll have the following i4 output:
323337
//
324-
// we'll have
338+
// expected output: [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8]
325339
//
326-
// expected output: [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x7, 0x8]
340+
// Emulating the above using i8 will give:
327341
//
328-
// %new_mask = [1, 1, 1, 0]
329-
// %maskedload = [0x12, 0x34, 0x56, 0x00]
330-
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x0, 0x0]
331-
// %select_using_shifted_mask = [0x1, 0xA, 0xB, 0xC, 0xD, 0xE, 0x0, 0x0]
332-
// %packed_data = [0x1A, 0xBC, 0xDE, 0x00]
342+
// %compressed_mask = [1, 1, 1, 1] (4 * i1)
343+
// %maskedload = [0x12, 0x34, 0x56, 0x78] (4 * i8)
344+
// %bitcast = [0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8] (8 * i4)
345+
// %select_using_shifted_mask =
346+
// [0x9, 0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x8] (8 * i4)
347+
// %packed_data = [0x9A, 0xBC, 0xDE, 0xF8] (4 * i8)
333348
//
334349
// Using the new mask to store %packed_data results in expected output.
335350
FailureOr<Operation *> newMask =

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
22
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
33

4+
///----------------------------------------------------------------------------------------
5+
/// vector.load
6+
///----------------------------------------------------------------------------------------
7+
48
func.func @vector_load_i8(%arg1: index, %arg2: index) -> vector<4xi8> {
59
%0 = memref.alloc() : memref<3x4xi8>
610
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<4xi8>
@@ -82,6 +86,10 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
8286

8387
// -----
8488

89+
///----------------------------------------------------------------------------------------
90+
/// vector.transfer_read
91+
///----------------------------------------------------------------------------------------
92+
8593
func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
8694
%c0 = arith.constant 0 : i4
8795
%0 = memref.alloc() : memref<3x8xi4>
@@ -111,6 +119,10 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
111119

112120
// -----
113121

122+
///----------------------------------------------------------------------------------------
123+
/// vector.maskedload
124+
///----------------------------------------------------------------------------------------
125+
114126
func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<4xi8>) -> vector<4xi8> {
115127
%0 = memref.alloc() : memref<3x4xi8>
116128
%mask = vector.create_mask %arg3 : vector<4xi1>
@@ -263,6 +275,10 @@ func.func @vector_cst_maskedload_i4(%arg1: index, %arg2: index, %passthru: vecto
263275

264276
// -----
265277

278+
///----------------------------------------------------------------------------------------
279+
/// vector.extract -> vector.masked_load
280+
///----------------------------------------------------------------------------------------
281+
266282
func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> {
267283
%0 = memref.alloc() : memref<8x8x16xi4>
268284
%c0 = arith.constant 0 : index
@@ -353,6 +369,10 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
353369

354370
// -----
355371

372+
///----------------------------------------------------------------------------------------
373+
/// vector.store
374+
///----------------------------------------------------------------------------------------
375+
356376
func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
357377
%0 = memref.alloc() : memref<4x8xi8>
358378
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
@@ -431,6 +451,10 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
431451

432452
// -----
433453

454+
///----------------------------------------------------------------------------------------
455+
/// vector.maskedstore
456+
///----------------------------------------------------------------------------------------
457+
434458
func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) {
435459
%0 = memref.alloc() : memref<3x8xi8>
436460
%mask = vector.create_mask %arg2 : vector<8xi1>
@@ -469,6 +493,61 @@ func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %valu
469493

470494
// -----
471495

496+
func.func @vector_maskedstore_i4(
497+
%idx1: index,
498+
%idx2: index,
499+
%num_elements_to_store: index,
500+
%value: vector<8xi4>) {
501+
502+
%0 = memref.alloc() : memref<3x8xi4>
503+
%cst = arith.constant dense<0> : vector<3x8xi4>
504+
%mask = vector.create_mask %num_elements_to_store : vector<8xi1>
505+
vector.maskedstore %0[%idx1, %idx2], %mask, %value :
506+
memref<3x8xi4>, vector<8xi1>, vector<8xi4>
507+
return
508+
}
509+
// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
510+
// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> ((s0 + 1) floordiv 2)>
511+
512+
// CHECK-LABEL: func.func @vector_maskedstore_i4(
513+
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
514+
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
515+
// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
516+
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
517+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
518+
// CHECK: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
519+
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_10]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
520+
// CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]](){{\[}}%[[NUM_EL_TO_STORE]]]
521+
// CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1>
522+
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
523+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
524+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
525+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
526+
// CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
527+
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
528+
529+
// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
530+
// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> ((s0 + 7) floordiv 8)>
531+
532+
// CHECK32-LABEL: func.func @vector_maskedstore_i4(
533+
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
534+
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
535+
// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
536+
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
537+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
538+
// CHECK32: %[[ORIG_MASK:.+]] = vector.create_mask %[[NUM_EL_TO_STORE]] : vector<8xi1>
539+
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_17]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
540+
// CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]](){{\[}}%[[NUM_EL_TO_STORE]]]
541+
// CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1>
542+
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
543+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
544+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
545+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4>
546+
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
547+
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32>
548+
549+
// -----
550+
472551
func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<8xi8>) {
473552
%0 = memref.alloc() : memref<3x8xi8>
474553
%mask = vector.constant_mask [4] : vector<8xi1>
@@ -500,3 +579,50 @@ func.func @vector_cst_maskedstore_i8(%arg0: index, %arg1: index, %value: vector<
500579
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL]], %[[BITCAST]] : vector<8xi1>, vector<8xi8>
501580
// CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi8> to vector<2xi32>
502581
// CHECK32: vector.maskedstore %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]]
582+
583+
// -----
584+
585+
func.func @vector_cst_maskedstore_i4(
586+
%idx_1: index,
587+
%idx_2: index,
588+
%val_to_store: vector<8xi4>) {
589+
590+
%0 = memref.alloc() : memref<3x8xi4>
591+
%cst = arith.constant dense<0> : vector<3x8xi4>
592+
%mask = vector.constant_mask [4] : vector<8xi1>
593+
vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store :
594+
memref<3x8xi4>, vector<8xi1>, vector<8xi4>
595+
return
596+
}
597+
598+
// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
599+
// CHECK-LABEL: func.func @vector_cst_maskedstore_i4(
600+
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
601+
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
602+
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
603+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
604+
// CHECK: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
605+
// CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
606+
// CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1>
607+
// CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8>
608+
// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8>
609+
// CHECK: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4>
610+
// CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
611+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8>
612+
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
613+
614+
// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
615+
// CHECK32-LABEL: func.func @vector_cst_maskedstore_i4(
616+
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
617+
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
618+
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
619+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
620+
// CHECK32: %[[ORIG_MASK:.+]] = vector.constant_mask [4] : vector<8xi1>
621+
// CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
622+
// CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1>
623+
// CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32>
624+
// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32>
625+
// CHECK32: %[[VAL_9:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4>
626+
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_2]], %[[VAL_9]] : vector<8xi1>, vector<8xi4>
627+
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32>
628+
// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[BITCAST]] : memref<3xi32>, vector<1xi1>, vector<1xi32>

0 commit comments

Comments
 (0)