Skip to content

Commit b777a60

Browse files
committed
updates
1 parent a9d7260 commit b777a60

File tree

2 files changed

+31
-28
lines changed

2 files changed

+31
-28
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
133133
static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
134134
VectorType extractType, Value vector,
135135
int64_t frontOffset, int64_t subvecSize) {
136-
// get vector's vector type:
137136
auto vectorType = dyn_cast<VectorType>(vector.getType());
138137
assert(vectorType && "expected vector type");
139138
assert(vectorType.getShape().size() == 1 && "expected 1-D vector type");

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
22

3+
// TODO: remove memref.alloc() in the tests to eliminate noises.
4+
// memref.alloc exists here because sub-byte vector data types such as i2
5+
// are currently not supported as input arguments.
6+
37
// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
48
// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
59

6-
func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
7-
%0 = memref.alloc() : memref<3x3xi2>
8-
%c0 = arith.constant 0 : index
9-
%c2 = arith.constant 2 : index
10-
%cst = arith.constant dense<0> : vector<3x3xi2>
11-
%1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
12-
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
13-
return %2 : vector<3x3xi2>
10+
func.func @vector_load_i2() -> vector<3x3xi2> {
11+
%0 = memref.alloc() : memref<3x3xi2>
12+
%c0 = arith.constant 0 : index
13+
%c2 = arith.constant 2 : index
14+
%cst = arith.constant dense<0> : vector<3x3xi2>
15+
%1 = vector.load %0[%c2, %c0] : memref<3x3xi2>, vector<3xi2>
16+
%2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
17+
return %2 : vector<3x3xi2>
1418
}
1519

1620
// CHECK: func @vector_load_i2
@@ -23,12 +27,12 @@ func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
2327
//-----
2428

2529
func.func @vector_transfer_read_i2() -> vector<3xi2> {
26-
%0 = memref.alloc() : memref<3x3xi2>
27-
%c0i2 = arith.constant 0 : i2
28-
%c0 = arith.constant 0 : index
29-
%c2 = arith.constant 2 : index
30-
%1 = vector.transfer_read %0[%c2, %c0], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
31-
return %1 : vector<3xi2>
30+
%0 = memref.alloc() : memref<3x3xi2>
31+
%pad = arith.constant 0 : i2
32+
%c0 = arith.constant 0 : index
33+
%c2 = arith.constant 2 : index
34+
%1 = vector.transfer_read %0[%c2, %c0], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
35+
return %1 : vector<3xi2>
3236
}
3337

3438
// CHECK: func @vector_transfer_read_i2
@@ -41,15 +45,15 @@ func.func @vector_transfer_read_i2() -> vector<3xi2> {
4145
//-----
4246

4347
func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
44-
%0 = memref.alloc() : memref<3x5xi2>
45-
%cst = arith.constant dense<0> : vector<3x5xi2>
46-
%mask = vector.constant_mask [3] : vector<5xi1>
47-
%c0 = arith.constant 0 : index
48-
%c2 = arith.constant 2 : index
49-
%1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
50-
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
51-
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
52-
return %2 : vector<3x5xi2>
48+
%0 = memref.alloc() : memref<3x5xi2>
49+
%cst = arith.constant dense<0> : vector<3x5xi2>
50+
%mask = vector.constant_mask [3] : vector<5xi1>
51+
%c0 = arith.constant 0 : index
52+
%c2 = arith.constant 2 : index
53+
%1 = vector.maskedload %0[%c2, %c0], %mask, %passthru :
54+
memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2>
55+
%2 = vector.insert %1, %cst [0] : vector<5xi2> into vector<3x5xi2>
56+
return %2 : vector<3x5xi2>
5357
}
5458

5559
// CHECK: func @vector_cst_maskedload_i2
@@ -71,10 +75,10 @@ func.func @vector_cst_maskedload_i2(%passthru: vector<5xi2>) -> vector<3x5xi2> {
7175

7276
//-----
7377

74-
func.func @vector_load_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
78+
func.func @vector_load_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
7579
%0 = memref.alloc() : memref<3x3xi2>
7680
%cst = arith.constant dense<0> : vector<3x3xi2>
77-
%1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
81+
%1 = vector.load %0[%idx1, %idx2] : memref<3x3xi2>, vector<3xi2>
7882
return %1 : vector<3xi2>
7983
}
8084

@@ -95,10 +99,10 @@ func.func @vector_load_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector
9599

96100
//-----
97101

98-
func.func @vector_transfer_read_i2_dynamic_indexing(%arg1: index, %arg2: index) -> vector<3xi2> {
102+
func.func @vector_transfer_read_i2_dynamic_indexing(%idx1: index, %idx2: index) -> vector<3xi2> {
99103
%0 = memref.alloc() : memref<3x3xi2>
100104
%pad = arith.constant 0 : i2
101-
%1 = vector.transfer_read %0[%arg1, %arg2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
105+
%1 = vector.transfer_read %0[%idx1, %idx2], %pad {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
102106
return %1 : vector<3xi2>
103107
}
104108

0 commit comments

Comments
 (0)