Skip to content

Commit abda8be

Browse files
authored
[InferAlignment] Increase alignment in masked load / store instrinsics if known (#156057)
Summary: The masked load / store LLVM intrinsics take an argument for the alignment. If the user is pessimistic about alignment they can provide a value of `1` for an unaligned load. This patch updates infer-alignment to increase the alignment value of the alignment argument if it is known greater than the provided one. Ignoring the gather / scatter versions for now since they contain many pointers.
1 parent b96fa9f commit abda8be

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

llvm/lib/Transforms/Scalar/InferAlignment.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/Analysis/AssumptionCache.h"
1616
#include "llvm/Analysis/ValueTracking.h"
1717
#include "llvm/IR/Instructions.h"
18+
#include "llvm/IR/IntrinsicInst.h"
1819
#include "llvm/Support/KnownBits.h"
1920
#include "llvm/Transforms/Scalar.h"
2021
#include "llvm/Transforms/Utils/Local.h"
@@ -35,8 +36,38 @@ static bool tryToImproveAlign(
3536
return true;
3637
}
3738
}
38-
// TODO: Also handle memory intrinsics.
39-
return false;
39+
40+
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
41+
if (!II)
42+
return false;
43+
44+
// TODO: Handle more memory intrinsics.
45+
switch (II->getIntrinsicID()) {
46+
case Intrinsic::masked_load:
47+
case Intrinsic::masked_store: {
48+
int AlignOpIdx = II->getIntrinsicID() == Intrinsic::masked_load ? 1 : 2;
49+
Value *PtrOp = II->getIntrinsicID() == Intrinsic::masked_load
50+
? II->getArgOperand(0)
51+
: II->getArgOperand(1);
52+
Type *Type = II->getIntrinsicID() == Intrinsic::masked_load
53+
? II->getType()
54+
: II->getArgOperand(0)->getType();
55+
56+
Align OldAlign =
57+
cast<ConstantInt>(II->getArgOperand(AlignOpIdx))->getAlignValue();
58+
Align PrefAlign = DL.getPrefTypeAlign(Type);
59+
Align NewAlign = Fn(PtrOp, OldAlign, PrefAlign);
60+
if (NewAlign <= OldAlign)
61+
return false;
62+
63+
Value *V =
64+
ConstantInt::get(Type::getInt32Ty(II->getContext()), NewAlign.value());
65+
II->setOperand(AlignOpIdx, V);
66+
return true;
67+
}
68+
default:
69+
return false;
70+
}
4071
}
4172

4273
bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt < %s -passes=infer-alignment -S | FileCheck %s
3+
4+
define <2 x i32> @load(<2 x i1> %mask, ptr %ptr) {
5+
; CHECK-LABEL: define <2 x i32> @load(
6+
; CHECK-SAME: <2 x i1> [[MASK:%.*]], ptr [[PTR:%.*]]) {
7+
; CHECK-NEXT: [[ENTRY:.*:]]
8+
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR]], i64 64) ]
9+
; CHECK-NEXT: [[MASKED_LOAD:%.*]] = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr [[PTR]], i32 64, <2 x i1> [[MASK]], <2 x i32> poison)
10+
; CHECK-NEXT: ret <2 x i32> [[MASKED_LOAD]]
11+
;
12+
entry:
13+
call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
14+
%masked_load = call <2 x i32> @llvm.masked.load.v2i32.p0(ptr %ptr, i32 1, <2 x i1> %mask, <2 x i32> poison)
15+
ret <2 x i32> %masked_load
16+
}
17+
18+
define void @store(<2 x i1> %mask, <2 x i32> %val, ptr %ptr) {
19+
; CHECK-LABEL: define void @store(
20+
; CHECK-SAME: <2 x i1> [[MASK:%.*]], <2 x i32> [[VAL:%.*]], ptr [[PTR:%.*]]) {
21+
; CHECK-NEXT: [[ENTRY:.*:]]
22+
; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(ptr [[PTR]], i64 64) ]
23+
; CHECK-NEXT: tail call void @llvm.masked.store.v2i32.p0(<2 x i32> [[VAL]], ptr [[PTR]], i32 64, <2 x i1> [[MASK]])
24+
; CHECK-NEXT: ret void
25+
;
26+
entry:
27+
call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 64) ]
28+
tail call void @llvm.masked.store.v2i32.p0(<2 x i32> %val, ptr %ptr, i32 1, <2 x i1> %mask)
29+
ret void
30+
}
31+
32+
declare void @llvm.assume(i1)
33+
declare <2 x i32> @llvm.masked.load.v2i32.p0(ptr, i32, <2 x i1>, <2 x i32>)
34+
declare void @llvm.masked.store.v2i32.p0(<2 x i32>, ptr, i32, <2 x i1>)

0 commit comments

Comments
 (0)