diff --git a/tools/clang/include/clang/AST/Expr.h b/tools/clang/include/clang/AST/Expr.h index 55fd184a79..3c011d1531 100644 --- a/tools/clang/include/clang/AST/Expr.h +++ b/tools/clang/include/clang/AST/Expr.h @@ -531,6 +531,9 @@ class Expr : public Stmt { bool isConstantInitializer(ASTContext &Ctx, bool ForRef, const Expr **Culprit = nullptr) const; + bool isVulkanSpecConstantExpr(const ASTContext &Ctx, + APValue *Result = nullptr) const; + /// EvalStatus is a struct with detailed info about an evaluation in progress. struct EvalStatus { /// HasSideEffects - Whether the evaluated expression has side effects. diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 83137dbc3a..5dbf7496b1 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -668,7 +668,7 @@ def HLSLMaxTessFactor: InheritableAttr { } def HLSLNumThreads: InheritableAttr { let Spellings = [CXX11<"", "numthreads", 2015>]; - let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLRootSignature: InheritableAttr { @@ -1016,7 +1016,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr { def HLSLNodeId : InheritableAttr { let Spellings = [CXX11<"", "nodeid", 2017>]; - let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>]; + let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>]; let Documentation = [Undocumented]; } @@ -1028,25 +1028,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr { def HLSLNodeShareInputOf : InheritableAttr { let Spellings = [CXX11<"", "nodeshareinputof", 2017>]; - let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>]; + let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>]; let Documentation = [Undocumented]; } def HLSLNodeDispatchGrid: InheritableAttr { let Spellings = [CXX11<"", "nodedispatchgrid", 2015>]; - let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLNodeMaxDispatchGrid: InheritableAttr { let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>]; - let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; let Documentation = [Undocumented]; } def HLSLNodeMaxRecursionDepth : InheritableAttr { let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>]; - let Args = [UnsignedArgument<"Count">]; + let Args = [ExprArgument<"Count">]; let Documentation = [Undocumented]; } @@ -1194,7 +1194,7 @@ def HLSLHitObject : InheritableAttr { def HLSLMaxRecords : InheritableAttr { let Spellings = [CXX11<"", "MaxRecords", 2015>]; - let Args = [IntArgument<"maxCount">]; + let Args = [ExprArgument<"maxCount">]; let Documentation = [Undocumented]; } diff --git a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td index cbd9412566..24e8181dfb 100644 --- a/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/tools/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -7740,6 +7740,8 @@ def warn_hlsl_entry_attribute_without_shader_attribute : Warning< InGroup; def err_hlsl_attribute_expects_float_literal : Error< "attribute %0 must have a float literal argument">; +def err_hlsl_attribute_expects_integer_const_expr : Error< + "attribute %0 argument %1 must be integer constant expression">; def warn_hlsl_comma_in_init : Warning< "comma expression used where a constructor list may have been intended">, InGroup; diff --git a/tools/clang/include/clang/SPIRV/SpirvContext.h b/tools/clang/include/clang/SPIRV/SpirvContext.h index 8e0458e731..8321da9e63 100644 --- a/tools/clang/include/clang/SPIRV/SpirvContext.h +++ b/tools/clang/include/clang/SPIRV/SpirvContext.h @@ -456,6 +456,15 @@ class SpirvContext { instructionsWithLoweredType.end(); } + SpirvInstruction *getSpecConstant(const VarDecl *decl) { + return specConstants[decl]; + } + + void registerSpecConstant(const VarDecl *decl, + SpirvInstruction *specConstant) { + specConstants[decl] = specConstant; + } + void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) { auto iter = dispatchGridIndices.find(decl); if (iter == dispatchGridIndices.end()) { @@ -536,6 +545,7 @@ class SpirvContext { llvm::DenseSet functionTypes; llvm::DenseMap spirvIntrinsicTypesById; llvm::SmallVector spirvIntrinsicTypes; + llvm::MapVector specConstants; const AccelerationStructureTypeNV *accelerationStructureTypeNV; const RayQueryTypeKHR *rayQueryTypeKHR; diff --git a/tools/clang/include/clang/Sema/SemaHLSL.h b/tools/clang/include/clang/Sema/SemaHLSL.h index 80ce8ddd7d..2dafbf7734 100644 --- a/tools/clang/include/clang/Sema/SemaHLSL.h +++ b/tools/clang/include/clang/Sema/SemaHLSL.h @@ -160,8 +160,6 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema, const clang::InitListExpr *InitList, const clang::QualType EltTy); -bool ContainsLongVector(clang::QualType); - bool IsConversionToLessOrEqualElements(clang::Sema *self, const clang::ExprResult &sourceExpr, const clang::QualType &targetType, diff --git a/tools/clang/lib/AST/ExprConstant.cpp b/tools/clang/lib/AST/ExprConstant.cpp index c24e44022f..d6939477d5 100644 --- a/tools/clang/lib/AST/ExprConstant.cpp +++ b/tools/clang/lib/AST/ExprConstant.cpp @@ -9448,6 +9448,19 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx, return true; } +bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx, + APValue *Result) const { + auto *D = dyn_cast(this); + if (!D) + return false; + auto *V = dyn_cast(D->getDecl()); + if (!V || !V->hasAttr()) + return false; + if (const Expr *I = V->getAnyInitializer()) + return I->IgnoreParenCasts()->isCXX11ConstantExpr(Ctx, Result); + return true; +} + bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const { return CheckICE(this, Ctx).Kind == IK_ICE; } diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index b5add521a6..5e5217c633 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -328,6 +328,25 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { }; } // namespace +static uint32_t +getIntConstAttrArg(ASTContext &astContext, const Expr *expr, + llvm::Optional defaultVal = llvm::None) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return (uint32_t)apsInt.getSExtValue(); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + llvm_unreachable( + "Expression must be a constant expression or spec constant"); + } + if (!defaultVal.hasValue()) { + DXASSERT(defaultVal.hasValue(), "missing attribute parameter"); + } + return defaultVal.getValue(); +} + //------------------------------------------------------------------------------ // // CGMSHLSLRuntime methods. @@ -1422,6 +1441,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } DiagnosticsEngine &Diags = CGM.getDiags(); + ASTContext &astContext = CGM.getTypes().getContext(); std::unique_ptr funcProps = llvm::make_unique(); @@ -1632,10 +1652,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // Populate numThreads if (const HLSLNumThreadsAttr *Attr = FD->getAttr()) { - - funcProps->numThreads[0] = Attr->getX(); - funcProps->numThreads[1] = Attr->getY(); - funcProps->numThreads[2] = Attr->getZ(); + funcProps->numThreads[0] = getIntConstAttrArg(astContext, Attr->getX()); + funcProps->numThreads[1] = getIntConstAttrArg(astContext, Attr->getY()); + funcProps->numThreads[2] = getIntConstAttrArg(astContext, Attr->getZ()); if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) { unsigned DiagID = Diags.getCustomDiagID( @@ -1808,7 +1827,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { if (const auto *pAttr = FD->getAttr()) { funcProps->NodeShaderID.Name = pAttr->getName().str(); - funcProps->NodeShaderID.Index = pAttr->getArrayIndex(); + funcProps->NodeShaderID.Index = + getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0); } else { funcProps->NodeShaderID.Name = FD->getName().str(); funcProps->NodeShaderID.Index = 0; @@ -1819,20 +1839,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { } if (const auto *pAttr = FD->getAttr()) { funcProps->NodeShaderSharedInput.Name = pAttr->getName().str(); - funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex(); + funcProps->NodeShaderSharedInput.Index = + getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.DispatchGrid[0] = pAttr->getX(); - funcProps->Node.DispatchGrid[1] = pAttr->getY(); - funcProps->Node.DispatchGrid[2] = pAttr->getZ(); + funcProps->Node.DispatchGrid[0] = + getIntConstAttrArg(astContext, pAttr->getX()); + funcProps->Node.DispatchGrid[1] = + getIntConstAttrArg(astContext, pAttr->getY()); + funcProps->Node.DispatchGrid[2] = + getIntConstAttrArg(astContext, pAttr->getZ()); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.MaxDispatchGrid[0] = pAttr->getX(); - funcProps->Node.MaxDispatchGrid[1] = pAttr->getY(); - funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ(); + funcProps->Node.MaxDispatchGrid[0] = + getIntConstAttrArg(astContext, pAttr->getX()); + funcProps->Node.MaxDispatchGrid[1] = + getIntConstAttrArg(astContext, pAttr->getY()); + funcProps->Node.MaxDispatchGrid[2] = + getIntConstAttrArg(astContext, pAttr->getZ()); } if (const auto *pAttr = FD->getAttr()) { - funcProps->Node.MaxRecursionDepth = pAttr->getCount(); + funcProps->Node.MaxRecursionDepth = + getIntConstAttrArg(astContext, pAttr->getCount()); } if (!FD->getAttr()) { // NumThreads wasn't specified. @@ -2346,8 +2374,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++; if (parmDecl->hasAttr()) { - node.MaxRecords = - parmDecl->getAttr()->getMaxCount(); + node.MaxRecords = getIntConstAttrArg( + astContext, + parmDecl->getAttr()->getMaxCount()); } if (parmDecl->hasAttr()) node.Flags.SetGloballyCoherent(); @@ -2378,7 +2407,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { // OutputID from attribute if (const auto *Attr = parmDecl->getAttr()) { node.OutputID.Name = Attr->getName().str(); - node.OutputID.Index = Attr->getArrayIndex(); + node.OutputID.Index = + getIntConstAttrArg(astContext, Attr->getArrayIndex(), 0); } else { node.OutputID.Name = parmDecl->getName().str(); node.OutputID.Index = 0; @@ -2437,7 +2467,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { node.MaxRecordsSharedWith = ix; } if (const auto *Attr = parmDecl->getAttr()) - node.MaxRecords = Attr->getMaxCount(); + node.MaxRecords = getIntConstAttrArg(astContext, Attr->getMaxCount()); } if (inputPatchCount > 1) { diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 005060ab10..45cc8b3f22 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -1815,6 +1815,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) { void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl, SpirvInstruction *specConstant) { specConstant->setRValue(); + spvContext.registerSpecConstant(decl, specConstant); registerVariableForDecl(decl, createDeclSpirvInfo(specConstant)); } diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index bf802ad208..fcb8d50020 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -2522,6 +2522,24 @@ isFieldMergeWithPrevious(const StructType::FieldInfo &previous, return previous.fieldIndex == field.fieldIndex; } +uint32_t EmitTypeHandler::getAttrArgInstr(ASTContext &astContext, + const Expr *expr, + uint32_t defaultVal) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, astContext)) + return getOrCreateConstantInt(apsInt, context.getUIntType(32), false); + if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && + apValue.isInt()) { + auto *declRefExpr = dyn_cast(expr); + auto *decl = dyn_cast(declRefExpr->getDecl()); + return getOrAssignResultId(context.getSpecConstant(decl)); + } + } + return defaultVal; +} + uint32_t EmitTypeHandler::emitType(const SpirvType *type) { // First get the decorations that would apply to this type. bool alreadyExists = false; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index b766650098..10c3ea68b0 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -67,6 +67,9 @@ class EmitTypeHandler { EmitTypeHandler(const EmitTypeHandler &) = delete; EmitTypeHandler &operator=(const EmitTypeHandler &) = delete; + uint32_t getAttrArgInstr(ASTContext &astContext, const Expr *expr, + uint32_t defaultVal = 0); + // Emits the instruction for the given type into the typeConstantBinary and // returns the result-id for the type. If the type has already been emitted, // it only returns its result-id. diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 5019fe3926..3142186174 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -11252,15 +11252,11 @@ SpirvEmitter::processIntrinsicIsValid(const CXXMemberCallExpr *callExpr) { const auto *declRefExpr = dyn_cast(baseExpr->IgnoreImpCasts()); const auto *paramDecl = dyn_cast(declRefExpr->getDecl()); - int nodeIndex = 0; - if (HLSLNodeIdAttr *nodeId = paramDecl->getAttr()) { - nodeIndex = nodeId->getArrayIndex(); - } SpirvInstruction *payload = doExpr(baseExpr); if (!shaderIndex) { - shaderIndex = spvBuilder.getConstantInt(astContext.UnsignedIntTy, - llvm::APInt(32, nodeIndex)); + shaderIndex = + spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); } return spvBuilder.createIsNodePayloadValid(payload, shaderIndex, loc); @@ -13671,6 +13667,88 @@ void SpirvEmitter::processInlineSpirvAttributes(const FunctionDecl *decl) { } } +SpirvInstruction * +SpirvEmitter::evalIntConstAttrArg(const Expr *expr, + llvm::Optional defaultVal) { + if (expr) { + QualType type = expr->getType(); + assert(type->isIntegerType()); + SpirvInstruction *ret = doExpr(expr); + assert(ret->getopcode() == spv::Op::OpConstant || + ret->getopcode() == spv::Op::OpSpecConstant); + if (type->isSignedIntegerType()) + ret->setAstResultType(astContext.UnsignedIntTy); + return ret; + } + if (defaultVal.hasValue()) + return spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, defaultVal.getValue())); + return nullptr; +} + +bool SpirvEmitter::processNumThreadsAttr(const FunctionDecl *decl) { + auto *numThreadsAttr = decl->getAttr(); + if (!numThreadsAttr) + return false; + + bool localSizeId = false; + Expr *x = numThreadsAttr->getX(), *y = numThreadsAttr->getY(), + *z = numThreadsAttr->getZ(); + + // SPIR-V spec says LocalSizeId missing "before version 1.2" but SPIRV-Tools + // validation excludes 1.2 as well. + switch (featureManager.getTargetEnv()) { + case SPV_ENV_VULKAN_1_0: + case SPV_ENV_VULKAN_1_1: + case SPV_ENV_VULKAN_1_1_SPIRV_1_4: + case SPV_ENV_VULKAN_1_2: + break; + default: + if (x->isVulkanSpecConstantExpr(astContext) || + y->isVulkanSpecConstantExpr(astContext) || + z->isVulkanSpecConstantExpr(astContext)) { + auto f = [this](Expr *E) -> SpirvInstruction * { + if (E) { + llvm::APSInt apsInt; + APValue apValue; + if (E->isIntegerConstantExpr(apsInt, astContext)) + return spvBuilder.getConstantInt(astContext.UnsignedIntTy, apsInt); + if (E->isVulkanSpecConstantExpr(astContext, &apValue) && + apValue.isInt()) { + auto *declRefExpr = dyn_cast(E); + auto *varDecl = dyn_cast(declRefExpr->getDecl()); + return declIdMapper.getDeclEvalInfo(varDecl, + declRefExpr->getExprLoc()); + } + } + return spvBuilder.getConstantInt(astContext.UnsignedIntTy, + llvm::APInt(32, 1)); + }; + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::LocalSizeId, + {f(x), f(y), f(z)}, decl->getLocation()); + return true; + } + } + + auto f = [this](Expr *E) -> unsigned { + if (E) { + llvm::APSInt apsInt; + APValue apValue; + if (E->isIntegerConstantExpr(apsInt, astContext)) + return (unsigned)apsInt.getZExtValue(); + if (E->isVulkanSpecConstantExpr(astContext, &apValue) && + apValue.isInt()) { + return apValue.getInt().getZExtValue(); + } + } + return 1U; + }; + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {f(x), f(y), f(z)}, decl->getLocation()); + return true; +} + bool SpirvEmitter::processGeometryShaderAttributes(const FunctionDecl *decl, uint32_t *arraySize) { bool success = true; @@ -13862,28 +13940,18 @@ void SpirvEmitter::checkForWaveSizeAttr(const FunctionDecl *decl) { } void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) { - auto *numThreadsAttr = decl->getAttr(); - assert(numThreadsAttr && "thread group size missing from entry-point"); - - uint32_t x = static_cast(numThreadsAttr->getX()); - uint32_t y = static_cast(numThreadsAttr->getY()); - uint32_t z = static_cast(numThreadsAttr->getZ()); - - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); + if (!processNumThreadsAttr(decl)) { + assert(false && "thread group size missing from entry-point"); + } checkForWaveSizeAttr(decl); } void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { - uint32_t x = 1, y = 1, z = 1; - if (auto *numThreadsAttr = decl->getAttr()) { - x = static_cast(numThreadsAttr->getX()); - y = static_cast(numThreadsAttr->getY()); - z = static_cast(numThreadsAttr->getZ()); + if (!processNumThreadsAttr(decl)) { + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {1, 1, 1}, decl->getLocation()); } - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); auto *nodeLaunchAttr = decl->getAttr(); StringRef launchType = nodeLaunchAttr ? nodeLaunchAttr->getLaunchType() : ""; @@ -13893,20 +13961,20 @@ void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { decl->getLocation()); } - uint64_t nodeId = 0; - if (const auto nodeIdAttr = decl->getAttr()) - nodeId = static_cast(nodeIdAttr->getArrayIndex()); - spvBuilder.addExecutionModeId( - entryFunction, spv::ExecutionMode::ShaderIndexAMDX, - {spvBuilder.getConstantInt(astContext.UnsignedIntTy, - llvm::APInt(32, nodeId))}, - decl->getLocation()); + SpirvInstruction *nodeId = nullptr; + if (const auto *nodeIdAttr = decl->getAttr()) + nodeId = evalIntConstAttrArg(nodeIdAttr->getArrayIndex(), 0); + else + nodeId = + spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::ShaderIndexAMDX, {nodeId}, + decl->getLocation()); if (const auto *nodeMaxRecursionDepthAttr = decl->getAttr()) { - SpirvInstruction *count = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, - llvm::APInt(32, nodeMaxRecursionDepthAttr->getCount())); + SpirvInstruction *count = + evalIntConstAttrArg(nodeMaxRecursionDepthAttr->getCount()); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::MaxNodeRecursionAMDX, {count}, decl->getLocation()); @@ -13916,32 +13984,25 @@ void SpirvEmitter::processNodeShaderAttributes(const FunctionDecl *decl) { decl->getAttr()) { SpirvInstruction *name = spvBuilder.getConstantString(nodeShareInputOfAttr->getName()); - SpirvInstruction *index = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, - llvm::APInt(32, nodeShareInputOfAttr->getArrayIndex())); + SpirvInstruction *index = + evalIntConstAttrArg(nodeShareInputOfAttr->getArrayIndex(), 0); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::SharesInputWithAMDX, {name, index}, decl->getLocation()); } if (const auto *dispatchGrid = decl->getAttr()) { - SpirvInstruction *gridX = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getX())); - SpirvInstruction *gridY = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getY())); - SpirvInstruction *gridZ = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, dispatchGrid->getZ())); + SpirvInstruction *gridX = evalIntConstAttrArg(dispatchGrid->getX(), 1); + SpirvInstruction *gridY = evalIntConstAttrArg(dispatchGrid->getY(), 1); + SpirvInstruction *gridZ = evalIntConstAttrArg(dispatchGrid->getZ(), 1); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::StaticNumWorkgroupsAMDX, {gridX, gridY, gridZ}, decl->getLocation()); } else if (const auto *maxDispatchGrid = decl->getAttr()) { - SpirvInstruction *gridX = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getX())); - SpirvInstruction *gridY = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getY())); - SpirvInstruction *gridZ = spvBuilder.getConstantInt( - astContext.UnsignedIntTy, llvm::APInt(32, maxDispatchGrid->getZ())); + SpirvInstruction *gridX = evalIntConstAttrArg(maxDispatchGrid->getX(), 1); + SpirvInstruction *gridY = evalIntConstAttrArg(maxDispatchGrid->getY(), 1); + SpirvInstruction *gridZ = evalIntConstAttrArg(maxDispatchGrid->getZ(), 1); spvBuilder.addExecutionModeId(entryFunction, spv::ExecutionMode::MaxNumWorkgroupsAMDX, {gridX, gridY, gridZ}, decl->getLocation()); @@ -14158,14 +14219,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing( bool SpirvEmitter::processMeshOrAmplificationShaderAttributes( const FunctionDecl *decl, uint32_t *outVerticesArraySize) { - if (auto *numThreadsAttr = decl->getAttr()) { - uint32_t x, y, z; - x = static_cast(numThreadsAttr->getX()); - y = static_cast(numThreadsAttr->getY()); - z = static_cast(numThreadsAttr->getZ()); - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); - } + processNumThreadsAttr(decl); // Early return for amplification shaders as they only take the 'numthreads' // attribute. @@ -15801,13 +15855,16 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() { if (!canUseDerivativeGroupExecutionMode(sm, usingEXTMeshShader)) return; - SpirvExecutionMode *numThreadsEm = - cast(spvBuilder.getModule()->findExecutionMode( - entryFunction, spv::ExecutionMode::LocalSize)); + SpirvExecutionMode *numThreadsEm = dyn_cast_or_null( + spvBuilder.getModule()->findExecutionMode(entryFunction, + spv::ExecutionMode::LocalSize)); + // If there is no LocalSize, there must be LocalSizeId. + if (!numThreadsEm) + return addDerivativeGroupExecutionModeId(); auto numThreads = numThreadsEm->getParams(); // The layout of the quad is determined by the numer of threads in each - // dimention. From the HLSL spec + // dimension. From the HLSL spec // (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html): // // Where numthreads has an X value divisible by 4 and Y and Z are both 1, the @@ -15830,6 +15887,61 @@ void SpirvEmitter::addDerivativeGroupExecutionMode() { spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation()); } +void SpirvEmitter::addDerivativeGroupExecutionModeId() { + assert(spvContext.isCS()); + + SpirvExecutionModeId *numThreadsEm = + dyn_cast(spvBuilder.getModule()->findExecutionMode( + entryFunction, spv::ExecutionMode::LocalSizeId)); + auto numThreads = numThreadsEm->getParams(); + bool numThreadsHasSpecConst = false; + auto f = [&numThreadsHasSpecConst]( + SpirvInstruction *arg) -> llvm::Optional { + if (auto con = dyn_cast(arg)) { + if (con->isSpecConstant()) + numThreadsHasSpecConst = true; + return (unsigned)con->getValue().getZExtValue(); + } + return llvm::None; + }; + + // The layout of the quad is determined by the numer of threads in each + // dimension. From the HLSL spec + // (https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html): + // + // Where numthreads has an X value divisible by 4 and Y and Z are both 1, the + // quad layouts are determined according to 1D quad rules. Where numthreads X + // and Y values are divisible by 2, the quad layouts are determined according + // to 2D quad rules. Using derivative operations in any numthreads + // configuration not matching either of these is invalid and will produce an + // error. + static_assert(spv::ExecutionMode::DerivativeGroupQuadsNV == + spv::ExecutionMode::DerivativeGroupQuadsKHR); + static_assert(spv::ExecutionMode::DerivativeGroupLinearNV == + spv::ExecutionMode::DerivativeGroupLinearKHR); + spv::ExecutionMode em = spv::ExecutionMode::DerivativeGroupQuadsNV; + auto x = f(numThreads[0]), y = f(numThreads[1]), z = f(numThreads[2]); + if (x.hasValue() && x.getValue() % 4 == 0 && y.hasValue() && + y.getValue() == 1 && z.hasValue() && z.getValue() == 1) { + em = spv::ExecutionMode::DerivativeGroupLinearNV; + } else { + assert((!x.hasValue() || x.getValue() % 2 == 0) && + (!y.hasValue() || y.getValue() % 2 == 0)); + } + + if (numThreadsHasSpecConst) { + // This code probably belongs in DiagnoseNumThreadsForDerivativeOp() in + // SemaHLSL.cpp, but that function apparently isn't invoked in all + // applicable situations. + diags.Report( + numThreadsEm->getSourceLocation(), + diags.getCustomDiagID(DiagnosticsEngine::Level::Warning, + "NumThreads spec constant default value used to " + "determine derivative group mode")); + } + spvBuilder.addExecutionMode(entryFunction, em, {}, SourceLocation()); +} + SpirvVariable *SpirvEmitter::createPCFParmVarAndInitFromStageInputVar( const ParmVarDecl *param) { const QualType type = param->getType(); diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index ada8db3068..72d1329664 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -896,6 +896,12 @@ class SpirvEmitter : public ASTConsumer { /// \brief Handle inline SPIR-V attributes for the entry function. void processInlineSpirvAttributes(const FunctionDecl *entryFunction); + SpirvInstruction * + evalIntConstAttrArg(const Expr *expr, + llvm::Optional defaultVal = llvm::None); + + bool processNumThreadsAttr(const FunctionDecl *decl); + /// \brief Adds necessary execution modes for the hull/domain shaders based on /// the HLSL attributes of the entry point function. /// In the case of hull shaders, also writes the number of output control @@ -1366,6 +1372,7 @@ class SpirvEmitter : public ASTConsumer { /// This decision is made according to the rules in /// https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_Derivatives.html. void addDerivativeGroupExecutionMode(); + void addDerivativeGroupExecutionModeId(); /// Creates an input variable for `param` that will be used by the patch /// constant function. The parameter is also added to the patch constant diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 4039e45426..747b4a403a 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -12416,21 +12416,48 @@ void Sema::CheckHLSLFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall) { } } +static uint32_t +getIntConstAttrArg(Sema &S, const Attr *attr, unsigned argNum, const Expr *expr, + llvm::Optional defaultVal = llvm::None) { + if (expr) { + llvm::APSInt apsInt; + APValue apValue; + if (expr->isIntegerConstantExpr(apsInt, S.getASTContext())) + return (uint32_t)apsInt.getSExtValue(); + if (expr->isVulkanSpecConstantExpr(S.getASTContext(), &apValue) && + apValue.isInt()) + return (uint32_t)apValue.getInt().getSExtValue(); + S.Diag(expr->getExprLoc(), + diag::err_hlsl_attribute_expects_integer_const_expr) + << attr->getSpelling() << argNum; + return 0; + } + if (!defaultVal.hasValue()) { + S.Diag(attr->getLocation(), + diag::err_hlsl_attribute_expects_integer_const_expr) + << attr->getSpelling() << argNum; + return 0; + } + return defaultVal.getValue(); +} + ///////////////////////////////////////////////////////////////////////////// // Check HLSL intrinsic calls reachable from entry/export functions. -static void DiagnoseNumThreadsForDerivativeOp(const HLSLNumThreadsAttr *Attr, - SourceLocation LocDeriv, - FunctionDecl *FD, - const FunctionDecl *EntryDecl, - DiagnosticsEngine &Diags) { +static void DiagnoseNumThreadsForDerivativeOp( + Sema &S, const HLSLNumThreadsAttr *Attr, SourceLocation LocDeriv, + FunctionDecl *FD, const FunctionDecl *EntryDecl, DiagnosticsEngine &Diags) { bool invalidNumThreads = false; - if (Attr->getY() != 1) { + uint32_t x = getIntConstAttrArg(S, Attr, 1, Attr->getX()); + uint32_t y = getIntConstAttrArg(S, Attr, 2, Attr->getY()); + uint32_t z = getIntConstAttrArg(S, Attr, 3, Attr->getZ()); + + if (y != 1) { // 2D mode requires x and y to be multiple of 2. - invalidNumThreads = !((Attr->getX() % 2) == 0 && (Attr->getY() % 2) == 0); + invalidNumThreads = !((x % 2) == 0 && (y % 2) == 0); } else { // 1D mode requires x to be multiple of 4 and y and z to be 1. - invalidNumThreads = (Attr->getX() % 4) != 0 || (Attr->getZ() != 1); + invalidNumThreads = (x % 4) != 0 || (z != 1); } if (invalidNumThreads) { Diags.Report(LocDeriv, diag::warn_hlsl_derivatives_wrong_numthreads) @@ -12476,7 +12503,7 @@ static void DiagnoseDerivativeOp(Sema &S, FunctionDecl *FD, SourceLocation Loc, if (const HLSLNumThreadsAttr *Attr = EntryDecl->getAttr()) { - DiagnoseNumThreadsForDerivativeOp(Attr, Loc, FD, EntryDecl, Diags); + DiagnoseNumThreadsForDerivativeOp(S, Attr, Loc, FD, EntryDecl, Diags); } } @@ -13915,12 +13942,12 @@ FlattenedTypeIterator::CompareTypesForInit(HLSLExternalSource &source, //////////////////////////////////////////////////////////////////////////////// // Attribute processing support. // -static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, - unsigned index = 0) { - int64_t value = 0; +static Expr *ValidateAttributeIntArgExpr(Sema &S, const AttributeList &Attr, + unsigned index, int64_t *value, + bool allowDefinedConstant = false) { + Expr *E = nullptr; if (Attr.getNumArgs() > index) { - Expr *E = nullptr; if (!Attr.isArgExpr(index)) { // For case arg is constant variable. IdentifierLoc *loc = Attr.getArgAsIdent(index); @@ -13931,13 +13958,13 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, if (!decl) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); - return value; + return nullptr; } Expr *init = decl->getInit(); if (!init) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); - return value; + return nullptr; } E = init; } else @@ -13947,11 +13974,13 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, bool displayError = false; if (E->isTypeDependent() || E->isValueDependent() || !E->isCXX11ConstantExpr(S.Context, &ArgNum)) { - displayError = true; + displayError = + !allowDefinedConstant || + !(E->isVulkanSpecConstantExpr(S.Context, &ArgNum) && ArgNum.isInt()); } else { if (ArgNum.isInt()) { - value = ArgNum.getInt().getSExtValue(); - if (!(E->getType()->isIntegralOrEnumerationType()) || value < 0) { + *value = ArgNum.getInt().getSExtValue(); + if (!(E->getType()->isIntegralOrEnumerationType()) || *value < 0) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); } @@ -13961,8 +13990,8 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, if (ArgNum.getFloat().convertToInteger( floatInt, llvm::APFloat::rmTowardZero, &isPrecise) == llvm::APFloat::opStatus::opOK) { - value = floatInt.getSExtValue(); - if (value < 0) { + *value = floatInt.getSExtValue(); + if (*value < 0) { S.Diag(Attr.getLoc(), diag::warn_hlsl_attribute_expects_uint_literal) << Attr.getName(); @@ -13980,9 +14009,23 @@ static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, S.Diag(Attr.getLoc(), diag::err_attribute_argument_type) << Attr.getName() << AANT_ArgumentIntegerConstant << E->getSourceRange(); + return nullptr; } } + return E; +} + +static Expr *ValidateAttributeIntArgExpr(Sema &S, const AttributeList &Attr, + unsigned index = 0) { + int64_t value = 0; + return ValidateAttributeIntArgExpr(S, Attr, index, &value, true); +} + +static int ValidateAttributeIntArg(Sema &S, const AttributeList &Attr, + unsigned index = 0) { + int64_t value = 0; + ValidateAttributeIntArgExpr(S, Attr, index, &value); return (int)value; } @@ -14328,19 +14371,27 @@ HLSLMaxRecordsAttr *ValidateMaxRecordsAttributes(Sema &S, Decl *D, Expr *ArgExpr = A.getArgAsExpr(0); IntegerLiteral *LiteralInt = dyn_cast(ArgExpr->IgnoreParenCasts()); + clang::SourceLocation Loc = {}; - if (ExistingMRSWA || ExistingMRA->getMaxCount() != LiteralInt->getValue()) { - clang::SourceLocation Loc = ExistingMRA ? ExistingMRA->getLocation() - : ExistingMRSWA->getLocation(); + if (ExistingMRSWA) { + Loc = ExistingMRSWA->getLocation(); + } else if (ExistingMRA) { + uint32_t maxCount = + getIntConstAttrArg(S, ExistingMRA, 1, ExistingMRA->getMaxCount(), 0); + if (LiteralInt->getValue().getLimitedValue() != maxCount) + Loc = ExistingMRA->getLocation(); + } + + if (Loc.isValid()) { S.Diag(A.getLoc(), diag::err_hlsl_maxrecord_attrs_on_same_arg); S.Diag(Loc, diag::note_conflicting_attribute); return nullptr; } } - return ::new (S.Context) - HLSLMaxRecordsAttr(A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - A.getAttributeSpellingListIndex()); + return ::new (S.Context) HLSLMaxRecordsAttr( + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + A.getAttributeSpellingListIndex()); } // This function validates the wave size attribute in a stand-alone way, @@ -14549,19 +14600,19 @@ void Sema::DiagnoseCoherenceMismatch(const Expr *SrcExpr, QualType TargetType, } } -void ValidateDispatchGridValues(DiagnosticsEngine &Diags, - const AttributeList &A, Attr *declAttr) { +void ValidateDispatchGridValues(Sema &S, const AttributeList &A, + Attr *declAttr) { unsigned x = 1, y = 1, z = 1; if (HLSLNodeDispatchGridAttr *pA = dyn_cast(declAttr)) { - x = pA->getX(); - y = pA->getY(); - z = pA->getZ(); + x = getIntConstAttrArg(S, pA, 1, pA->getX()); + y = getIntConstAttrArg(S, pA, 2, pA->getY()); + z = getIntConstAttrArg(S, pA, 3, pA->getZ()); } else if (HLSLNodeMaxDispatchGridAttr *pA = dyn_cast(declAttr)) { - x = pA->getX(); - y = pA->getY(); - z = pA->getZ(); + x = getIntConstAttrArg(S, pA, 1, pA->getX()); + y = getIntConstAttrArg(S, pA, 2, pA->getY()); + z = getIntConstAttrArg(S, pA, 3, pA->getZ()); } else { llvm_unreachable("ValidateDispatchGridValues() called for wrong attribute"); } @@ -14570,26 +14621,26 @@ void ValidateDispatchGridValues(DiagnosticsEngine &Diags, // If a component is out of range, we reset it to 0 to avoid also generating // a secondary error if the product would be out of range if (x < 1 || x > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(0)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(0)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "X" << A.getRange(); x = 0; } if (y < 1 || y > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(1)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(1)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "Y" << A.getRange(); y = 0; } if (z < 1 || z > MaxComponentValue) { - Diags.Report(A.getArgAsExpr(2)->getExprLoc(), - diag::err_hlsl_dispatchgrid_component) + S.Diags.Report(A.getArgAsExpr(2)->getExprLoc(), + diag::err_hlsl_dispatchgrid_component) << A.getName() << "Z" << A.getRange(); z = 0; } uint64_t product = (uint64_t)x * (uint64_t)y * (uint64_t)z; if (product > MaxProductValue) - Diags.Report(A.getLoc(), diag::err_hlsl_dispatchgrid_product) + S.Diags.Report(A.getLoc(), diag::err_hlsl_dispatchgrid_product) << A.getName() << A.getRange(); } @@ -14749,7 +14800,8 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_HLSLNodeId: declAttr = ::new (S.Context) HLSLNodeIdAttr( A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr, 0), - ValidateAttributeIntArg(S, A, 1), A.getAttributeSpellingListIndex()); + ValidateAttributeIntArgExpr(S, A, 1), + A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNodeTrackRWInputSharing: declAttr = ::new (S.Context) HLSLNodeTrackRWInputSharingAttr( @@ -14852,18 +14904,20 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNumThreads: { - int X = ValidateAttributeIntArg(S, A, 0); - int Y = ValidateAttributeIntArg(S, A, 1); - int Z = ValidateAttributeIntArg(S, A, 2); - int N = X * Y * Z; + int64_t X = 1, Y = 1, Z = 1; + auto *arg0 = ValidateAttributeIntArgExpr(S, A, 0, &X, true); + auto *arg1 = ValidateAttributeIntArgExpr(S, A, 1, &Y, true); + auto *arg2 = ValidateAttributeIntArgExpr(S, A, 2, &Z, true); + int64_t N = X * Y * Z; if (N > 0 && N <= 1024) { - auto numThreads = ::new (S.Context) HLSLNumThreadsAttr( - A.getRange(), S.Context, X, Y, Z, A.getAttributeSpellingListIndex()); + auto numThreads = ::new (S.Context) + HLSLNumThreadsAttr(A.getRange(), S.Context, arg0, arg1, arg2, + A.getAttributeSpellingListIndex()); declAttr = numThreads; } else { // If the number of threads is invalid, diagnose and drop the attribute. S.Diags.Report(A.getLoc(), diag::warn_hlsl_numthreads_group_size) - << N << X << Y << Z << A.getRange(); + << (int)N << (int)X << (int)Y << (int)Z << A.getRange(); return; } break; @@ -14965,31 +15019,37 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, case AttributeList::AT_HLSLNodeShareInputOf: declAttr = ::new (S.Context) HLSLNodeShareInputOfAttr( A.getRange(), S.Context, ValidateAttributeStringArg(S, A, nullptr, 0), - ValidateAttributeIntArg(S, A, 1), A.getAttributeSpellingListIndex()); + ValidateAttributeIntArgExpr(S, A, 1), + A.getAttributeSpellingListIndex()); break; case AttributeList::AT_HLSLNodeDispatchGrid: declAttr = ::new (S.Context) HLSLNodeDispatchGridAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2), + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + ValidateAttributeIntArgExpr(S, A, 1), + ValidateAttributeIntArgExpr(S, A, 2), A.getAttributeSpellingListIndex()); - ValidateDispatchGridValues(S.Diags, A, declAttr); + ValidateDispatchGridValues(S, A, declAttr); break; case AttributeList::AT_HLSLNodeMaxDispatchGrid: declAttr = ::new (S.Context) HLSLNodeMaxDispatchGridAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), - ValidateAttributeIntArg(S, A, 1), ValidateAttributeIntArg(S, A, 2), + A.getRange(), S.Context, ValidateAttributeIntArgExpr(S, A), + ValidateAttributeIntArgExpr(S, A, 1), + ValidateAttributeIntArgExpr(S, A, 2), A.getAttributeSpellingListIndex()); - ValidateDispatchGridValues(S.Diags, A, declAttr); + ValidateDispatchGridValues(S, A, declAttr); break; - case AttributeList::AT_HLSLNodeMaxRecursionDepth: + case AttributeList::AT_HLSLNodeMaxRecursionDepth: { + int64_t maxRecursionDepth = 0; declAttr = ::new (S.Context) HLSLNodeMaxRecursionDepthAttr( - A.getRange(), S.Context, ValidateAttributeIntArg(S, A), + A.getRange(), S.Context, + ValidateAttributeIntArgExpr(S, A, 0, &maxRecursionDepth, true), A.getAttributeSpellingListIndex()); - if (cast(declAttr)->getCount() > 32) + if (maxRecursionDepth > 32) S.Diags.Report(declAttr->getLocation(), diag::err_hlsl_maxrecursiondepth_exceeded) << declAttr->getRange(); break; + } default: Handled = false; break; // SPIRV Change: was return; @@ -16391,8 +16451,13 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, Attr *noconst = const_cast(A); HLSLNumThreadsAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - Out << "[numthreads(" << ACast->getX() << ", " << ACast->getY() << ", " - << ACast->getZ() << ")]\n"; + Out << "[numthreads("; + ACast->getX()->printPretty(Out, nullptr, Policy); + Out << ", "; + ACast->getY()->printPretty(Out, nullptr, Policy); + Out << ", "; + ACast->getZ()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -16624,11 +16689,16 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, Attr *noconst = const_cast(A); HLSLNodeIdAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - if (ACast->getArrayIndex() > 0) - Out << "[NodeId(\"" << ACast->getName() << "\"," << ACast->getArrayIndex() - << ")]\n"; - else - Out << "[NodeId(\"" << ACast->getName() << "\")]\n"; + Out << "[NodeId(\"" << ACast->getName(); + if (auto *lit = dyn_cast(ACast->getArrayIndex())) { + if (!lit->getValue().isStrictlyPositive()) { + Out << "\")]\n"; + break; + } + } + Out << "\","; + ACast->getArrayIndex()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -16646,11 +16716,16 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, HLSLNodeShareInputOfAttr *ACast = static_cast(noconst); Indent(Indentation, Out); - if (ACast->getArrayIndex() > 0) - Out << "[NodeShareInputOf(\"" << ACast->getName() << "\"," - << ACast->getArrayIndex() << ")]\n"; - else - Out << "[NodeShareInputOf(\"" << ACast->getName() << "\")]\n"; + Out << "[NodeShareInputOf(\"" << ACast->getName(); + if (auto *lit = dyn_cast(ACast->getArrayIndex())) { + if (!lit->getValue().isStrictlyPositive()) { + Out << "\")]\n"; + break; + } + } + Out << "\","; + ACast->getArrayIndex()->printPretty(Out, nullptr, Policy); + Out << ")]\n"; break; } @@ -17183,8 +17258,10 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName, // thread group size is (1,1,1) if (NodeLaunchTy == DXIL::NodeLaunchType::Thread) { if (auto NumThreads = FD->getAttr()) { - if (NumThreads->getX() != 1 || NumThreads->getY() != 1 || - NumThreads->getZ() != 1) { + uint32_t x = getIntConstAttrArg(S, NumThreads, 1, NumThreads->getX()); + uint32_t y = getIntConstAttrArg(S, NumThreads, 2, NumThreads->getY()); + uint32_t z = getIntConstAttrArg(S, NumThreads, 3, NumThreads->getZ()); + if (x != 1 || y != 1 || z != 1) { S.Diags.Report(NumThreads->getLocation(), diag::err_hlsl_wg_thread_launch_group_size) << NumThreads->getRange(); diff --git a/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.derivative-group.hlsl b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.derivative-group.hlsl new file mode 100644 index 0000000000..cde93df5b6 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.derivative-group.hlsl @@ -0,0 +1,35 @@ +// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl %s -spirv 2>&1 | FileCheck -check-prefix=CHECK-LINEAR %s +// RUN: %dxc -T cs_6_6 -E main -fspv-extension=SPV_NV_compute_shader_derivatives -fcgl -DQUADS %s -spirv 2>&1 | FileCheck -check-prefix=CHECK-QUADS %s + +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +// CHECK-LINEAR: OpCapability ComputeDerivativeGroupLinearKHR +// CHECK-LINEAR: OpExecutionMode %{{[^ ]*}} DerivativeGroupLinearKHR +// CHECK-QUADS: OpCapability ComputeDerivativeGroupQuadsKHR +// CHECK-QUADS: OpExecutionMode %{{[^ ]*}} DerivativeGroupQuadsKHR + +SamplerState ss : register(s2); +SamplerComparisonState scs; + +RWStructuredBuffer o; +Texture1D t1; + +#ifdef QUADS +[[vk::constant_id(0)]] +const uint NumThreadsX = 2; +[[vk::constant_id(1)]] +const uint NumThreadsY = 2; +#else +[[vk::constant_id(0)]] +const uint NumThreadsX = 24; +[[vk::constant_id(1)]] +const uint NumThreadsY = 1; +#endif + +[numthreads(NumThreadsX,NumThreadsY,1)] +void main(uint3 id : SV_GroupThreadID) +{ + o[0] = t1.Sample(ss, 1); +} + diff --git a/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl new file mode 100644 index 0000000000..7cdd2c45af --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/vk.spec-constant.attributes.hlsl @@ -0,0 +1,67 @@ +// RUN: %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 -DSPEC=1 %s | FileCheck %s +// RUN: not %dxc -spirv -Vd -Od -T lib_6_8 -fspv-target-env=vulkan1.3 %s 2>&1 | FileCheck --check-prefix=NOSPEC %s + +// Note: validation disabled until NodePayloadAMDX pointers are allowed +// as function arguments + +struct InputPayload { + uint grid : SV_DispatchGrid; +}; + +struct OutputPayload { + uint foo; +}; + +#ifdef SPEC +[[vk::constant_id(0)]] const +#endif +uint MaxPayloads = 1; +#ifdef SPEC +[[vk::constant_id(1)]] const +#endif +uint WorkgroupSizeX = 1; +#ifdef SPEC +[[vk::constant_id(2)]] const +#endif +uint ShaderIndex = 0; +#ifdef SPEC +[[vk::constant_id(3)]] const +#endif +uint NumThreadsX = 512; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NumThreads(NumThreadsX, 1, 1)] +[NodeDispatchGrid(WorkgroupSizeX, 1, 1)] + +void main(const uint svGroupIndex : SV_GroupIndex, + DispatchNodeInputRecord inputRecord, + [NodeID("main", ShaderIndex)] + [MaxRecords(MaxPayloads)] + NodeOutput nodeOutput) { + ThreadNodeOutputRecords outRec = nodeOutput.GetThreadNodeOutputRecords(1); + outRec.OutputComplete(); +} + +// CHECK: OpExecutionModeId %{{[_0-9A-Za-z]*}} LocalSizeId [[NUMTHREADSX:%[_0-9A-Za-z]*]] [[U1:%[_0-9A-Za-z]*]] [[U1]] +// CHECK: OpExecutionModeId %{{[_0-9A-Za-z]*}} StaticNumWorkgroupsAMDX [[WGSIZEX:%[_0-9A-Za-z]*]] [[U1]] [[U1]] +// CHECK: OpDecorate [[MAXPAYLOADS:%[_0-9A-Za-z]*]] SpecId 0 +// CHECK: OpDecorate [[WGSIZEX]] SpecId 1 +// CHECK: OpDecorate [[SHADERINDEX:%[_0-9A-Za-z]*]] SpecId 2 +// CHECK: OpDecorate [[NUMTHREADSX]] SpecId 3 +// CHECK: OpDecorateId %{{[_0-9A-Za-z]*}} NodeMaxPayloadsAMDX [[U1:%[_0-9A-Za-z]*]] +// CHECK-DAG: OpDecorateId %{{[_0-9A-Za-z]*}} PayloadNodeBaseIndexAMDX [[SHADERINDEX]] +// CHECK-DAG: OpDecorateId %{{[_0-9A-Za-z]*}} NodeMaxPayloadsAMDX [[MAXPAYLOADS]] +// CHECK: [[UINT:%[_0-9A-Za-z]*]] = OpTypeInt 32 0 +// CHECK-DAG: [[U0:%[_0-9A-Za-z]*]] = OpConstant [[UINT]] 0 +// CHECK-DAG: [[U1]] = OpConstant [[UINT]] 1 +// CHECK-DAG: [[MAXPAYLOADS:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 1 +// CHECK-DAG: [[WGSIZEX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 1 +// CHECK-DAG: [[SHADERINDEX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 0 +// CHECK-DAG: [[NUMTHREADSX:%[_0-9A-Za-z]*]] = OpSpecConstant [[UINT]] 512 + +// NOSPEC-DAG: error: 'MaxRecords' attribute requires an integer constant +// NOSPEC-DAG: error: 'NodeID' attribute requires an integer constant +// NOSPEC-DAG: error: 'NodeDispatchGrid' attribute requires an integer constant +// NOSPEC-DAG: error: 'NumThreads' attribute requires an integer constant + diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl index 10335ee864..d5645e2849 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/sparsenodes.hlsl @@ -10,34 +10,40 @@ // AST: FunctionDecl {{.*}} node_1_0 'void (NodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 31 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 31 // AST-NEXT: HLSLNodeArraySizeAttr {{.*}} 129 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_1_1 'void (NodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_1_1 'NodeOutputArray':'NodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 37 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 37 // AST-NEXT: HLSLUnboundedSparseNodesAttr // AST: FunctionDecl {{.*}} node_1_2 'void (NodeOutput)' // AST-NEXT: ParmVarDecl {{.*}} used Output_1_2 'NodeOutput':'NodeOutput' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 47 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 47 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_0 'void (EmptyNodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_2_0 'EmptyNodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 41 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 41 // AST-NEXT: HLSLNodeArraySizeAttr {{.*}} 131 // AST-NEXT: HLSLAllowSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_1 'void (EmptyNodeOutputArray)' // AST-NEXT: ParmVarDecl {{.*}} used OutputArray_2_1 'EmptyNodeOutputArray' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 43 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 43 // AST-NEXT: HLSLUnboundedSparseNodesAttr // AST: FunctionDecl {{.*}} node_2_2 'void (EmptyNodeOutput)' // AST-NEXT: ParmVarDecl {{.*}} used Output_2_2 'EmptyNodeOutput' -// AST-NEXT: HLSLMaxRecordsAttr {{.*}} 53 +// AST-NEXT: HLSLMaxRecordsAttr {{.*}} +// AST-NEXT: IntegerLiteral {{.*}} 53 // AST-NEXT: HLSLAllowSparseNodesAttr // ==== -fcgl Metadata Checks ==== diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl index 9343ad9831..6c69441468 100644 --- a/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl +++ b/tools/clang/test/HLSLFileCheck/shader_targets/nodes/NodeOutput.hlsl @@ -23,20 +23,28 @@ struct MY_INPUT_RECORD // CHECK:FunctionDecl 0x{{.*}} myFancyNode 'void (GroupNodeInputRecords, NodeOutput, NodeOutput, NodeOutputArray, EmptyNodeOutput)' // CHECK-NEXT:ParmVarDecl 0x{{.*}} myInput 'GroupNodeInputRecords':'GroupNodeInputRecords' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 // CHECK-NEXT: ParmVarDecl 0x{{.*}} myFascinatingNode 'NodeOutput':'NodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 // CHECK-NEXT: ParmVarDecl 0x{{.*}} myRecords 'NodeOutput':'NodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 4 -// CHECK-NEXT: HLSLNodeIdAttr 0x{{.*}} "myNiftyNode" 3 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 +// CHECK-NEXT: HLSLNodeIdAttr 0x{{.*}} "myNiftyNode" +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 3 // CHECK-NEXT: ParmVarDecl 0x{{.*}} col:65 myMaterials 'NodeOutputArray' // CHECK-NEXT:HLSLNodeArraySizeAttr 0x{{.*}} 63 // CHECK-NEXT:HLSLAllowSparseNodesAttr 0x{{.*}} // CHECK-NEXT:HLSLMaxRecordsSharedWithAttr 0x{{.*}} myRecords // CHECK-NEXT:ParmVarDecl 0x{{.*}} myProgressCounter 'EmptyNodeOutput' -// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} 20 +// CHECK-NEXT: HLSLMaxRecordsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 20 // CHECK-NEXT: CompoundStmt 0x -// CHECK-NEXT: HLSLNumThreadsAttr 0x{{.*}} 4 5 6 +// CHECK-NEXT: HLSLNumThreadsAttr 0x{{.*}} +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 4 +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 5 +// CHECK-NEXT: IntegerLiteral 0x{{.*}} 6 // CHECK-NEXT: HLSLNodeLaunchAttr 0x{{.*}} "coalescing" // CHECK-NEXT: HLSLShaderAttr 0x{{.*}} "node" [Shader("node")] diff --git a/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl b/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl index 2fd316bba6..5d2a684178 100644 --- a/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/objects/NodeObjects/node-objects.hlsl @@ -40,8 +40,14 @@ struct RECORD // AST: `-DeclRefExpr 0x{{.+}} 'DispatchNodeInputRecord':'DispatchNodeInputRecord' lvalue ParmVar 0x[[Param]] 'input' 'DispatchNodeInputRecord':'DispatchNodeInputRecord' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 64 1 1 -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 64 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" @@ -67,8 +73,14 @@ void node_DispatchNodeInputRecord(DispatchNodeInputRecord input) // AST: `-DeclRefExpr 0x{{.+}} 'RWDispatchNodeInputRecord':'RWDispatchNodeInputRecord' lvalue ParmVar 0x[[Param]] 'input' 'RWDispatchNodeInputRecord':'RWDispatchNodeInputRecord' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 16 1 1 -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 16 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -85,7 +97,8 @@ void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord input) // AST: FunctionDecl 0x{{.+}} node_GroupNodeInputRecords 'void (GroupNodeInputRecords)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] col:81 used inputs 'GroupNodeInputRecords':'GroupNodeInputRecords' -// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} 256 +// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: `-IntegerLiteral 0x{{.+}} 256 // call to wrapper // AST: `-CallExpr 0x{{.+}} 'GroupNodeInputRecords':'GroupNodeInputRecords' // AST: |-ImplicitCastExpr 0x{{.+}} 'GroupNodeInputRecords (*)(GroupNodeInputRecords)' @@ -94,7 +107,10 @@ void node_RWDispatchNodeInputRecord(RWDispatchNodeInputRecord input) // AST: `-DeclRefExpr 0x{{.+}} 'GroupNodeInputRecords':'GroupNodeInputRecords' lvalue ParmVar 0x[[Param]] 'inputs' 'GroupNodeInputRecords':'GroupNodeInputRecords' // attributes. // AST: |-HLSLNodeIsProgramEntryAttr 0x{{.+}} -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1024 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" // AST: `-HLSLShaderAttr 0x{{.+}} "node" @@ -112,7 +128,8 @@ void node_GroupNodeInputRecords([MaxRecords(256)] GroupNodeInputRecords // AST: FunctionDecl 0x{{.+}} node_RWGroupNodeInputRecords 'void (RWGroupNodeInputRecords)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] col:84 used input2 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' -// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} 4 +// AST: | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | `-IntegerLiteral 0x{{.+}} 4 // call to wrapper // AST: CallExpr 0x{{.+}} 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' // AST: |-ImplicitCastExpr 0x{{.+}} 'RWGroupNodeInputRecords (*)(RWGroupNodeInputRecords)' @@ -121,7 +138,10 @@ void node_GroupNodeInputRecords([MaxRecords(256)] GroupNodeInputRecords // AST: | `-DeclRefExpr 0x{{.+}} 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' lvalue ParmVar 0x[[Param]] 'input2' 'RWGroupNodeInputRecords':'RWGroupNodeInputRecords' // attributes. // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: `-HLSLShaderAttr 0x{{.+}} "node" struct RECORD2 @@ -192,7 +212,10 @@ void node_RWThreadNodeInputRecord(RWThreadNodeInputRecord input) // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeInput' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeInput' lvalue ParmVar 0x[[Param]] 'input' 'EmptyNodeInput' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 2 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 2 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeIsProgramEntryAttr // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "coalescing" // AST: | `-HLSLShaderAttr 0x{{.+}}> "node" @@ -219,8 +242,14 @@ void node_EmptyNodeInput(EmptyNodeInput input) // AST: `-DeclRefExpr 0x{{.+}} 'NodeOutput':'NodeOutput' lvalue ParmVar 0x[[Param]] 'output3' 'NodeOutput':'NodeOutput' // attributes. // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 32 1 1 -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1024 1 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 32 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1024 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -239,7 +268,8 @@ void node_NodeOutput(NodeOutput output3) // EmptyNodeOutput // AST: FunctionDecl 0x{{.+}} node_EmptyNodeOutput 'void (EmptyNodeOutput)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:35 used loadStressChild 'EmptyNodeOutput' -// AST: | | `-HLSLMaxRecordsAttr 0x{{.+}} 12 +// AST: | | `-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | `-IntegerLiteral 0x{{.+}} 12 // call to wrapper // AST: CallExpr 0x{{.+}} 'EmptyNodeOutput':'EmptyNodeOutput' // AST: |-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutput (*)(EmptyNodeOutput)' @@ -247,8 +277,14 @@ void node_NodeOutput(NodeOutput output3) // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutput' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeOutput' lvalue ParmVar 0x[[Param]] 'loadStressChild' 'EmptyNodeOutput' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | `-HLSLShaderAttr 0x{{.+}} "node" void loadStressEmptyRecWorker( EmptyNodeOutput outputNode) @@ -269,7 +305,8 @@ void node_EmptyNodeOutput( // NodeOutputArray // AST: FunctionDecl 0x{{.+}} node_NodeOutputArray 'void (NodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:30 used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 31 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 31 // AST: | | |-HLSLNodeArraySizeAttr 0x{{.+}} 129 // AST: | | `-HLSLAllowSparseNodesAttr 0x{{.+}} // call to wrapper @@ -279,8 +316,14 @@ void node_EmptyNodeOutput( // AST: `-ImplicitCastExpr 0x{{.+}} 'NodeOutputArray':'NodeOutputArray' // AST: `-DeclRefExpr 0x{{.+}} 'NodeOutputArray':'NodeOutputArray' lvalue ParmVar 0x[[Param]] 'OutputArray_1_0' 'NodeOutputArray':'NodeOutputArray' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" @@ -304,7 +347,8 @@ void node_NodeOutputArray( // EmptyNodeOutputArray // AST: FunctionDecl 0x{{.+}} node_EmptyNodeOutputArray 'void (EmptyNodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:64 used EmptyOutputArray 'EmptyNodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 64 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 64 // AST: | | `-HLSLNodeArraySizeAttr 0x{{.+}} 128 // call to wrapper // AST: CallExpr 0x{{.+}} 'EmptyNodeOutputArray':'EmptyNodeOutputArray' @@ -313,8 +357,14 @@ void node_NodeOutputArray( // AST: `-ImplicitCastExpr 0x{{.+}} 'EmptyNodeOutputArray' // AST: `-DeclRefExpr 0x{{.+}} 'EmptyNodeOutputArray' lvalue ParmVar 0x[[Param]] 'EmptyOutputArray' 'EmptyNodeOutputArray' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 128 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 128 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" @@ -335,7 +385,8 @@ void node_EmptyNodeOutputArray( // GroupNodeOutputRecords // AST: FunctionDecl 0x{{.+}} node_GroupNodeOutputRecords 'void (NodeOutputArray)' // AST: | |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} col:68 used OutputArray 'NodeOutputArray':'NodeOutputArray' -// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} 64 +// AST: | | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | | `-IntegerLiteral 0x{{.+}} 64 // AST: | | `-HLSLNodeArraySizeAttr 0x{{.+}} 128 // call to wrapper // AST: CallExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' @@ -344,8 +395,14 @@ void node_EmptyNodeOutputArray( // AST: `-ImplicitCastExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' // AST: `-DeclRefExpr 0x{{.+}} 'GroupNodeOutputRecords':'GroupNodeOutputRecords' lvalue Var 0x{{.+}} 'outRec' 'GroupNodeOutputRecords':'GroupNodeOutputRecords' // attributes. -// AST: | |-HLSLNumThreadsAttr 0x{{.+}} 128 1 1 -// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: | |-HLSLNumThreadsAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 128 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 +// AST: | |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | |-IntegerLiteral 0x{{.+}} 1 +// AST: | | `-IntegerLiteral 0x{{.+}} 1 // AST: | |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: | `-HLSLShaderAttr 0x{{.+}} "node" [Shader("node")] @@ -366,7 +423,8 @@ void node_GroupNodeOutputRecords( // ThreadNodeOutputRecords // AST: FunctionDecl 0x{{.+}} node_ThreadNodeOutputRecords 'void (NodeOutputArray)' // AST: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} used OutputArray_1_0 'NodeOutputArray':'NodeOutputArray' -// AST: | |-HLSLMaxRecordsAttr 0x{{.+}} 31 +// AST: | |-HLSLMaxRecordsAttr 0x{{.+}} +// AST: | | `-IntegerLiteral 0x{{.+}} 31 // AST: | |-HLSLNodeArraySizeAttr 0x{{.+}} 129 // AST: | `-HLSLAllowSparseNodesAttr 0x{{.+}} // call to wrapper @@ -376,8 +434,14 @@ void node_GroupNodeOutputRecords( // AST: `-ImplicitCastExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' // AST: `-DeclRefExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' lvalue Var 0x{{.+}} 'outRec' 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' // attributes. -// AST: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// AST: |-HLSLNumThreadsAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 +// AST: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: |-IntegerLiteral 0x{{.+}} 1 +// AST: `-IntegerLiteral 0x{{.+}} 1 // AST: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // AST: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl index aa5a7123f6..278161d65d 100644 --- a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-EmptyNodeOutputArrayTypes.hlsl @@ -54,7 +54,8 @@ void node_2_0( // CHECK: `-FunctionDecl 0x{{.+}} node_2_0 'void (EmptyNodeOutputArray)' // CHECK-NEXT: |-ParmVarDecl 0x[[Param:[0-9a-f]+]] {{.+}} used OutputArray_2_0 'EmptyNodeOutputArray' -// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} 41 +// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} +// CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}} 'literal int' 41 // CHECK-NEXT: | |-HLSLNodeArraySizeAttr 0x{{.+}} 131 // CHECK-NEXT: | `-HLSLAllowSparseNodesAttr 0x{{.+}} // CHECK-NEXT: |-CompoundStmt 0x{{.+}} @@ -68,7 +69,13 @@ void node_2_0( // CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}}{{.+}} 'literal int' 1 // CHECK-NEXT: | `-ImplicitCastExpr 0x{{.+}} 'unsigned int' // CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 10 -// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 // CHECK-NEXT: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // CHECK-NEXT: `-HLSLShaderAttr 0x{{.+}} "node" diff --git a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl index 6d381c5953..5bf163edbf 100644 --- a/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl +++ b/tools/clang/test/SemaHLSL/hlsl/workgraph/ast-NodeOutputArrayTypes.hlsl @@ -152,7 +152,8 @@ void node_1_1( // CHECK:`-FunctionDecl 0x{{.+}} line:16:6 node_1_1 'void (NodeOutputArray)' // CHECK-NEXT: |-ParmVarDecl 0x[[ParmVar:[0-9a-f]+]] col:30 used OutputArray_1_1 'NodeOutputArray':'NodeOutputArray' -// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} 37 +// CHECK-NEXT: | |-HLSLMaxRecordsAttr 0x{{.+}} +// CHECK-NEXT: | | `-IntegerLiteral 0x{{.+}} 'literal int' 37 // CHECK-NEXT: | `-HLSLUnboundedSparseNodesAttr 0x{{.+}} // CHECK-NEXT: |-CompoundStmt 0x{{.+}} // CHECK-NEXT: | |-DeclStmt 0x{{.+}} @@ -170,7 +171,13 @@ void node_1_1( // CHECK-NEXT: | `-CXXMemberCallExpr 0x{{.+}} 'void' // CHECK-NEXT: | `-MemberExpr 0x{{.+}} '' .OutputComplete 0x[[OutComplete]] // CHECK-NEXT: | `-DeclRefExpr 0x{{.+}} 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' lvalue Var 0x[[OutRec]] 'outRec' 'ThreadNodeOutputRecords':'ThreadNodeOutputRecords' -// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} 1 1 1 -// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} 1 1 1 +// CHECK-NEXT: |-HLSLNumThreadsAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: |-HLSLNodeDispatchGridAttr 0x{{.+}} +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | |-IntegerLiteral 0x{{.+}} 'literal int' 1 +// CHECK-NEXT: | `-IntegerLiteral 0x{{.+}} 'literal int' 1 // CHECK-NEXT: |-HLSLNodeLaunchAttr 0x{{.+}} "broadcasting" // CHECK-NEXT: `-HLSLShaderAttr 0x{{.+}} "node"