Skip to content

Commit c5e6d56

Browse files
authored
[mlir][vector] Propagate alignment when emulating masked{load,stores}. (#155648)
Propagate alignment from `vector.maskedload` and `vector.maskedstore` to `memref.load` and `memref.store` during `VectorEmulateMaskedLoadStore` pass.
1 parent 250d251 commit c5e6d56

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ struct VectorMaskedLoadOpConverter final
6464
Value mask = maskedLoadOp.getMask();
6565
Value base = maskedLoadOp.getBase();
6666
Value iValue = maskedLoadOp.getPassThru();
67+
std::optional<uint64_t> alignment = maskedLoadOp.getAlignment();
6768
auto indices = llvm::to_vector_of<Value>(maskedLoadOp.getIndices());
6869
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
6970
IntegerAttr::get(indexType, 1));
@@ -73,8 +74,9 @@ struct VectorMaskedLoadOpConverter final
7374
auto ifOp = scf::IfOp::create(
7475
rewriter, loc, maskBit,
7576
[&](OpBuilder &builder, Location loc) {
76-
auto loadedValue =
77-
memref::LoadOp::create(builder, loc, base, indices);
77+
auto loadedValue = memref::LoadOp::create(
78+
builder, loc, base, indices, /*nontemporal=*/false,
79+
alignment.value_or(0));
7880
auto combinedValue =
7981
vector::InsertOp::create(builder, loc, loadedValue, iValue, i);
8082
scf::YieldOp::create(builder, loc, combinedValue.getResult());
@@ -132,6 +134,8 @@ struct VectorMaskedStoreOpConverter final
132134
Value mask = maskedStoreOp.getMask();
133135
Value base = maskedStoreOp.getBase();
134136
Value value = maskedStoreOp.getValueToStore();
137+
bool nontemporal = false;
138+
std::optional<uint64_t> alignment = maskedStoreOp.getAlignment();
135139
auto indices = llvm::to_vector_of<Value>(maskedStoreOp.getIndices());
136140
Value one = arith::ConstantOp::create(rewriter, loc, indexType,
137141
IntegerAttr::get(indexType, 1));
@@ -141,7 +145,8 @@ struct VectorMaskedStoreOpConverter final
141145
auto ifOp = scf::IfOp::create(rewriter, loc, maskBit, /*else=*/false);
142146
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
143147
auto extractedValue = vector::ExtractOp::create(rewriter, loc, value, i);
144-
memref::StoreOp::create(rewriter, loc, extractedValue, base, indices);
148+
memref::StoreOp::create(rewriter, loc, extractedValue, base, indices,
149+
nontemporal, alignment.value_or(0));
145150

146151
rewriter.setInsertionPointAfter(ifOp);
147152
indices.back() =

mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,22 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
5454
return %0: vector<4xf32>
5555
}
5656

57+
// CHECK-LABEL: @vector_maskedload_with_alignment
58+
// CHECK: memref.load
59+
// CHECK-SAME: {alignment = 8 : i64}
60+
// CHECK: memref.load
61+
// CHECK-SAME: {alignment = 8 : i64}
62+
func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
63+
%idx_0 = arith.constant 0 : index
64+
%idx_1 = arith.constant 1 : index
65+
%idx_4 = arith.constant 4 : index
66+
%mask = vector.create_mask %idx_1 : vector<4xi1>
67+
%s = arith.constant 0.0 : f32
68+
%pass_thru = vector.splat %s : vector<4xf32>
69+
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
70+
return %0: vector<4xf32>
71+
}
72+
5773
// CHECK-LABEL: @vector_maskedstore
5874
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x5xf32>, %[[ARG1:.*]]: vector<4xf32>) {
5975
// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
@@ -93,3 +109,17 @@ func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
93109
vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
94110
return
95111
}
112+
113+
// CHECK-LABEL: @vector_maskedstore_with_alignment
114+
// CHECK: memref.store
115+
// CHECK-SAME: {alignment = 8 : i64}
116+
// CHECK: memref.store
117+
// CHECK-SAME: {alignment = 8 : i64}
118+
func.func @vector_maskedstore_with_alignment(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) {
119+
%idx_0 = arith.constant 0 : index
120+
%idx_1 = arith.constant 1 : index
121+
%idx_4 = arith.constant 4 : index
122+
%mask = vector.create_mask %idx_1 : vector<4xi1>
123+
vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 { alignment = 8 } : memref<4x5xf32>, vector<4xi1>, vector<4xf32>
124+
return
125+
}

0 commit comments

Comments
 (0)