diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 7da5a71ab729b..be77a70fa46ba 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -13,6 +13,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include @@ -174,16 +175,22 @@ static void upcastI8AllocasAndUses(Instruction &I, Type *SmallestType = nullptr; - // Gather all cast targets for (User *U : AI->users()) { auto *Load = dyn_cast(U); if (!Load) continue; for (User *LU : Load->users()) { - auto *Cast = dyn_cast(LU); - if (!Cast) + Type *Ty = nullptr; + if (auto *Cast = dyn_cast(LU)) + Ty = Cast->getType(); + if (CallInst *CI = dyn_cast(LU)) { + if (CI->getIntrinsicID() == Intrinsic::memset) + Ty = Type::getInt32Ty(CI->getContext()); + } + + if (!Ty) continue; - Type *Ty = Cast->getType(); + if (!SmallestType || Ty->getPrimitiveSizeInBits() < SmallestType->getPrimitiveSizeInBits()) SmallestType = Ty; @@ -239,6 +246,77 @@ downcastI64toI32InsertExtractElements(Instruction &I, } } +static void emitMemsetExpansion(IRBuilder<> &Builder, Value *Dst, Value *Val, + ConstantInt *SizeCI, + DenseMap &ReplacedValues) { + LLVMContext &Ctx = Builder.getContext(); + [[maybe_unused]] const DataLayout &DL = + Builder.GetInsertBlock()->getModule()->getDataLayout(); + [[maybe_unused]] uint64_t OrigSize = SizeCI->getZExtValue(); + + AllocaInst *Alloca = dyn_cast(Dst); + + assert(Alloca && "Expected memset on an Alloca"); + assert(OrigSize == Alloca->getAllocationSize(DL)->getFixedValue() && + "Expected for memset size to match DataLayout size"); + + Type *AllocatedTy = Alloca->getAllocatedType(); + ArrayType *ArrTy = dyn_cast(AllocatedTy); + assert(ArrTy && "Expected Alloca for an Array Type"); + + Type *ElemTy = ArrTy->getElementType(); + uint64_t Size = ArrTy->getArrayNumElements(); + + [[maybe_unused]] uint64_t ElemSize = DL.getTypeStoreSize(ElemTy); + + assert(ElemSize > 0 && "Size must be set"); + assert(OrigSize == ElemSize * Size && "Size in bytes must match"); + + Value *TypedVal = Val; + + if (Val->getType() != ElemTy) { + if (ReplacedValues[Val]) { + // Note for i8 replacements if we know them we should use them. + // Further if this is a constant ReplacedValues will return null + // so we will stick to TypedVal = Val + TypedVal = ReplacedValues[Val]; + + } else { + // This case Val is a ConstantInt so the cast folds away. + // However if we don't do the cast the store below ends up being + // an i8. + TypedVal = Builder.CreateIntCast(Val, ElemTy, false); + } + } + + for (uint64_t I = 0; I < Size; ++I) { + Value *Offset = ConstantInt::get(Type::getInt32Ty(Ctx), I); + Value *Ptr = Builder.CreateGEP(ElemTy, Dst, Offset, "gep"); + Builder.CreateStore(TypedVal, Ptr); + } +} + +static void removeMemSet(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &ReplacedValues) { + + CallInst *CI = dyn_cast(&I); + if (!CI) + return; + + Intrinsic::ID ID = CI->getIntrinsicID(); + if (ID != Intrinsic::memset) + return; + + IRBuilder<> Builder(&I); + Value *Dst = CI->getArgOperand(0); + Value *Val = CI->getArgOperand(1); + ConstantInt *Size = dyn_cast(CI->getArgOperand(2)); + assert(Size && "Expected Size to be a ConstantInt"); + emitMemsetExpansion(Builder, Dst, Val, Size, ReplacedValues); + ToRemove.push_back(CI); +} + namespace { class DXILLegalizationPipeline { @@ -270,6 +348,7 @@ class DXILLegalizationPipeline { LegalizationPipeline.push_back(fixI8UseChain); LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); LegalizationPipeline.push_back(legalizeFreeze); + LegalizationPipeline.push_back(removeMemSet); } }; diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 398abd66dda16..10f4b4ee76619 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -98,7 +98,6 @@ class DirectXPassConfig : public TargetPassConfig { FunctionPass *createTargetRegisterAllocator(bool) override { return nullptr; } void addCodeGenPrepare() override { - addPass(createDXILFinalizeLinkageLegacyPass()); addPass(createDXILIntrinsicExpansionLegacyPass()); addPass(createDXILCBufferAccessLegacyPass()); addPass(createDXILDataScalarizationLegacyPass()); @@ -109,6 +108,7 @@ class DirectXPassConfig : public TargetPassConfig { addPass(createScalarizerPass(DxilScalarOptions)); addPass(createDXILForwardHandleAccessesLegacyPass()); addPass(createDXILLegalizeLegacyPass()); + addPass(createDXILFinalizeLinkageLegacyPass()); addPass(createDXILTranslateMetadataLegacyPass()); addPass(createDXILOpLoweringLegacyPass()); addPass(createDXILPrepareModulePass()); diff --git a/llvm/test/CodeGen/DirectX/legalize-memset.ll b/llvm/test/CodeGen/DirectX/legalize-memset.ll new file mode 100644 index 0000000000000..e97817ba824ed --- /dev/null +++ b/llvm/test/CodeGen/DirectX/legalize-memset.ll @@ -0,0 +1,125 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -dxil-legalize -dxil-finalize-linkage -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +define void @replace_float_memset_test() #0 { +; CHECK-LABEL: define void @replace_float_memset_test( +; CHECK-SAME: ) #[[ATTR0:[0-9]+]] { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x float], align 4 +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: [[GEP:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP]], align 4 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr float, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store float 0.000000e+00, ptr [[GEP1]], align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [2 x float], align 4 + call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %accum.i.flat) + call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 8, i1 false) + call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %accum.i.flat) + ret void +} + +define void @replace_half_memset_test() #0 { +; CHECK-LABEL: define void @replace_half_memset_test( +; CHECK-SAME: ) #[[ATTR0]] { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x half], align 4 +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: [[GEP:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store half 0xH0000, ptr [[GEP]], align 2 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr half, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store half 0xH0000, ptr [[GEP1]], align 2 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [2 x half], align 4 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat) + call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat) + ret void +} + +define void @replace_double_memset_test() #0 { +; CHECK-LABEL: define void @replace_double_memset_test( +; CHECK-SAME: ) #[[ATTR0]] { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [2 x double], align 4 +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: [[GEP:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP]], align 8 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr double, ptr [[ACCUM_I_FLAT]], i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr [[GEP1]], align 8 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 16, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [2 x double], align 4 + call void @llvm.lifetime.start.p0(i64 16, ptr nonnull %accum.i.flat) + call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 16, i1 false) + call void @llvm.lifetime.end.p0(i64 16, ptr nonnull %accum.i.flat) + ret void +} + +define void @replace_int16_memset_test() #0 { +; CHECK-LABEL: define void @replace_int16_memset_test( +; CHECK-SAME: ) #[[ATTR0]] { +; CHECK-NEXT: [[CACHE_I:%.*]] = alloca [2 x i16], align 2 +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[CACHE_I]]) +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 0 +; CHECK-NEXT: store i16 0, ptr [[GEP]], align 2 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i16, ptr [[CACHE_I]], i32 1 +; CHECK-NEXT: store i16 0, ptr [[GEP1]], align 2 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[CACHE_I]]) +; CHECK-NEXT: ret void +; + %cache.i = alloca [2 x i16], align 2 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %cache.i) + call void @llvm.memset.p0.i32(ptr nonnull align 2 dereferenceable(4) %cache.i, i8 0, i32 4, i1 false) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %cache.i) + ret void +} + +define void @replace_int_memset_test() #0 { +; CHECK-LABEL: define void @replace_int_memset_test( +; CHECK-SAME: ) #[[ATTR0]] { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 0, ptr [[GEP]], align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat) + call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 0, i32 4, i1 false) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat) + ret void +} + +define void @replace_int_memset_to_var_test() #0 { +; CHECK-LABEL: define void @replace_int_memset_to_var_test( +; CHECK-SAME: ) #[[ATTR0]] { +; CHECK-NEXT: [[ACCUM_I_FLAT:%.*]] = alloca [1 x i32], align 4 +; CHECK-NEXT: [[I:%.*]] = alloca i32, align 4 +; CHECK-NEXT: store i32 1, ptr [[I]], align 4 +; CHECK-NEXT: [[I8_LOAD:%.*]] = load i32, ptr [[I]], align 4 +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i32, ptr [[ACCUM_I_FLAT]], i32 0 +; CHECK-NEXT: store i32 [[I8_LOAD]], ptr [[GEP]], align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[ACCUM_I_FLAT]]) +; CHECK-NEXT: ret void +; + %accum.i.flat = alloca [1 x i32], align 4 + %i = alloca i8, align 4 + store i8 1, ptr %i + %i8.load = load i8, ptr %i + call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %accum.i.flat) + call void @llvm.memset.p0.i32(ptr nonnull align 4 dereferenceable(8) %accum.i.flat, i8 %i8.load, i32 4, i1 false) + call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %accum.i.flat) + ret void +} + +attributes #0 = {"hlsl.export"} + + +declare void @llvm.lifetime.end.p0(i64 immarg, ptr captures(none)) +declare void @llvm.lifetime.start.p0(i64 immarg, ptr captures(none)) +declare void @llvm.memset.p0.i32(ptr writeonly captures(none), i8, i32, i1 immarg) diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index a2412b6324a05..55dd86c9fad1d 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -13,7 +13,6 @@ ; CHECK-OBJ-NEXT: Create Garbage Collector Module Metadata ; CHECK-NEXT: ModulePass Manager -; CHECK-NEXT: DXIL Finalize Linkage ; CHECK-NEXT: DXIL Intrinsic Expansion ; CHECK-NEXT: DXIL CBuffer Access ; CHECK-NEXT: DXIL Data Scalarization @@ -24,6 +23,7 @@ ; CHECK-NEXT: Scalarize vector operations ; CHECK-NEXT: DXIL Forward Handle Accesses ; CHECK-NEXT: DXIL Legalizer +; CHECK-NEXT: DXIL Finalize Linkage ; CHECK-NEXT: DXIL Resources Analysis ; CHECK-NEXT: DXIL Module Metadata analysis ; CHECK-NEXT: DXIL Shader Flag Analysis