Skip to content

Commit a23c221

Browse files
committed
Legalize memcpy
1 parent 90d8e4d commit a23c221

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

llvm/lib/Target/DirectX/DXILLegalizePass.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "llvm/IR/Instructions.h"
1616
#include "llvm/IR/Module.h"
1717
#include "llvm/Pass.h"
18+
#include "llvm/Support/ErrorHandling.h"
1819
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
1920
#include <functional>
2021

@@ -246,6 +247,58 @@ downcastI64toI32InsertExtractElements(Instruction &I,
246247
}
247248
}
248249

250+
static void emitMemcpyExpansion(IRBuilder<> &Builder, Value *Dst, Value *Src,
251+
ConstantInt *Length) {
252+
253+
uint64_t ByteLength = Length->getZExtValue();
254+
if (ByteLength == 0)
255+
return;
256+
257+
LLVMContext &Ctx = Builder.getContext();
258+
const DataLayout &DL = Builder.GetInsertBlock()->getModule()->getDataLayout();
259+
260+
auto GetArrTyFromVal = [](Value *Val) {
261+
if (auto *Alloca = dyn_cast<AllocaInst>(Val))
262+
return dyn_cast<ArrayType>(Alloca->getAllocatedType());
263+
if (auto *GlobalVar = dyn_cast<GlobalVariable>(Val))
264+
return dyn_cast<ArrayType>(GlobalVar->getValueType());
265+
llvm_unreachable(
266+
"Expected an Alloca or GlobalVariable in memcpy Src and Dst");
267+
};
268+
269+
ArrayType *ArrTy = GetArrTyFromVal(Dst);
270+
assert(ArrTy && "Expected Dst of memcpy to be a Pointer to an Array Type");
271+
if (auto *DstGlobalVar = dyn_cast<GlobalVariable>(Dst))
272+
assert(!DstGlobalVar->isConstant() &&
273+
"The Dst of memcpy must not be a constant Global Variable");
274+
275+
[[maybe_unused]] ArrayType *SrcArrTy = GetArrTyFromVal(Src);
276+
assert(SrcArrTy && "Expected Src of memcpy to be a Pointer to an Array Type");
277+
278+
// This assumption simplifies implementation and covers currently-known
279+
// use-cases for DXIL. It may be relaxed in the future if required.
280+
assert(ArrTy == SrcArrTy && "Array Types of Src and Dst in memcpy must match");
281+
282+
Type *ElemTy = ArrTy->getElementType();
283+
uint64_t ElemSize = DL.getTypeStoreSize(ElemTy);
284+
assert(ElemSize > 0 && "Size must be set");
285+
286+
[[maybe_unused]] uint64_t Size = ArrTy->getArrayNumElements();
287+
assert(ElemSize * Size >= ByteLength &&
288+
"Array size must be at least as large as the memcpy length");
289+
290+
uint64_t NumElemsToCopy = ByteLength / ElemSize;
291+
assert(ByteLength % ElemSize == 0 &&
292+
"memcpy length must be divisible by array element type");
293+
for (uint64_t I = 0; I < NumElemsToCopy; ++I) {
294+
Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I);
295+
Value *SrcPtr = Builder.CreateGEP(ElemTy, Src, Offset, "gep");
296+
Value *SrcVal = Builder.CreateLoad(ElemTy, SrcPtr);
297+
Value *DstPtr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep");
298+
Builder.CreateStore(SrcVal, DstPtr);
299+
}
300+
}
301+
249302
static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
250303
ConstantInt *SizeCI,
251304
DenseMap<Value *, Value *> &ReplacedValues) {
@@ -296,6 +349,30 @@ static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val,
296349
}
297350
}
298351

352+
static void removeMemCpy(Instruction &I,
353+
SmallVectorImpl<Instruction *> &ToRemove,
354+
DenseMap<Value *, Value *> &ReplacedValues) {
355+
356+
CallInst *CI = dyn_cast<CallInst>(&I);
357+
if (!CI)
358+
return;
359+
360+
Intrinsic::ID ID = CI->getIntrinsicID();
361+
if (ID != Intrinsic::memcpy)
362+
return;
363+
364+
IRBuilder<> Builder(&I);
365+
Value *Dst = CI->getArgOperand(0);
366+
Value *Src = CI->getArgOperand(1);
367+
ConstantInt *Length = dyn_cast<ConstantInt>(CI->getArgOperand(2));
368+
assert(Length && "Expected Length to be a ConstantInt");
369+
ConstantInt *IsVolatile = dyn_cast<ConstantInt>(CI->getArgOperand(3));
370+
assert(IsVolatile && "Expected IsVolatile to be a ConstantInt");
371+
assert(IsVolatile->getZExtValue() == 0 && "Expected IsVolatile to be false");
372+
emitMemcpyExpansion(Builder, Dst, Src, Length);
373+
ToRemove.push_back(CI);
374+
}
375+
299376
static void removeMemSet(Instruction &I,
300377
SmallVectorImpl<Instruction *> &ToRemove,
301378
DenseMap<Value *, Value *> &ReplacedValues) {
@@ -348,6 +425,7 @@ class DXILLegalizationPipeline {
348425
LegalizationPipeline.push_back(fixI8UseChain);
349426
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
350427
LegalizationPipeline.push_back(legalizeFreeze);
428+
LegalizationPipeline.push_back(removeMemCpy);
351429
LegalizationPipeline.push_back(removeMemSet);
352430
}
353431
};
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
2+
; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s
3+
4+
@outputStrides = external local_unnamed_addr addrspace(2) global [2 x <4 x i32>], align 4
5+
6+
define void @replace_2x4xint_global_memcpy_test() #0 {
7+
; CHECK-LABEL: define void @replace_2x4xint_global_memcpy_test(
8+
; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
9+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x <4 x i32>], align 16
10+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 32, ptr nonnull [[TMP1]])
11+
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr addrspace(2) @outputStrides, align 16
12+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr <4 x i32>, ptr [[TMP1]], i32 0
13+
; CHECK-NEXT: store <4 x i32> [[TMP2]], ptr [[GEP]], align 16
14+
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr addrspace(2) getelementptr (<4 x i32>, ptr addrspace(2) @outputStrides, i32 1), align 16
15+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr <4 x i32>, ptr [[TMP1]], i32 1
16+
; CHECK-NEXT: store <4 x i32> [[TMP3]], ptr [[GEP1]], align 16
17+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 32, ptr nonnull [[TMP1]])
18+
; CHECK-NEXT: ret void
19+
;
20+
%1 = alloca [2 x <4 x i32>], align 16
21+
call void @llvm.lifetime.start.p0(i64 32, ptr nonnull %1)
22+
call void @llvm.memcpy.p0.p2.i32(ptr nonnull align 16 dereferenceable(32) %1, ptr addrspace(2) align 16 dereferenceable(32) @outputStrides, i32 32, i1 false)
23+
call void @llvm.lifetime.end.p0(i64 32, ptr nonnull %1)
24+
ret void
25+
}
26+
27+
define void @replace_int_memcpy_test() #0 {
28+
; CHECK-LABEL: define void @replace_int_memcpy_test(
29+
; CHECK-SAME: ) #[[ATTR0]] {
30+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [1 x i32], align 4
31+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [1 x i32], align 4
32+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TMP1]])
33+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TMP2]])
34+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[TMP1]], i32 0
35+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[GEP]], align 4
36+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i32, ptr [[TMP2]], i32 0
37+
; CHECK-NEXT: store i32 [[TMP3]], ptr [[GEP1]], align 4
38+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TMP2]])
39+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TMP1]])
40+
; CHECK-NEXT: ret void
41+
;
42+
%1 = alloca [1 x i32], align 4
43+
%2 = alloca [1 x i32], align 4
44+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %1)
45+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %2)
46+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(4) %2, ptr align 4 dereferenceable(4) %1, i32 4, i1 false)
47+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %2)
48+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %1)
49+
ret void
50+
}
51+
52+
define void @replace_int16_memcpy_test() #0 {
53+
; CHECK-LABEL: define void @replace_int16_memcpy_test(
54+
; CHECK-SAME: ) #[[ATTR0]] {
55+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x i16], align 2
56+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x i16], align 2
57+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TMP1]])
58+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TMP2]])
59+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[TMP1]], i32 0
60+
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr [[GEP]], align 2
61+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[TMP2]], i32 0
62+
; CHECK-NEXT: store i16 [[TMP3]], ptr [[GEP1]], align 2
63+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i16, ptr [[TMP1]], i32 1
64+
; CHECK-NEXT: [[TMP4:%.*]] = load i16, ptr [[GEP2]], align 2
65+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i16, ptr [[TMP2]], i32 1
66+
; CHECK-NEXT: store i16 [[TMP4]], ptr [[GEP3]], align 2
67+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TMP2]])
68+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TMP1]])
69+
; CHECK-NEXT: ret void
70+
;
71+
%1 = alloca [2 x i16], align 2
72+
%2 = alloca [2 x i16], align 2
73+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %1)
74+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %2)
75+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 2 dereferenceable(4) %2, ptr align 2 dereferenceable(4) %1, i32 4, i1 false)
76+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %2)
77+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %1)
78+
ret void
79+
}
80+
81+
define void @replace_float_memcpy_test() #0 {
82+
; CHECK-LABEL: define void @replace_float_memcpy_test(
83+
; CHECK-SAME: ) #[[ATTR0]] {
84+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x float], align 4
85+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x float], align 4
86+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull [[TMP1]])
87+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull [[TMP2]])
88+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr float, ptr [[TMP1]], i32 0
89+
; CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[GEP]], align 4
90+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr float, ptr [[TMP2]], i32 0
91+
; CHECK-NEXT: store float [[TMP3]], ptr [[GEP1]], align 4
92+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr float, ptr [[TMP1]], i32 1
93+
; CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[GEP2]], align 4
94+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr float, ptr [[TMP2]], i32 1
95+
; CHECK-NEXT: store float [[TMP4]], ptr [[GEP3]], align 4
96+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull [[TMP2]])
97+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull [[TMP1]])
98+
; CHECK-NEXT: ret void
99+
;
100+
%1 = alloca [2 x float], align 4
101+
%2 = alloca [2 x float], align 4
102+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %1)
103+
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %2)
104+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(8) %2, ptr align 4 dereferenceable(8) %1, i32 8, i1 false)
105+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %2)
106+
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %1)
107+
ret void
108+
}
109+
110+
define void @replace_double_memcpy_test() #0 {
111+
; CHECK-LABEL: define void @replace_double_memcpy_test(
112+
; CHECK-SAME: ) #[[ATTR0]] {
113+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x double], align 4
114+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x double], align 4
115+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 16, ptr nonnull [[TMP1]])
116+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 16, ptr nonnull [[TMP2]])
117+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr double, ptr [[TMP1]], i32 0
118+
; CHECK-NEXT: [[TMP3:%.*]] = load double, ptr [[GEP]], align 8
119+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr double, ptr [[TMP2]], i32 0
120+
; CHECK-NEXT: store double [[TMP3]], ptr [[GEP1]], align 8
121+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr double, ptr [[TMP1]], i32 1
122+
; CHECK-NEXT: [[TMP4:%.*]] = load double, ptr [[GEP2]], align 8
123+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr double, ptr [[TMP2]], i32 1
124+
; CHECK-NEXT: store double [[TMP4]], ptr [[GEP3]], align 8
125+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 16, ptr nonnull [[TMP2]])
126+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 16, ptr nonnull [[TMP1]])
127+
; CHECK-NEXT: ret void
128+
;
129+
%1 = alloca [2 x double], align 4
130+
%2 = alloca [2 x double], align 4
131+
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %1)
132+
call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %2)
133+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 4 dereferenceable(8) %2, ptr align 4 dereferenceable(8) %1, i32 16, i1 false)
134+
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %2)
135+
call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %1)
136+
ret void
137+
}
138+
139+
define void @replace_half_memcpy_test() #0 {
140+
; CHECK-LABEL: define void @replace_half_memcpy_test(
141+
; CHECK-SAME: ) #[[ATTR0]] {
142+
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x half], align 2
143+
; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x half], align 2
144+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TMP1]])
145+
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TMP2]])
146+
; CHECK-NEXT: [[GEP:%.*]] = getelementptr half, ptr [[TMP1]], i32 0
147+
; CHECK-NEXT: [[TMP3:%.*]] = load half, ptr [[GEP]], align 2
148+
; CHECK-NEXT: [[GEP1:%.*]] = getelementptr half, ptr [[TMP2]], i32 0
149+
; CHECK-NEXT: store half [[TMP3]], ptr [[GEP1]], align 2
150+
; CHECK-NEXT: [[GEP2:%.*]] = getelementptr half, ptr [[TMP1]], i32 1
151+
; CHECK-NEXT: [[TMP4:%.*]] = load half, ptr [[GEP2]], align 2
152+
; CHECK-NEXT: [[GEP3:%.*]] = getelementptr half, ptr [[TMP2]], i32 1
153+
; CHECK-NEXT: store half [[TMP4]], ptr [[GEP3]], align 2
154+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TMP2]])
155+
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TMP1]])
156+
; CHECK-NEXT: ret void
157+
;
158+
%1 = alloca [2 x half], align 2
159+
%2 = alloca [2 x half], align 2
160+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %1)
161+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %2)
162+
call void @llvm.memcpy.p0.p0.i32(ptr nonnull align 2 dereferenceable(4) %2, ptr align 2 dereferenceable(4) %1, i32 4, i1 false)
163+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %2)
164+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %1)
165+
ret void
166+
}
167+
168+
attributes #0 = {"hlsl.export"}
169+
170+
171+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr captures(none))
172+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr captures(none))
173+
declare void @llvm.memcpy.p0.p2.i32(ptr noalias, ptr addrspace(2) noalias readonly, i32, i1)
174+
declare void @llvm.memcpy.p0.p0.i32(ptr noalias, ptr noalias readonly, i32, i1)

0 commit comments

Comments
 (0)