Skip to content
Open
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilMetadataHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class DxilMDHelper {
static const unsigned kDxilNodeMaxRecordsSharedWithTag = 4;
static const unsigned kDxilNodeOutputArraySizeTag = 5;
static const unsigned kDxilNodeAllowSparseNodesTag = 6;
static const unsigned kDxilNodeMaxRecordsPerNodeTag = 7;

// Node Record Type
static const unsigned kDxilNodeRecordSizeTag = 0;
Expand Down
1 change: 1 addition & 0 deletions include/dxc/DXIL/DxilNodeProps.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ struct NodeIOProperties {
int MaxRecordsSharedWith = -1;
unsigned OutputArraySize = 0;
bool AllowSparseNodes = false;
unsigned MaxRecordsPerNode = 0;

public:
NodeIOProperties() {}
Expand Down
4 changes: 4 additions & 0 deletions include/dxc/DxilContainer/RDAT_LibraryTypes.inl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ RDAT_ENUM_START(NodeAttribKind, uint32_t)
RDAT_ENUM_VALUE(OutputArraySize, 6)
RDAT_ENUM_VALUE(AllowSparseNodes, 7)
RDAT_ENUM_VALUE(RecordAlignmentInBytes, 8)
RDAT_ENUM_VALUE(MaxRecordsPerNode, 9)
RDAT_ENUM_VALUE_NODEF(LastValue)
RDAT_ENUM_END()

Expand Down Expand Up @@ -415,6 +416,9 @@ RDAT_STRUCT_TABLE(NodeShaderIOAttrib, NodeShaderIOAttribTable)
getAttribKind() ==
hlsl::RDAT::NodeAttribKind::RecordAlignmentInBytes)
RDAT_VALUE(uint32_t, RecordAlignmentInBytes)
RDAT_UNION_ELIF(MaxRecordsPerNode,
getAttribKind() == hlsl::RDAT::NodeAttribKind::MaxRecordsPerNode)
RDAT_VALUE(uint32_t, MaxRecordsPerNode)
RDAT_UNION_ENDIF()
RDAT_UNION_END()
RDAT_STRUCT_END()
Expand Down
12 changes: 12 additions & 0 deletions lib/DXIL/DxilMetadataHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,7 @@ void DxilMDHelper::SerializeNodeProps(SmallVectorImpl<llvm::Metadata *> &MDVals,
MDVals.push_back(Uint32ToConstMD(nodeoutput.OutputArraySize));
MDVals.push_back(BoolToConstMD(nodeoutput.AllowSparseNodes));
MDVals.push_back(Uint32ToConstMD(nodeoutput.RecordType.alignment));
MDVals.push_back(Uint32ToConstMD(nodeoutput.MaxRecordsPerNode));
}
}

Expand Down Expand Up @@ -2047,6 +2048,7 @@ void DxilMDHelper::DeserializeNodeProps(const MDTuple *pProps, unsigned &idx,
nodeoutput.RecordType.alignment =
ConstMDToUint32(pProps->getOperand(idx++));
}
nodeoutput.MaxRecordsPerNode = ConstMDToUint32(pProps->getOperand(idx++));
}
}

Expand Down Expand Up @@ -2832,6 +2834,13 @@ DxilMDHelper::EmitDxilNodeIOState(const hlsl::NodeIOProperties &Node) {
NodeOpIDVals.emplace_back(Uint32ToConstMD(Node.OutputID.Index));
MDVals.emplace_back(MDNode::get(m_Ctx, NodeOpIDVals));
}

if (Node.MaxRecordsPerNode) {
MDVals.emplace_back(
Uint32ToConstMD(DxilMDHelper::kDxilNodeMaxRecordsPerNodeTag));
MDVals.emplace_back(Uint32ToConstMD(Node.MaxRecordsPerNode));
}

} else {
DXASSERT(Node.Flags.IsInputRecord(), "Invalid NodeIO Kind");
if (Node.MaxRecords) {
Expand Down Expand Up @@ -2922,6 +2931,9 @@ NodeIOProperties DxilMDHelper::LoadDxilNodeIOState(const llvm::MDOperand &MDO) {
Node.OutputID.Name = StringMDToString(pNode->getOperand(0));
Node.OutputID.Index = ConstMDToUint32(pNode->getOperand(1));
} break;
case DxilMDHelper::kDxilNodeMaxRecordsPerNodeTag: {
Node.MaxRecordsPerNode = ConstMDToUint32(MDO);
} break;
default:
m_bExtraMetadata = true;
break;
Expand Down
8 changes: 8 additions & 0 deletions lib/DxilContainer/DxilContainerAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,14 @@ class DxilRDATWriter : public DxilPartWriter {
nAttrib.MaxRecordsSharedWith = N.MaxRecordsSharedWith;
nodeAttribs.push_back(Builder.InsertRecord(nAttrib));
}

if (N.MaxRecordsPerNode > 0) {
nAttrib = {};
nAttrib.AttribKind = (uint32_t)NodeAttribKind::MaxRecordsPerNode;
nAttrib.MaxRecordsPerNode = N.MaxRecordsPerNode;
nodeAttribs.push_back(Builder.InsertRecord(nAttrib));
}

if (N.AllowSparseNodes) {
nAttrib = {};
nAttrib.AttribKind = (uint32_t)RDAT::NodeAttribKind::AllowSparseNodes;
Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ bool IsHLSLObjectWithImplicitROMemberAccess(clang::QualType type);
bool IsHLSLRWNodeInputRecordType(clang::QualType type);
bool IsHLSLRONodeInputRecordType(clang::QualType type);
bool IsHLSLNodeOutputType(clang::QualType type);
bool IsHLSLNodeOutputArrayType(clang::QualType type);

DXIL::NodeIOKind GetNodeIOType(clang::QualType type);

Expand Down
6 changes: 6 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,12 @@ def HLSLMaxRecords : InheritableAttr {
let Documentation = [Undocumented];
}

def HLSLMaxRecordsPerNode : InheritableAttr {
let Spellings = [CXX11<"", "MaxRecordsPerNode", 2015>];
let Args = [IntArgument<"maxCount">];
let Documentation = [Undocumented];
}

def HLSLMaxRecordsSharedWith : InheritableParamAttr {
let Spellings = [CXX11<"", "maxrecordssharedwith", 2015>];
let Args = [IdentifierArgument<"Name">];
Expand Down
1 change: 1 addition & 0 deletions tools/clang/include/clang/Basic/DiagnosticGroups.td
Original file line number Diff line number Diff line change
Expand Up @@ -805,4 +805,5 @@ def HLSLParameterUsage : DiagGroup<"parameter-usage">;
def HLSLAvailability: DiagGroup<"hlsl-availability">;
def HLSLBarrier : DiagGroup<"hlsl-barrier">;
def HLSLLegacyLiterals : DiagGroup<"hlsl-legacy-literal">;
def HLSLMaxNodesPerRecordAttr : DiagGroup<"hlsl-require-max-records-per-node">;
// HLSL Change Ends
2 changes: 2 additions & 0 deletions tools/clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -7974,6 +7974,8 @@ def warn_hlsl_legacy_integer_literal_signedness: Warning<
InGroup<HLSLLegacyLiterals>, DefaultIgnore;
def err_hlsl_unsupported_semantic_index: Error<
"'%0' is defined with semantic index %1, but only values 0 through %2 are supported">;
def warn_hlsl_max_records_per_node_required_attribute: Warning<
"MaxRecordsPerNode is a required attribute SM6.9+">, InGroup<HLSLMaxNodesPerRecordAttr>, DefaultError;
// HLSL Change Ends

// SPIRV Change Starts
Expand Down
9 changes: 9 additions & 0 deletions tools/clang/lib/AST/HlslTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,15 @@ bool IsHLSLNodeOutputType(clang::QualType type) {
static_cast<uint32_t>(DXIL::NodeIOFlags::Output);
}

bool IsHLSLNodeOutputArrayType(clang::QualType type) {
return (static_cast<uint32_t>(GetNodeIOType(type)) &
(static_cast<uint32_t>(DXIL::NodeIOFlags::Output) |
static_cast<uint32_t>(DXIL::NodeIOFlags::NodeArray) |
static_cast<uint32_t>(DXIL::NodeIOFlags::RecordGranularityMask))) ==
(static_cast<uint32_t>(DXIL::NodeIOFlags::Output) |
static_cast<uint32_t>(DXIL::NodeIOFlags::NodeArray));
}

bool IsHLSLStructuredBufferType(clang::QualType type) {
if (const RecordType *RT = type->getAs<RecordType>()) {
StringRef name = RT->getDecl()->getName();
Expand Down
13 changes: 13 additions & 0 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2435,6 +2435,19 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
}
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
node.MaxRecords = Attr->getMaxCount();
// Check for SM6.9 Attributes
if (SM->IsSM69Plus()) {
// Though MaxRecordsPerNode is a required attribute, this requirement
// can be overriden, in which case when the attribute is not present
// set the value to MaxRecords.
node.MaxRecordsPerNode = node.MaxRecords;
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsPerNodeAttr>()) {
node.MaxRecordsPerNode = Attr->getMaxCount();
DXASSERT(node.MaxRecordsPerNode <= node.MaxRecords,
"MaxRecordsPerNode value should be less than or equal to the "
"MaxRecords value");
}
}
}

if (inputPatchCount > 1) {
Expand Down
1 change: 1 addition & 0 deletions tools/clang/lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ void Parser::ParseGNUAttributeArgs(IdentifierInfo *AttrName,
case AttributeList::AT_HLSLNodeMaxRecursionDepth:
case AttributeList::AT_HLSLMaxRecordsSharedWith:
case AttributeList::AT_HLSLMaxRecords:
case AttributeList::AT_HLSLMaxRecordsPerNode:
case AttributeList::AT_HLSLNodeArraySize:
case AttributeList::AT_HLSLRootSignature:
case AttributeList::AT_HLSLOutputControlPoints:
Expand Down
21 changes: 19 additions & 2 deletions tools/clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13381,6 +13381,12 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A,
}
break;
}
case AttributeList::AT_HLSLMaxRecordsPerNode: {
declAttr = new (S.Context) HLSLMaxRecordsPerNodeAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
A.getAttributeSpellingListIndex());
break;
}
case AttributeList::AT_HLSLNodeArraySize: {
declAttr = ::new (S.Context) HLSLNodeArraySizeAttr(
A.getRange(), S.Context, ValidateAttributeIntArg(S, A),
Expand Down Expand Up @@ -15964,7 +15970,7 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName,
for (ParmVarDecl *PD : FD->params()) {
QualType ParamType = PD->getType().getCanonicalType();

// Find parameter that is the node input record
// Find parameter that is the node output
if (hlsl::IsHLSLNodeOutputType(ParamType)) {
// Node records are template types
if (RecordDecl *NodeStructDecl =
Expand All @@ -15974,6 +15980,15 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName,
DiagnoseDispatchGridSemantics(S, NodeStructDecl, PD->getLocation(),
OutputFound);
}
if (hlsl::IsHLSLNodeOutputArrayType(ParamType)) {
std::string profile = S.getLangOpts().HLSLProfile;
const ShaderModel *SM = hlsl::ShaderModel::GetByName(profile.c_str());
if (SM->IsSM69Plus() && !PD->getAttr<HLSLMaxRecordsPerNodeAttr>()) {
S.Diags.Report(
PD->getLocation(),
diag::warn_hlsl_max_records_per_node_required_attribute);
}
}
}
}

Expand All @@ -15992,6 +16007,7 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName,
auto *NodeArraySizeAttr = Param->getAttr<HLSLNodeArraySizeAttr>();
auto *UnboundedSparseNodesAttr =
Param->getAttr<HLSLUnboundedSparseNodesAttr>();
auto *MaxRecordsPerNodeAttr = Param->getAttr<HLSLMaxRecordsPerNodeAttr>();
// Check any node input is compatible with the node launch type
if (hlsl::IsHLSLNodeInputType(ParamTy)) {
InputCount++;
Expand Down Expand Up @@ -16022,7 +16038,8 @@ void DiagnoseNodeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName,
// If node output is not an array, diagnose array only attributes
if (((uint32_t)GetNodeIOType(ParamTy) &
(uint32_t)DXIL::NodeIOFlags::NodeArray) == 0) {
Attr *ArrayAttrs[] = {NodeArraySizeAttr, UnboundedSparseNodesAttr};
Attr *ArrayAttrs[] = {NodeArraySizeAttr, UnboundedSparseNodesAttr,
MaxRecordsPerNodeAttr};
for (auto *A : ArrayAttrs) {
if (A) {
S.Diags.Report(A->getLocation(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: %dxc -T lib_6_9 %s | FileCheck %s

// Test for required attribute

struct RECORD1
{
uint value;
uint value2;
};
[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeDispatchGrid(1, 1, 1)]
[NumThreads(128, 1, 1)]
void node_1_2(
[NodeArraySize(128)] [MaxRecords(64)] NodeOutputArray<RECORD1> OutputArray
)
{
}
// CHECK: error: MaxRecordsPerNode is a required attribute SM6.9+
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: %dxc -T lib_6_9 %s | FileCheck %s

// Tests for [MaxRecordsPerNode] attribute

struct RECORD1
{
uint value;
uint value2;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeDispatchGrid(1, 1, 1)]
[NumThreads(128, 1, 1)]
void node_1_0(
[NodeArraySize(128)] [MaxRecords(64)] [MaxRecordsPerNode(16)] NodeOutputArray<RECORD1> OutputArray
)
{
}
// CHECK-NOT: error: MaxRecordsPerNode is a required attribute SM6.9+ [-Whlsl-require-max-records-per-node]

// Test Emission of metadata
[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeDispatchGrid(1, 1, 1)]
[NumThreads(128, 1, 1)]
void node_1_1(
[NodeArraySize(128)] [MaxRecords(64)] [MaxRecordsPerNode(32)] NodeOutputArray<RECORD1> OutputArray
)
{
}
// CHECK: define void @node_1_1()
// CHECK: ret void

// CHECK: !{void ()* @node_1_1, !"node_1_1", null, null, [[NODE_1_1:![0-9]+]]}
// CHECK: [[NODE_1_1]] = !{i32 8, i32 15, i32 13, i32 1, i32 15, !20, i32 16, i32 -1, i32 18, !11, i32 21, [[OUTPUTS:![0-9]+]]
// CHECK: [[OUTPUTS]] = !{[[OUTPUT1:![0-9]+]]}
// CHECK: [[OUTPUT1]] = !{i32 1, i32 22, i32 2, {{![0-9]+}}, i32 3, i32 64, i32 5, i32 128, i32 0, {{![0-9]+}}, i32 7, i32 32}

Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: %dxc -T lib_6_9 -Wno-hlsl-require-max-records-per-node %s | FileCheck %s

// Tests for overrriding [MaxRecordsPerNode] attribute requirement with -Wno
// Test MaxRecordsPerNode set to MaxRecords

struct RECORD1
{
uint value;
uint value2;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeDispatchGrid(1, 1, 1)]
[NumThreads(128, 1, 1)]
void node_1_0(
[NodeArraySize(128)] [MaxRecords(64)] NodeOutputArray<RECORD1> OutputArray
)
{
}

// CHECK-NOT: error: MaxRecordsPerNode is a required attribute SM6.9+ [-Whlsl-require-max-records-per-node]
// CHECK: !{void ()* @node_1_0, !"node_1_0", null, null, [[NODE_1_0:![0-9]+]]}
// CHECK: [[NODE_1_0]] = !{i32 8, i32 15, i32 13, i32 1, i32 15, !10, i32 16, i32 -1, i32 18, !11, i32 21, [[OUTPUTS:![0-9]+]]
// CHECK: [[OUTPUTS]] = !{[[OUTPUT1:![0-9]+]]}
// CHECK: [[OUTPUT1]] = !{i32 1, i32 22, i32 2, !14, i32 3, i32 [[MAXRECORDS:[0-9]+]], i32 5, i32 128, i32 0, !15, i32 7, i32 [[MAXRECORDS]]}

Loading