diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 6093664c908dc..fe6a552433550 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -38,7 +38,8 @@ def int_dx_typedBufferStore [IntrWriteMem]>; def int_dx_updateCounter - : DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i8_ty]>; + : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_i8_ty], + [IntrInaccessibleMemOrArgMemOnly]>; // Cast between target extension handle types and dxil-style opaque handles def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 078f0591a6751..7b0cb8c41b12a 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -757,7 +757,7 @@ def BufferStore : DXILOp<69, bufferStore> { def UpdateCounter : DXILOp<70, bufferUpdateCounter> { let Doc = "increments/decrements a buffer counter"; let arguments = [HandleTy, Int8Ty]; - let result = VoidTy; + let result = Int32Ty; let stages = [Stages]; } diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 02b441126cfd0..9f124394363a3 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -488,6 +488,7 @@ class OpLowerer { [[nodiscard]] bool lowerUpdateCounter(Function &F) { IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int32Ty = IRB.getInt32Ty(); return replaceFunction(F, [&](CallInst *CI) -> Error { IRB.SetInsertPoint(CI); @@ -497,12 +498,13 @@ class OpLowerer { std::array Args{Handle, Op1}; - Expected OpCall = - OpBuilder.tryCreateOp(OpCode::UpdateCounter, Args, CI->getName()); + Expected OpCall = OpBuilder.tryCreateOp( + OpCode::UpdateCounter, Args, CI->getName(), Int32Ty); if (Error E = OpCall.takeError()) return E; + CI->replaceAllUsesWith(*OpCall); CI->eraseFromParent(); return Error::success(); }); diff --git a/llvm/test/CodeGen/DirectX/updateCounter.ll b/llvm/test/CodeGen/DirectX/updateCounter.ll index 68ea1e9eac9d5..6bfb4d8670f55 100644 --- a/llvm/test/CodeGen/DirectX/updateCounter.ll +++ b/llvm/test/CodeGen/DirectX/updateCounter.ll @@ -11,8 +11,8 @@ define void @update_counter_decrement_vector() { i32 0, i32 0, i32 1, i32 0, i1 false) ; CHECK-NEXT: [[BUFFANOT:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]] - ; CHECK-NEXT: call void @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[BUFFANOT]], i8 -1) - call void @llvm.dx.updateCounter(target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i8 -1) + ; CHECK-NEXT: [[REG:%.*]] = call i32 @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[BUFFANOT]], i8 -1) + %1 = call i32 @llvm.dx.updateCounter(target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i8 -1) ret void } @@ -23,8 +23,8 @@ define void @update_counter_increment_vector() { @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0( i32 0, i32 0, i32 1, i32 0, i1 false) ; CHECK-NEXT: [[BUFFANOT:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]] - ; CHECK-NEXT: call void @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[BUFFANOT]], i8 1) - call void @llvm.dx.updateCounter(target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i8 1) + ; CHECK-NEXT: [[REG:%.*]] = call i32 @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[BUFFANOT]], i8 1) + %1 = call i32 @llvm.dx.updateCounter(target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i8 1) ret void } @@ -35,7 +35,7 @@ define void @update_counter_decrement_scalar() { @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_0_0t( i32 1, i32 8, i32 1, i32 0, i1 false) ; CHECK-NEXT: [[BUFFANOT:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 216, %dx.types.Handle [[BIND]] - ; CHECK-NEXT: call void @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[BUFFANOT]], i8 -1) - call void @llvm.dx.updateCounter(target("dx.RawBuffer", i8, 0, 0) %buffer, i8 -1) + ; CHECK-NEXT: [[REG:%.*]] = call i32 @dx.op.bufferUpdateCounter(i32 70, %dx.types.Handle [[BUFFANOT]], i8 -1) + %1 = call i32 @llvm.dx.updateCounter(target("dx.RawBuffer", i8, 0, 0) %buffer, i8 -1) ret void }