Skip to content

Commit 0564065

Browse files
authored
[SPIRV] Implement support for SPV_KHR_expect_assume (#66217)
Adds new extension SPV_KHR_expect_assume, new capability ExpectAssumeKHR as well as the new instructions: * OpExpectKHR * OpAssumeTrueKHR These are lowered from respectively llvm.expect.<ty> and llvm.assume intrinsics. Previously https://reviews.llvm.org/D157696
1 parent 6841eff commit 0564065

File tree

9 files changed

+95
-1
lines changed

9 files changed

+95
-1
lines changed

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,8 @@ let TargetPrefix = "spv" in {
3333
def int_spv_unreachable : Intrinsic<[], []>;
3434
def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
3535
def int_spv_undef : Intrinsic<[llvm_i32_ty], []>;
36+
37+
// Expect, Assume Intrinsics
38+
def int_spv_assume : Intrinsic<[], [llvm_i1_ty]>;
39+
def int_spv_expect : Intrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>]>;
3640
}

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ def OpNop: SimpleOp<"OpNop", 0>;
9191
def OpUndef: Op<1, (outs ID:$res), (ins TYPE:$type), "$res = OpUndef $type">;
9292
def OpSizeOf: Op<321, (outs ID:$res), (ins TYPE:$ty, ID:$ptr), "$res = OpSizeOf $ty $ptr">;
9393

94+
// - SPV_KHR_expect_assume : Expect assume instructions
95+
def OpAssumeTrueKHR: Op<5630, (outs), (ins ID:$cond), "OpAssumeTrueKHR $cond">;
96+
def OpExpectKHR: Op<5631, (outs ID:$res), (ins TYPE:$ty, ID:$val, ID:$expected), "$res = OpExpectKHR $ty $val $expected">;
97+
9498
// 3.42.2 Debug Instructions
9599

96100
def OpSourceContinued: Op<2, (outs), (ins StringImm:$str, variable_ops),

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//
1313
//===----------------------------------------------------------------------===//
1414

15+
#include "MCTargetDesc/SPIRVMCTargetDesc.h"
1516
#include "SPIRV.h"
1617
#include "SPIRVGlobalRegistry.h"
1718
#include "SPIRVInstrInfo.h"
@@ -1395,6 +1396,17 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
13951396
break;
13961397
case Intrinsic::spv_alloca:
13971398
return selectFrameIndex(ResVReg, ResType, I);
1399+
case Intrinsic::spv_assume:
1400+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpAssumeTrueKHR))
1401+
.addUse(I.getOperand(1).getReg());
1402+
break;
1403+
case Intrinsic::spv_expect:
1404+
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExpectKHR))
1405+
.addDef(ResVReg)
1406+
.addUse(GR.getSPIRVTypeID(ResType))
1407+
.addUse(I.getOperand(2).getReg())
1408+
.addUse(I.getOperand(3).getReg());
1409+
break;
13981410
default:
13991411
llvm_unreachable("Intrinsic selection not implemented");
14001412
}

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -903,6 +903,13 @@ void addInstrRequirements(const MachineInstr &MI,
903903
case SPIRV::OpGroupNonUniformBallotFindMSB:
904904
Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
905905
break;
906+
case SPIRV::OpAssumeTrueKHR:
907+
case SPIRV::OpExpectKHR:
908+
if (ST.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume)) {
909+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_expect_assume);
910+
Reqs.addCapability(SPIRV::Capability::ExpectAssumeKHR);
911+
}
912+
break;
906913
default:
907914
break;
908915
}

llvm/lib/Target/SPIRV/SPIRVPrepareFunctions.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "llvm/CodeGen/IntrinsicLowering.h"
2525
#include "llvm/IR/IRBuilder.h"
2626
#include "llvm/IR/IntrinsicInst.h"
27+
#include "llvm/IR/Intrinsics.h"
28+
#include "llvm/IR/IntrinsicsSPIRV.h"
2729
#include "llvm/Transforms/Utils/Cloning.h"
2830
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
2931

@@ -233,6 +235,32 @@ static void buildUMulWithOverflowFunc(Function *UMulFunc) {
233235
IRB.CreateRet(Res);
234236
}
235237

238+
static void lowerExpectAssume(IntrinsicInst *II) {
239+
// If we cannot use the SPV_KHR_expect_assume extension, then we need to
240+
// ignore the intrinsic and move on. It should be removed later on by LLVM.
241+
// Otherwise we should lower the intrinsic to the corresponding SPIR-V
242+
// instruction.
243+
// For @llvm.assume we have OpAssumeTrueKHR.
244+
// For @llvm.expect we have OpExpectKHR.
245+
//
246+
// We need to lower this into a builtin and then the builtin into a SPIR-V
247+
// instruction.
248+
if (II->getIntrinsicID() == Intrinsic::assume) {
249+
Function *F = Intrinsic::getDeclaration(
250+
II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
251+
II->setCalledFunction(F);
252+
} else if (II->getIntrinsicID() == Intrinsic::expect) {
253+
Function *F = Intrinsic::getDeclaration(
254+
II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
255+
{II->getOperand(0)->getType()});
256+
II->setCalledFunction(F);
257+
} else {
258+
llvm_unreachable("Unknown intrinsic");
259+
}
260+
261+
return;
262+
}
263+
236264
static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
237265
// Get a separate function - otherwise, we'd have to rework the CFG of the
238266
// current one. Then simply replace the intrinsic uses with a call to the new
@@ -270,6 +298,10 @@ bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
270298
} else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) {
271299
lowerUMulWithOverflow(II);
272300
Changed = true;
301+
} else if (II->getIntrinsicID() == Intrinsic::assume ||
302+
II->getIntrinsicID() == Intrinsic::expect) {
303+
lowerExpectAssume(II);
304+
Changed = true;
273305
}
274306
}
275307
}

llvm/lib/Target/SPIRV/SPIRVSubtarget.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "SPIRVSubtarget.h"
14-
#include "MCTargetDesc/SPIRVBaseInfo.h"
1514
#include "SPIRV.h"
1615
#include "SPIRVGlobalRegistry.h"
1716
#include "SPIRVLegalizerInfo.h"

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ defm CooperativeMatrixNV : CapabilityOperand<5357, 0, 0, [], [Shader]>;
451451
defm ArbitraryPrecisionIntegersINTEL : CapabilityOperand<5844, 0, 0, [SPV_INTEL_arbitrary_precision_integers], [Int8, Int16]>;
452452
defm OptNoneINTEL : CapabilityOperand<6094, 0, 0, [SPV_INTEL_optnone], []>;
453453
defm BitInstructions : CapabilityOperand<6025, 0, 0, [SPV_KHR_bit_instructions], []>;
454+
defm ExpectAssumeKHR : CapabilityOperand<5629, 0, 0, [SPV_KHR_expect_assume], []>;
454455

455456
//===----------------------------------------------------------------------===//
456457
// Multiclass used to define SourceLanguage enum values and at the same time

llvm/test/CodeGen/SPIRV/assume.ll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
; RUN: llc -mtriple=spirv32-unknown-unknown < %s | FileCheck %s
2+
; RUN: llc -mtriple=spirv64-unknown-unknown < %s | FileCheck %s
3+
4+
; CHECK: OpCapability ExpectAssumeKHR
5+
; CHECK-NEXT: OpExtension "SPV_KHR_expect_assume"
6+
7+
declare void @llvm.assume(i1)
8+
9+
; CHECK-DAG: %9 = OpIEqual %5 %6 %7
10+
; CHECK-NEXT: OpAssumeTrueKHR %9
11+
define void @assumeeq(i32 %x, i32 %y) {
12+
%cmp = icmp eq i32 %x, %y
13+
call void @llvm.assume(i1 %cmp)
14+
ret void
15+
}

llvm/test/CodeGen/SPIRV/expect.ll

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: llc -mtriple=spirv32-unknown-unknown < %s | FileCheck %s
2+
; RUN: llc -mtriple=spirv64-unknown-unknown < %s | FileCheck %s
3+
4+
; CHECK: OpCapability ExpectAssumeKHR
5+
; CHECK-NEXT: OpExtension "SPV_KHR_expect_assume"
6+
7+
declare i32 @llvm.expect.i32(i32, i32)
8+
declare i32 @getOne()
9+
10+
; CHECK-DAG: %2 = OpTypeInt 32 0
11+
; CHECK-DAG: %6 = OpFunctionParameter %2
12+
; CHECK-DAG: %9 = OpIMul %2 %6 %8
13+
; CHECK-DAG: %10 = OpExpectKHR %2 %9 %6
14+
15+
define i32 @test(i32 %x) {
16+
%one = call i32 @getOne()
17+
%val = mul i32 %x, %one
18+
%v = call i32 @llvm.expect.i32(i32 %val, i32 %x)
19+
ret i32 %v
20+
}

0 commit comments

Comments
 (0)