Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/clang/include/clang/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ class Expr : public Stmt {
bool isConstantInitializer(ASTContext &Ctx, bool ForRef,
const Expr **Culprit = nullptr) const;

bool isVulkanSpecConstantExpr(const ASTContext &Ctx,
APValue *Result = nullptr) const;

/// EvalStatus is a struct with detailed info about an evaluation in progress.
struct EvalStatus {
/// HasSideEffects - Whether the evaluated expression has side effects.
Expand Down
14 changes: 7 additions & 7 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def HLSLMaxTessFactor: InheritableAttr {
}
def HLSLNumThreads: InheritableAttr {
let Spellings = [CXX11<"", "numthreads", 2015>];
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLRootSignature: InheritableAttr {
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {

def HLSLNodeId : InheritableAttr {
let Spellings = [CXX11<"", "nodeid", 2017>];
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
let Documentation = [Undocumented];
}

Expand All @@ -1016,25 +1016,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {

def HLSLNodeShareInputOf : InheritableAttr {
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
let Documentation = [Undocumented];
}

def HLSLNodeDispatchGrid: InheritableAttr {
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}

def HLSLNodeMaxDispatchGrid: InheritableAttr {
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}

def HLSLNodeMaxRecursionDepth : InheritableAttr {
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
let Args = [UnsignedArgument<"Count">];
let Args = [ExprArgument<"Count">];
let Documentation = [Undocumented];
}

Expand Down Expand Up @@ -1182,7 +1182,7 @@ def HLSLHitObject : InheritableAttr {

def HLSLMaxRecords : InheritableAttr {
let Spellings = [CXX11<"", "MaxRecords", 2015>];
let Args = [IntArgument<"maxCount">];
let Args = [ExprArgument<"maxCount">];
let Documentation = [Undocumented];
}

Expand Down
10 changes: 10 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,15 @@ class SpirvContext {
instructionsWithLoweredType.end();
}

SpirvInstruction *getSpecConstant(const VarDecl *decl) {
return specConstants[decl];
}

void registerSpecConstant(const VarDecl *decl,
SpirvInstruction *specConstant) {
specConstants[decl] = specConstant;
}

void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
auto iter = dispatchGridIndices.find(decl);
if (iter == dispatchGridIndices.end()) {
Expand Down Expand Up @@ -536,6 +545,7 @@ class SpirvContext {
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
llvm::MapVector<const VarDecl *, SpirvInstruction *> specConstants;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Expand Down
2 changes: 0 additions & 2 deletions tools/clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema,
const clang::InitListExpr *InitList,
const clang::QualType EltTy);

bool ContainsLongVector(clang::QualType);

bool IsConversionToLessOrEqualElements(clang::Sema *self,
const clang::ExprResult &sourceExpr,
const clang::QualType &targetType,
Expand Down
13 changes: 13 additions & 0 deletions tools/clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9448,6 +9448,19 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
return true;
}

bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx,
APValue *Result) const {
auto *D = dyn_cast<DeclRefExpr>(this);
if (!D)
return false;
auto *V = dyn_cast<VarDecl>(D->getDecl());
if (!V || !V->hasAttr<VKConstantIdAttr>())
return false;
if (const Expr *I = V->getAnyInitializer())
return I->IgnoreParenCasts()->isCXX11ConstantExpr(Ctx, Result);
return true;
}

bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {
return CheckICE(this, Ctx).Kind == IK_ICE;
}
Expand Down
58 changes: 41 additions & 17 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
};
} // namespace

static uint32_t GetIntConstAttrArg(ASTContext &astContext, const Expr *expr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've got a fair amount of mixing coding style, and DXC is a bit inconsistent.

In general if you're adding code in a part of the codebase that isn't otherwise following a consistent style we follow the LLVM Coding Standards. LLVM uses CamelCase for variable names and type names (see: https://llvm.org/docs/CodingStandards.html#name-types-functions-variables-and-enumerators-properly).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

uint32_t defaultVal = 0) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint32_t defaultVal = 0) {
uint32_t defaultVal) {

Doesn't look like this is ever called without a default value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed this to use llvm::Optional so that an absent default indicates that expr should be non-null. (See note on SemaHLSL below.)

if (expr) {
llvm::APSInt apsInt;
APValue apValue;
if (expr->isIntegerConstantExpr(apsInt, astContext))
return (uint32_t)apsInt.getSExtValue();
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
return (uint32_t)apValue.getInt().getSExtValue();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
llvm_unreachable("Expression must be a constant expression or spec constant")
}

We should be erroring in Sema if this isn't a constant expression or spec constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

return defaultVal;
}

//------------------------------------------------------------------------------
//
// CGMSHLSLRuntime methods.
Expand Down Expand Up @@ -1422,6 +1435,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}

DiagnosticsEngine &Diags = CGM.getDiags();
ASTContext &astContext = CGM.getTypes().getContext();

std::unique_ptr<DxilFunctionProps> funcProps =
llvm::make_unique<DxilFunctionProps>();
Expand Down Expand Up @@ -1632,10 +1646,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {

// Populate numThreads
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {

funcProps->numThreads[0] = Attr->getX();
funcProps->numThreads[1] = Attr->getY();
funcProps->numThreads[2] = Attr->getZ();
funcProps->numThreads[0] = GetIntConstAttrArg(astContext, Attr->getX(), 1);
funcProps->numThreads[1] = GetIntConstAttrArg(astContext, Attr->getY(), 1);
funcProps->numThreads[2] = GetIntConstAttrArg(astContext, Attr->getZ(), 1);

if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
unsigned DiagID = Diags.getCustomDiagID(
Expand Down Expand Up @@ -1808,7 +1821,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {

if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
funcProps->NodeShaderID.Name = pAttr->getName().str();
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
funcProps->NodeShaderID.Index =
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
} else {
funcProps->NodeShaderID.Name = FD->getName().str();
funcProps->NodeShaderID.Index = 0;
Expand All @@ -1819,20 +1833,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
funcProps->NodeShaderSharedInput.Index =
GetIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
funcProps->Node.DispatchGrid[0] = pAttr->getX();
funcProps->Node.DispatchGrid[1] = pAttr->getY();
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
funcProps->Node.DispatchGrid[0] =
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
funcProps->Node.DispatchGrid[1] =
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
funcProps->Node.DispatchGrid[2] =
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
funcProps->Node.MaxDispatchGrid[0] =
GetIntConstAttrArg(astContext, pAttr->getX(), 1);
funcProps->Node.MaxDispatchGrid[1] =
GetIntConstAttrArg(astContext, pAttr->getY(), 1);
funcProps->Node.MaxDispatchGrid[2] =
GetIntConstAttrArg(astContext, pAttr->getZ(), 1);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
funcProps->Node.MaxRecursionDepth =
GetIntConstAttrArg(astContext, pAttr->getCount(), 0);
}
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
// NumThreads wasn't specified.
Expand Down Expand Up @@ -2346,8 +2368,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;

if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
node.MaxRecords =
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
node.MaxRecords = GetIntConstAttrArg(
astContext,
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount(), 1);
}
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
node.Flags.SetGloballyCoherent();
Expand Down Expand Up @@ -2378,7 +2401,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
// OutputID from attribute
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
node.OutputID.Name = Attr->getName().str();
node.OutputID.Index = Attr->getArrayIndex();
node.OutputID.Index =
GetIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
} else {
node.OutputID.Name = parmDecl->getName().str();
node.OutputID.Index = 0;
Expand Down Expand Up @@ -2437,7 +2461,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
node.MaxRecordsSharedWith = ix;
}
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
node.MaxRecords = Attr->getMaxCount();
node.MaxRecords = GetIntConstAttrArg(astContext, Attr->getMaxCount(), 0);
}

if (inputPatchCount > 1) {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
SpirvInstruction *specConstant) {
specConstant->setRValue();
spvContext.registerSpecConstant(decl, specConstant);
registerVariableForDecl(decl, createDeclSpirvInfo(specConstant));
}

Expand Down
31 changes: 23 additions & 8 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2523,6 +2523,24 @@ isFieldMergeWithPrevious(const StructType::FieldInfo &previous,
return previous.fieldIndex == field.fieldIndex;
}

uint32_t EmitTypeHandler::getAttrArgInstr(ASTContext &astContext,
const Expr *expr,
uint32_t defaultVal) {
if (expr) {
llvm::APSInt apsInt;
APValue apValue;
if (expr->isIntegerConstantExpr(apsInt, astContext))
return getOrCreateConstantInt(apsInt, context.getUIntType(32), false);
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) &&
apValue.isInt()) {
auto *declRefExpr = dyn_cast<DeclRefExpr>(expr);
auto *decl = dyn_cast<const VarDecl>(declRefExpr->getDecl());
return getOrAssignResultId(context.getSpecConstant(decl));
}
}
return defaultVal;
}

uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
// First get the decorations that would apply to this type.
bool alreadyExists = false;
Expand Down Expand Up @@ -2655,27 +2673,24 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
StringRef name = nodeDecl->getName();
unsigned index = 0;
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
if (auto *nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
name = nodeID->getName();
index = nodeID->getArrayIndex();
index = getAttrArgInstr(astContext, nodeID->getArrayIndex());
}

auto *str = new (context) SpirvConstantString(name);
uint32_t nodeName = getOrCreateConstantString(str);
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
llvm::None, true);
if (index) {
uint32_t baseIndex = getOrCreateConstantInt(
llvm::APInt(32, index), context.getUIntType(32), false);
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX,
{baseIndex}, llvm::None, true);
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, {index},
llvm::None, true);
}
}

uint32_t maxRecords;
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
context.getUIntType(32), false);
maxRecords = getAttrArgInstr(astContext, attr->getMaxCount(), 1);
} else {
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
context.getUIntType(32), false);
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class EmitTypeHandler {
EmitTypeHandler(const EmitTypeHandler &) = delete;
EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;

uint32_t getAttrArgInstr(ASTContext &astContext, const Expr *expr,
uint32_t defaultVal = 0);

// Emits the instruction for the given type into the typeConstantBinary and
// returns the result-id for the type. If the type has already been emitted,
// it only returns its result-id.
Expand Down
Loading