Skip to content
Closed
10 changes: 10 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4335,6 +4335,16 @@ def HLSLLoopHint: StmtAttr {
let Documentation = [HLSLLoopHintDocs, HLSLUnrollHintDocs];
}

def HLSLControlFlowHint: StmtAttr {
/// [branch]
/// [flatten]
let Spellings = [Microsoft<"branch">, Microsoft<"flatten">];
let Subjects = SubjectList<[IfStmt],
ErrorDiag, "'if' statements">;
let LangOpts = [HLSL];
let Documentation = [InternalOnly];
}

def CapturedRecord : InheritableAttr {
// This attribute has no spellings as it is only ever created implicitly.
let Spellings = [];
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CGStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,8 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
bool noinline = false;
bool alwaysinline = false;
bool noconvergent = false;
HLSLControlFlowHintAttr::Spelling flattenOrBranch =
HLSLControlFlowHintAttr::SpellingNotCalculated;
const CallExpr *musttail = nullptr;

for (const auto *A : S.getAttrs()) {
Expand Down Expand Up @@ -761,13 +763,17 @@ void CodeGenFunction::EmitAttributedStmt(const AttributedStmt &S) {
Builder.CreateAssumption(AssumptionVal);
}
} break;
case attr::HLSLControlFlowHint: {
flattenOrBranch = cast<HLSLControlFlowHintAttr>(A)->getSemanticSpelling();
} break;
}
}
SaveAndRestore save_nomerge(InNoMergeAttributedStmt, nomerge);
SaveAndRestore save_noinline(InNoInlineAttributedStmt, noinline);
SaveAndRestore save_alwaysinline(InAlwaysInlineAttributedStmt, alwaysinline);
SaveAndRestore save_noconvergent(InNoConvergentAttributedStmt, noconvergent);
SaveAndRestore save_musttail(MustTailCall, musttail);
SaveAndRestore save_flattenOrBranch(HLSLControlFlowAttr, flattenOrBranch);
EmitStmt(S.getSubStmt(), S.getAttrs());
}

Expand Down
23 changes: 22 additions & 1 deletion clang/lib/CodeGen/CodeGenFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,28 @@ void CodeGenFunction::EmitBranchOnBoolExpr(
Weights = createProfileWeights(TrueCount, CurrentCount - TrueCount);
}

Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights, Unpredictable);
auto *BrInst = Builder.CreateCondBr(CondV, TrueBlock, FalseBlock, Weights,
Unpredictable);
switch (HLSLControlFlowAttr) {
case HLSLControlFlowHintAttr::Microsoft_branch:
case HLSLControlFlowHintAttr::Microsoft_flatten: {
llvm::MDBuilder MDHelper(CGM.getLLVMContext());

llvm::ConstantInt *BranchHintConstant =
HLSLControlFlowAttr ==
HLSLControlFlowHintAttr::Spelling::Microsoft_branch
? llvm::ConstantInt::get(CGM.Int32Ty, 1)
: llvm::ConstantInt::get(CGM.Int32Ty, 2);

SmallVector<llvm::Metadata *, 2> Vals(
{MDHelper.createString("hlsl.controlflow.hint"),
MDHelper.createConstant(BranchHintConstant)});
BrInst->setMetadata("hlsl.controlflow.hint",
llvm::MDNode::get(CGM.getLLVMContext(), Vals));
} break;
case HLSLControlFlowHintAttr::SpellingNotCalculated:
break;
}
}

/// ErrorUnsupported - Print out an error that codegen doesn't support the
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ class CodeGenFunction : public CodeGenTypeCache {
/// True if the current statement has noconvergent attribute.
bool InNoConvergentAttributedStmt = false;

/// HLSL Branch attribute.
HLSLControlFlowHintAttr::Spelling HLSLControlFlowAttr =
HLSLControlFlowHintAttr::SpellingNotCalculated;

// The CallExpr within the current statement that the musttail attribute
// applies to. nullptr if there is no 'musttail' on the current statement.
const CallExpr *MustTailCall = nullptr;
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/Sema/SemaStmtAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,12 @@ static Attr *handleHLSLLoopHintAttr(Sema &S, Stmt *St, const ParsedAttr &A,
return ::new (S.Context) HLSLLoopHintAttr(S.Context, A, UnrollFactor);
}

static Attr *handleHLSLControlFlowHint(Sema &S, Stmt *St, const ParsedAttr &A,
SourceRange Range) {

return ::new (S.Context) HLSLControlFlowHintAttr(S.Context, A);
}

static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
SourceRange Range) {
if (A.isInvalid() || A.getKind() == ParsedAttr::IgnoredAttribute)
Expand Down Expand Up @@ -655,6 +661,8 @@ static Attr *ProcessStmtAttribute(Sema &S, Stmt *St, const ParsedAttr &A,
return handleLoopHintAttr(S, St, A, Range);
case ParsedAttr::AT_HLSLLoopHint:
return handleHLSLLoopHintAttr(S, St, A, Range);
case ParsedAttr::AT_HLSLControlFlowHint:
return handleHLSLControlFlowHint(S, St, A, Range);
case ParsedAttr::AT_OpenCLUnrollHint:
return handleOpenCLUnrollHint(S, St, A, Range);
case ParsedAttr::AT_Suppress:
Expand Down
43 changes: 43 additions & 0 deletions clang/test/AST/HLSL/HLSLControlFlowHint.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -ast-dump %s | FileCheck %s

// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used branch 'int (int)'
// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> branch
export int branch(int X){
int resp;
[branch] if (X > 0) {
resp = -X;
} else {
resp = X * 2;
}

return resp;
}

// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used flatten 'int (int)'
// CHECK: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
// CHECK-NEXT: -HLSLControlFlowHintAttr 0x{{[0-9A-Fa-f]+}} <{{.*}}> flatten
export int flatten(int X){
int resp;
[flatten] if (X > 0) {
resp = -X;
} else {
resp = X * 2;
}

return resp;
}

// CHECK: FunctionDecl 0x{{[0-9A-Fa-f]+}} <{{.*}}> {{.*}} used no_attr 'int (int)'
// CHECK-NO: AttributedStmt 0x{{[0-9A-Fa-f]+}} <<invalid sloc>
// CHECK-NO: -HLSLControlFlowHintAttr
export int no_attr(int X){
int resp;
if (X > 0) {
resp = -X;
} else {
resp = X * 2;
}

return resp;
}
48 changes: 48 additions & 0 deletions clang/test/CodeGenHLSL/HLSLControlFlowHint.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple spirv-vulkan-library %s -fnative-half-type -emit-llvm -o - | FileCheck %s

// CHECK: define {{.*}} i32 {{.*}}test_branch{{.*}}(i32 {{.*}} [[VALD:%.*]])
// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_BRANCH:![0-9]+]]
export int test_branch(int X){
int resp;
[branch] if (X > 0) {
resp = -X;
} else {
resp = X * 2;
}

return resp;
}

// CHECK: define {{.*}} i32 {{.*}}test_flatten{{.*}}(i32 {{.*}} [[VALD:%.*]])
// CHECK: [[PARAM:%.*]] = load i32, ptr [[VALD]].addr, align 4
// CHECK: [[CMP:%.*]] = icmp sgt i32 [[PARAM]], 0
// CHECK: br i1 [[CMP]], label %if.then, label %if.else, !hlsl.controlflow.hint [[HINT_FLATTEN:![0-9]+]]
export int test_flatten(int X){
int resp;
[flatten] if (X > 0) {
resp = -X;
} else {
resp = X * 2;
}

return resp;
}

// CHECK: define {{.*}} i32 {{.*}}test_no_attr{{.*}}(i32 {{.*}} [[VALD:%.*]])
// CHECK-NO: !hlsl.controlflow.hint
export int test_no_attr(int X){
int resp;
if (X > 0) {
resp = -X;
} else {
resp = X * 2;
}

return resp;
}

//CHECK: [[HINT_BRANCH]] = !{!"hlsl.controlflow.hint", i32 1}
//CHECK: [[HINT_FLATTEN]] = !{!"hlsl.controlflow.hint", i32 2}
2 changes: 1 addition & 1 deletion llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ let TargetPrefix = "spv" in {
def int_spv_ptrcast : Intrinsic<[llvm_any_ty], [llvm_any_ty, llvm_metadata_ty, llvm_i32_ty], [ImmArg<ArgIndex<2>>]>;
def int_spv_switch : Intrinsic<[], [llvm_any_ty, llvm_vararg_ty]>;
def int_spv_loop_merge : Intrinsic<[], [llvm_vararg_ty]>;
def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
def int_spv_selection_merge : Intrinsic<[], [llvm_any_ty, llvm_i32_ty], [ImmArg<ArgIndex<1>>]>;
def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
def int_spv_unreachable : Intrinsic<[], []>;
def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
Expand Down
36 changes: 36 additions & 0 deletions llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
Expand Down Expand Up @@ -295,6 +296,39 @@ static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD,
return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx);
}

// TODO: We might need to refactor this to be more generic,
// in case we need more metadata to be replaced.
static void translateBranchMetadata(Module &M) {
for (auto &F : M) {
for (auto &BB : F) {
auto *BBTerminatorInst = BB.getTerminator();

auto *HlslControlFlowMD =
BBTerminatorInst->getMetadata("hlsl.controlflow.hint");

if (!HlslControlFlowMD)
continue;

assert(HlslControlFlowMD->getNumOperands() == 2 &&
"invalid operands for hlsl.controlflow.hint");

MDBuilder MDHelper(M.getContext());
auto *Op1 =
mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1));

SmallVector<llvm::Metadata *, 2> Vals(
ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"),
MDHelper.createConstant(Op1)});

auto *MDNode = llvm::MDNode::get(M.getContext(), Vals);

BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode);
BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
}
F.clearMetadata();
}
}

static void translateMetadata(Module &M, const DXILResourceMap &DRM,
const Resources &MDResources,
const ModuleShaderFlags &ShaderFlags,
Expand Down Expand Up @@ -364,6 +398,7 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);

translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
translateBranchMetadata(M);

return PreservedAnalyses::all();
}
Expand Down Expand Up @@ -397,6 +432,7 @@ class DXILTranslateMetadataLegacy : public ModulePass {
getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();

translateMetadata(M, DRM, MDResources, ShaderFlags, MMDI);
translateBranchMetadata(M);
return true;
}
};
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void initializeSPIRVModuleAnalysisPass(PassRegistry &);
void initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PassRegistry &);
void initializeSPIRVPreLegalizerPass(PassRegistry &);
void initializeSPIRVPostLegalizerPass(PassRegistry &);
void initializeSPIRVStructurizerPass(PassRegistry &);
void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
void initializeSPIRVEmitNonSemanticDIPass(PassRegistry &);
} // namespace llvm
Expand Down
28 changes: 22 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2799,19 +2799,35 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
}
return MIB.constrainAllUses(TII, TRI, RBI);
}
case Intrinsic::spv_loop_merge:
case Intrinsic::spv_selection_merge: {
const auto Opcode = IID == Intrinsic::spv_selection_merge
? SPIRV::OpSelectionMerge
: SPIRV::OpLoopMerge;
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode));
case Intrinsic::spv_loop_merge: {
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
assert(I.getOperand(i).isMBB());
MIB.addMBB(I.getOperand(i).getMBB());
}
MIB.addImm(SPIRV::SelectionControl::None);
return MIB.constrainAllUses(TII, TRI, RBI);
}
case Intrinsic::spv_selection_merge: {

int64_t SelectionControl = SPIRV::SelectionControl::None;
auto LastOp = I.getOperand(I.getNumOperands() - 1);

auto BranchHint = LastOp.getImm();
if (BranchHint == 2)
SelectionControl = SPIRV::SelectionControl::Flatten;
else if (BranchHint == 1)
SelectionControl = SPIRV::SelectionControl::DontFlatten;

auto MIB =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSelectionMerge));
for (unsigned i = 1; i < I.getNumExplicitOperands() - 1; ++i) {
assert(I.getOperand(i).isMBB());
MIB.addMBB(I.getOperand(i).getMBB());
}
MIB.addImm(SelectionControl);
return MIB.constrainAllUses(TII, TRI, RBI);
}
case Intrinsic::spv_cmpxchg:
return selectAtomicCmpXchg(ResVReg, ResType, I);
case Intrinsic::spv_unreachable:
Expand Down
Loading
Loading