Skip to content
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,9 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"operations instead of the alignment of the element type of the "
"memref. This flag is intended for use with hardware which requires"
"vector alignment, or in application contexts where it is known all "
"vector access are naturally aligned. ">,
"vector access are naturally aligned. If operations have an "
"alignment attribute set, the alignment attribute takes priority "
"over this option ">,
Option<"amx", "enable-amx",
"bool", /*default=*/"false",
"Enables the use of AMX dialect while lowering the vector "
Expand Down
28 changes: 22 additions & 6 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ class VectorLoadStoreConversion : public ConvertOpToLLVMPattern<LoadOrStoreOp> {
MemRefType memRefTy = loadOrStoreOp.getMemRefType();

// Resolve alignment.
// Explicit alignment takes priority over use-vector-alignment.
unsigned align = loadOrStoreOp.getAlignment().value_or(0);
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vectorTy,
Expand Down Expand Up @@ -299,8 +300,10 @@ class VectorGatherOpConversion
}

// Resolve alignment.
unsigned align;
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
// Explicit alignment takes priority over use-vector-alignment.
unsigned align = gather.getAlignment().value_or(0);
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
Comment on lines +304 to +306
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this means that explicit alignment takes priority over e.g. --convert-vector-to-llvm='use-vector-alignment=1. This feels like the right design decision, but we should make sure that it is documented (perhaps just add a comment here?) and tested (e.g. in "use-vector-alignment.mlir").

Similar comment for scatter. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @banach-space! See here for the changes:

b6e5aff
746412a
fdfc873

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Sorry I didn't replay earlier, I was on a sick leave.

memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(gather, "could not resolve alignment");

Expand Down Expand Up @@ -354,8 +357,10 @@ class VectorScatterOpConversion
}

// Resolve alignment.
unsigned align;
if (failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
// Explicit alignment takes priority over use-vector-alignment.
unsigned align = scatter.getAlignment().value_or(0);
if (!align &&
failed(getVectorToLLVMAlignment(*this->getTypeConverter(), vType,
memRefType, align, useVectorAlignment)))
return rewriter.notifyMatchFailure(scatter,
"could not resolve alignment");
Expand Down Expand Up @@ -399,8 +404,14 @@ class VectorExpandLoadOpConversion
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());

// From:
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
// The pointer alignment defaults to 1.
uint64_t alignment = expand.getAlignment().value_or(1);

rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru());
expand, vtype, ptr, adaptor.getMask(), adaptor.getPassThru(),
alignment);
return success();
}
};
Expand All @@ -421,8 +432,13 @@ class VectorCompressStoreOpConversion
Value ptr = getStridedElementPtr(rewriter, loc, memRefType,
adaptor.getBase(), adaptor.getIndices());

// From:
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
// The pointer alignment defaults to 1.
uint64_t alignment = compress.getAlignment().value_or(1);

rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
compress, adaptor.getValueToStore(), ptr, adaptor.getMask());
compress, adaptor.getValueToStore(), ptr, adaptor.getMask(), alignment);
return success();
}
};
Expand Down
77 changes: 77 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/use-vector-alignment.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ func.func @load(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8

// -----

func.func @load_with_alignment_attribute(%base : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> {
%0 = vector.load %base[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<8xf32>
return %0 : vector<8xf32>
}

// ALL-LABEL: func @load_with_alignment_attribute

// VEC-ALIGN: llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>
// MEMREF-ALIGN: llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr -> vector<8xf32>

// -----

//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
Expand All @@ -35,6 +47,19 @@ func.func @store(%base : memref<200x100xf32>, %i : index, %j : index) {

// -----

func.func @store_with_alignment_attribute(%base : memref<200x100xf32>, %i : index, %j : index) {
%val = arith.constant dense<11.0> : vector<4xf32>
vector.store %val, %base[%i, %j] {alignment = 8} : memref<200x100xf32>, vector<4xf32>
return
}

// ALL-LABEL: func @store_with_alignment_attribute

// VEC-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr
// MEMREF-ALIGN: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : vector<4xf32>, !llvm.ptr

// -----

//===----------------------------------------------------------------------===//
// vector.maskedload
//===----------------------------------------------------------------------===//
Expand All @@ -52,6 +77,19 @@ func.func @masked_load(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: ve

// -----

func.func @masked_load_with_alignment_attribute(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %base[%c0], %mask, %passthru {alignment = 8} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}

// ALL-LABEL: func @masked_load_with_alignment_attribute

// VEC-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
// MEMREF-ALIGN: %[[L:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr, vector<16xi1>, vector<16xf32>) -> vector<16xf32>

// -----

//===----------------------------------------------------------------------===//
// vector.maskedstore
//===----------------------------------------------------------------------===//
Expand All @@ -69,6 +107,19 @@ func.func @masked_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: v

// -----

func.func @masked_store_with_alignment_attribute(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
%c0 = arith.constant 0: index
vector.maskedstore %base[%c0], %mask, %passthru {alignment = 8} : memref<?xf32>, vector<16xi1>, vector<16xf32>
return
}

// ALL-LABEL: func @masked_store_with_alignment_attribute

// VEC-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr
// MEMREF-ALIGN: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xf32>, vector<16xi1> into !llvm.ptr

// -----

//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//
Expand All @@ -86,6 +137,19 @@ func.func @scatter(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3x

// -----

func.func @scatter_with_alignment_attribute(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %value: vector<3xf32>) {
%0 = arith.constant 0: index
vector.scatter %base[%0][%index], %mask, %value {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}

// ALL-LABEL: func @scatter_with_alignment_attribute

// VEC-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
// MEMREF-ALIGN: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>

// -----

//===----------------------------------------------------------------------===//
// vector.gather
//===----------------------------------------------------------------------===//
Expand All @@ -100,3 +164,16 @@ func.func @gather(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi

// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 16 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>

// -----

func.func @gather_with_alignment_attribute(%base: memref<?xf32>, %index: vector<3xi32>, %mask: vector<3xi1>, %passthru: vector<3xf32>) -> vector<3xf32> {
%0 = arith.constant 0: index
%1 = vector.gather %base[%0][%index], %mask, %passthru {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %1 : vector<3xf32>
}

// ALL-LABEL: func @gather_with_alignment_attribute

// VEC-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// MEMREF-ALIGN: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
39 changes: 39 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,16 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x

// -----

func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) -> vector<3xf32> {
%1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 {alignment = 8} : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
return %1 : vector<3xf32>
}

// CHECK-LABEL: func @gather_with_alignment
// CHECK: llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>

// -----

//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2118,6 +2128,17 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>

// -----

func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>, %0: index) {
vector.scatter %arg0[%0][%arg1], %arg2, %arg3 { alignment = 8 } : memref<?xf32>, vector<3xi32>, vector<3xi1>, vector<3xf32>
return
}

// CHECK-LABEL: func @scatter_with_alignment
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>


// -----

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2149,6 +2170,15 @@ func.func @expand_load_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %a

// -----

func.func @expand_load_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) -> vector<11xindex> {
%0 = vector.expandload %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex> into vector<11xindex>
return %0 : vector<11xindex>
}
// CHECK-LABEL: func @expand_load_op_with_alignment
// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64>

// -----

//===----------------------------------------------------------------------===//
// vector.compressstore
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2177,6 +2207,15 @@ func.func @compress_store_op_index(%arg0: memref<?xindex>, %arg1: vector<11xi1>,

// -----

func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vector<11xi1>, %arg2: vector<11xindex>, %c0: index) {
vector.compressstore %arg0[%c0], %arg1, %arg2 { alignment = 8 } : memref<?xindex>, vector<11xi1>, vector<11xindex>
return
}
// CHECK-LABEL: func @compress_store_op_with_alignment
// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> ()

// -----

//===----------------------------------------------------------------------===//
// vector.splat
//===----------------------------------------------------------------------===//
Expand Down