Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tools/clang/include/clang/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ class Expr : public Stmt {
bool isConstantInitializer(ASTContext &Ctx, bool ForRef,
const Expr **Culprit = nullptr) const;

bool isVulkanSpecConstantExpr(const ASTContext &Ctx,
APValue *Result = nullptr) const;

/// EvalStatus is a struct with detailed info about an evaluation in progress.
struct EvalStatus {
/// HasSideEffects - Whether the evaluated expression has side effects.
Expand Down
14 changes: 7 additions & 7 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def HLSLMaxTessFactor: InheritableAttr {
}
def HLSLNumThreads: InheritableAttr {
let Spellings = [CXX11<"", "numthreads", 2015>];
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLRootSignature: InheritableAttr {
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {

def HLSLNodeId : InheritableAttr {
let Spellings = [CXX11<"", "nodeid", 2017>];
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
let Documentation = [Undocumented];
}

Expand All @@ -1028,25 +1028,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {

def HLSLNodeShareInputOf : InheritableAttr {
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
let Documentation = [Undocumented];
}

def HLSLNodeDispatchGrid: InheritableAttr {
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}

def HLSLNodeMaxDispatchGrid: InheritableAttr {
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}

def HLSLNodeMaxRecursionDepth : InheritableAttr {
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
let Args = [UnsignedArgument<"Count">];
let Args = [ExprArgument<"Count">];
let Documentation = [Undocumented];
}

Expand Down Expand Up @@ -1194,7 +1194,7 @@ def HLSLHitObject : InheritableAttr {

def HLSLMaxRecords : InheritableAttr {
let Spellings = [CXX11<"", "MaxRecords", 2015>];
let Args = [IntArgument<"maxCount">];
let Args = [ExprArgument<"maxCount">];
let Documentation = [Undocumented];
}

Expand Down
2 changes: 2 additions & 0 deletions tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -7740,6 +7740,8 @@ def warn_hlsl_entry_attribute_without_shader_attribute : Warning<
InGroup<HLSLEntryAttributeWithoutShaderAttrType>;
def err_hlsl_attribute_expects_float_literal : Error<
"attribute %0 must have a float literal argument">;
def err_hlsl_attribute_expects_integer_const_expr : Error<
"attribute %0 argument %1 must be integer constant expression">;
def warn_hlsl_comma_in_init : Warning<
"comma expression used where a constructor list may have been intended">,
InGroup<HLSLCommaInInit>;
Expand Down
10 changes: 10 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,15 @@ class SpirvContext {
instructionsWithLoweredType.end();
}

SpirvInstruction *getSpecConstant(const VarDecl *decl) {
return specConstants[decl];
}

void registerSpecConstant(const VarDecl *decl,
SpirvInstruction *specConstant) {
specConstants[decl] = specConstant;
}

void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
auto iter = dispatchGridIndices.find(decl);
if (iter == dispatchGridIndices.end()) {
Expand Down Expand Up @@ -536,6 +545,7 @@ class SpirvContext {
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
llvm::MapVector<const VarDecl *, SpirvInstruction *> specConstants;
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
const RayQueryTypeKHR *rayQueryTypeKHR;

Expand Down
2 changes: 0 additions & 2 deletions tools/clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema,
const clang::InitListExpr *InitList,
const clang::QualType EltTy);

bool ContainsLongVector(clang::QualType);

bool IsConversionToLessOrEqualElements(clang::Sema *self,
const clang::ExprResult &sourceExpr,
const clang::QualType &targetType,
Expand Down
13 changes: 13 additions & 0 deletions tools/clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9448,6 +9448,19 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
return true;
}

bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx,
APValue *Result) const {
auto *D = dyn_cast<DeclRefExpr>(this);
if (!D)
return false;
auto *V = dyn_cast<VarDecl>(D->getDecl());
if (!V || !V->hasAttr<VKConstantIdAttr>())
return false;
if (const Expr *I = V->getAnyInitializer())
return I->IgnoreParenCasts()->isCXX11ConstantExpr(Ctx, Result);
return true;
}

bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {
return CheckICE(this, Ctx).Kind == IK_ICE;
}
Expand Down
64 changes: 47 additions & 17 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,25 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
};
} // namespace

static uint32_t
getIntConstAttrArg(ASTContext &astContext, const Expr *expr,
llvm::Optional<uint32_t> defaultVal = llvm::None) {
if (expr) {
llvm::APSInt apsInt;
APValue apValue;
if (expr->isIntegerConstantExpr(apsInt, astContext))
return (uint32_t)apsInt.getSExtValue();
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
return (uint32_t)apValue.getInt().getSExtValue();
llvm_unreachable(
"Expression must be a constant expression or spec constant");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
llvm_unreachable("Expression must be a constant expression or spec constant")
}

We should be erroring in Sema if this isn't a constant expression or spec constant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

if (!defaultVal.hasValue()) {
DXASSERT(defaultVal.hasValue(), "missing attribute parameter");
}
return defaultVal.getValue();
}

//------------------------------------------------------------------------------
//
// CGMSHLSLRuntime methods.
Expand Down Expand Up @@ -1422,6 +1441,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}

DiagnosticsEngine &Diags = CGM.getDiags();
ASTContext &astContext = CGM.getTypes().getContext();

std::unique_ptr<DxilFunctionProps> funcProps =
llvm::make_unique<DxilFunctionProps>();
Expand Down Expand Up @@ -1632,10 +1652,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {

// Populate numThreads
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {

funcProps->numThreads[0] = Attr->getX();
funcProps->numThreads[1] = Attr->getY();
funcProps->numThreads[2] = Attr->getZ();
funcProps->numThreads[0] = getIntConstAttrArg(astContext, Attr->getX());
funcProps->numThreads[1] = getIntConstAttrArg(astContext, Attr->getY());
funcProps->numThreads[2] = getIntConstAttrArg(astContext, Attr->getZ());

if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
unsigned DiagID = Diags.getCustomDiagID(
Expand Down Expand Up @@ -1808,7 +1827,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {

if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
funcProps->NodeShaderID.Name = pAttr->getName().str();
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
funcProps->NodeShaderID.Index =
getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
} else {
funcProps->NodeShaderID.Name = FD->getName().str();
funcProps->NodeShaderID.Index = 0;
Expand All @@ -1819,20 +1839,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
funcProps->NodeShaderSharedInput.Index =
getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
}
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
funcProps->Node.DispatchGrid[0] = pAttr->getX();
funcProps->Node.DispatchGrid[1] = pAttr->getY();
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
funcProps->Node.DispatchGrid[0] =
getIntConstAttrArg(astContext, pAttr->getX());
funcProps->Node.DispatchGrid[1] =
getIntConstAttrArg(astContext, pAttr->getY());
funcProps->Node.DispatchGrid[2] =
getIntConstAttrArg(astContext, pAttr->getZ());
}
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
funcProps->Node.MaxDispatchGrid[0] =
getIntConstAttrArg(astContext, pAttr->getX());
funcProps->Node.MaxDispatchGrid[1] =
getIntConstAttrArg(astContext, pAttr->getY());
funcProps->Node.MaxDispatchGrid[2] =
getIntConstAttrArg(astContext, pAttr->getZ());
}
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
funcProps->Node.MaxRecursionDepth =
getIntConstAttrArg(astContext, pAttr->getCount());
}
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
// NumThreads wasn't specified.
Expand Down Expand Up @@ -2346,8 +2374,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;

if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
node.MaxRecords =
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
node.MaxRecords = getIntConstAttrArg(
astContext,
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount());
}
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
node.Flags.SetGloballyCoherent();
Expand Down Expand Up @@ -2378,7 +2407,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
// OutputID from attribute
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
node.OutputID.Name = Attr->getName().str();
node.OutputID.Index = Attr->getArrayIndex();
node.OutputID.Index =
getIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
} else {
node.OutputID.Name = parmDecl->getName().str();
node.OutputID.Index = 0;
Expand Down Expand Up @@ -2437,7 +2467,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
node.MaxRecordsSharedWith = ix;
}
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
node.MaxRecords = Attr->getMaxCount();
node.MaxRecords = getIntConstAttrArg(astContext, Attr->getMaxCount());
}

if (inputPatchCount > 1) {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
SpirvInstruction *specConstant) {
specConstant->setRValue();
spvContext.registerSpecConstant(decl, specConstant);
registerVariableForDecl(decl, createDeclSpirvInfo(specConstant));
}

Expand Down
18 changes: 18 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2522,6 +2522,24 @@ isFieldMergeWithPrevious(const StructType::FieldInfo &previous,
return previous.fieldIndex == field.fieldIndex;
}

uint32_t EmitTypeHandler::getAttrArgInstr(ASTContext &astContext,
const Expr *expr,
uint32_t defaultVal) {
if (expr) {
llvm::APSInt apsInt;
APValue apValue;
if (expr->isIntegerConstantExpr(apsInt, astContext))
return getOrCreateConstantInt(apsInt, context.getUIntType(32), false);
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) &&
apValue.isInt()) {
auto *declRefExpr = dyn_cast<DeclRefExpr>(expr);
auto *decl = dyn_cast<const VarDecl>(declRefExpr->getDecl());
return getOrAssignResultId(context.getSpecConstant(decl));
}
}
return defaultVal;
}

uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
// First get the decorations that would apply to this type.
bool alreadyExists = false;
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class EmitTypeHandler {
EmitTypeHandler(const EmitTypeHandler &) = delete;
EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;

uint32_t getAttrArgInstr(ASTContext &astContext, const Expr *expr,
uint32_t defaultVal = 0);

// Emits the instruction for the given type into the typeConstantBinary and
// returns the result-id for the type. If the type has already been emitted,
// it only returns its result-id.
Expand Down
Loading