diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index e44d3b70eb657..84acf4d536d0c 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -49,6 +49,7 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::log: case Intrinsic::log10: case Intrinsic::pow: + case Intrinsic::powi: case Intrinsic::dx_all: case Intrinsic::dx_any: case Intrinsic::dx_cross: @@ -251,7 +252,7 @@ static Value *expandExpIntrinsic(CallInst *Orig) { } static Value *expandAnyOrAllIntrinsic(CallInst *Orig, - Intrinsic::ID intrinsicId) { + Intrinsic::ID IntrinsicId) { Value *X = Orig->getOperand(0); IRBuilder<> Builder(Orig); Type *Ty = X->getType(); @@ -285,7 +286,7 @@ static Value *expandAnyOrAllIntrinsic(CallInst *Orig, Result = Builder.CreateExtractElement(Cond, (uint64_t)0); for (unsigned I = 1; I < XVec->getNumElements(); I++) { Value *Elt = Builder.CreateExtractElement(Cond, I); - Result = ApplyOp(intrinsicId, Result, Elt); + Result = ApplyOp(IntrinsicId, Result, Elt); } } return Result; @@ -410,13 +411,16 @@ static Value *expandAtan2Intrinsic(CallInst *Orig) { return Result; } -static Value *expandPowIntrinsic(CallInst *Orig) { +static Value *expandPowIntrinsic(CallInst *Orig, Intrinsic::ID IntrinsicId) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); Type *Ty = X->getType(); IRBuilder<> Builder(Orig); + if (IntrinsicId == Intrinsic::powi) + Y = Builder.CreateSIToFP(Y, Ty); + auto *Log2Call = Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); auto *Mul = Builder.CreateFMul(Log2Call, Y); @@ -542,7 +546,8 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { Result = expandLog10Intrinsic(Orig); break; case Intrinsic::pow: - Result = expandPowIntrinsic(Orig); + case Intrinsic::powi: + Result = expandPowIntrinsic(Orig, IntrinsicId); break; case Intrinsic::dx_all: case Intrinsic::dx_any: diff --git a/llvm/test/CodeGen/DirectX/pow.ll b/llvm/test/CodeGen/DirectX/pow.ll index 378649ec119d5..ebc0382415801 100644 --- a/llvm/test/CodeGen/DirectX/pow.ll +++ b/llvm/test/CodeGen/DirectX/pow.ll @@ -25,5 +25,31 @@ entry: ret half %elt.pow } +define noundef float @powi_float(float noundef %a, i32 noundef %b) { +entry: +; CHECK: [[CAST:%.*]] = sitofp i32 %b to float +; DOPCHECK: call float @dx.op.unary.f32(i32 23, float %a) +; EXPCHECK: call float @llvm.log2.f32(float %a) +; CHECK: fmul float %{{.*}}, [[CAST]] +; DOPCHECK: call float @dx.op.unary.f32(i32 21, float %{{.*}}) +; EXPCHECK: call float @llvm.exp2.f32(float %{{.*}}) + %elt.powi = call float @llvm.powi.f32.i32(float %a, i32 %b) + ret float %elt.powi +} + +define noundef half @powi_half(half noundef %a, i32 noundef %b) { +entry: +; CHECK: [[CAST:%.*]] = sitofp i32 %b to half +; DOPCHECK: call half @dx.op.unary.f16(i32 23, half %a) +; EXPCHECK: call half @llvm.log2.f16(half %a) +; CHECK: fmul half %{{.*}}, [[CAST]] +; DOPCHECK: call half @dx.op.unary.f16(i32 21, half %{{.*}}) +; EXPCHECK: call half @llvm.exp2.f16(half %{{.*}}) + %elt.powi = call half @llvm.powi.f16.i32(half %a, i32 %b) + ret half %elt.powi +} + declare half @llvm.pow.f16(half,half) declare float @llvm.pow.f32(float,float) +declare half @llvm.powi.f16.i32(half,i32) +declare float @llvm.powi.f32.i32(float,i32)