Skip to content

Commit b2f70ff

Browse files
committed
[mlir] Use alignment in VectorToLLVM conversion.
After adding an alignment attribute to: * vector.load * vector.store * vector.maskedload * vector.maskedstore The shared pattern used by these operations can use the alignment attribute to specify the alignment of the vector being loaded or stored.
1 parent 9958c24 commit b2f70ff

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
247247
MemRefType memRefTy = loadOrStoreOp.getMemRefType();
248248

249249
// Resolve alignment.
250-
unsigned align;
251-
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
250+
unsigned align = loadOrStoreOp.getAlignment().value_or(0);
251+
if (!align && failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
252252
memRefTy, align, useVectorAlignment)))
253253
return rewriter.notifyMatchFailure(loadOrStoreOp,
254254
"could not resolve alignment");

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,20 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec
16791679

16801680
// -----
16811681

1682+
func.func @load_alignment(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
1683+
%0 = vector.load %memref[%i, %j] { alignment = 8 } : memref<200x100xf32>, vector<8xf32>
1684+
return %0 : vector<8xf32>
1685+
}
1686+
1687+
// CHECK-LABEL: func @load_alignment
1688+
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
1689+
// CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
1690+
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
1691+
// CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
1692+
// CHECK: llvm.load %[[GEP]] {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>
1693+
1694+
// -----
1695+
16821696
//===----------------------------------------------------------------------===//
16831697
// vector.store
16841698
//===----------------------------------------------------------------------===//
@@ -1785,6 +1799,21 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
17851799

17861800
// -----
17871801

1802+
func.func @store_alignment(%memref : memref<200x100xf32>, %i : index, %j : index) {
1803+
%val = arith.constant dense<11.0> : vector<4xf32>
1804+
vector.store %val, %memref[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32>
1805+
return
1806+
}
1807+
1808+
// CHECK-LABEL: func @store_alignment
1809+
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
1810+
// CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %[[C100]] : i64
1811+
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64
1812+
// CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
1813+
// CHECK: llvm.store %{{.*}}, %[[GEP]] {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr
1814+
1815+
// -----
1816+
17881817
//===----------------------------------------------------------------------===//
17891818
// vector.maskedload
17901819
//===----------------------------------------------------------------------===//
@@ -1839,6 +1868,21 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]
18391868

18401869
// -----
18411870

1871+
func.func @masked_load_alignment(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
1872+
%c0 = arith.constant 0: index
1873+
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1874+
return %0 : vector<16xf32>
1875+
}
1876+
1877+
// CHECK-LABEL: func @masked_load
1878+
// CHECK: %[[CO:.*]] = arith.constant 0 : index
1879+
// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
1880+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
1881+
// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
1882+
// CHECK: return %[[L]] : vector<16xf32>
1883+
1884+
// -----
1885+
18421886
//===----------------------------------------------------------------------===//
18431887
// vector.maskedstore
18441888
//===----------------------------------------------------------------------===//
@@ -1891,6 +1935,20 @@ func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16
18911935

18921936
// -----
18931937

1938+
func.func @masked_store_alignment(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
1939+
%c0 = arith.constant 0: index
1940+
vector.maskedstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xf32>, vector<16xi1>, vector<16xf32>
1941+
return
1942+
}
1943+
1944+
// CHECK-LABEL: func @masked_store
1945+
// CHECK: %[[CO:.*]] = arith.constant 0 : index
1946+
// CHECK: %[[C:.*]] = builtin.unrealized_conversion_cast %[[CO]] : index to i64
1947+
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
1948+
// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
1949+
1950+
// -----
1951+
18941952
//===----------------------------------------------------------------------===//
18951953
// vector.gather
18961954
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)