Skip to content

Commit 9221570

Browse files
authored
[Validator] Check size of PSV. (microsoft#6924)
Check size of PSV part matches the PSVVersion. Updated DxilPipelineStateValidation::ReadOrWrite to read based on initInfo.PSVVersion. And return fail when size mismatch in RWMode::Read. Fixes microsoft#6817
1 parent b05313c commit 9221570

File tree

7 files changed

+816
-55
lines changed

7 files changed

+816
-55
lines changed

docs/ReleaseNotes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Place release notes for the upcoming release below this line and remove this lin
2323

2424
- The incomplete WaveMatrix implementation has been removed.
2525
- DXIL Validator Hash is open sourced.
26+
- DXIL container validation for PSV0 part allows any content ordering inside string and semantic index tables.
2627

2728
### Version 1.8.2407
2829

include/dxc/DxilContainer/DxilPipelineStateValidation.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ struct PSVStringTable {
226226
PSVStringTable() : Table(nullptr), Size(0) {}
227227
PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {}
228228
const char *Get(uint32_t offset) const {
229-
assert(offset < Size && Table && Table[Size - 1] == '\0');
229+
if (!(offset < Size && Table && Table[Size - 1] == '\0'))
230+
return nullptr;
230231
return Table + offset;
231232
}
232233
};
@@ -344,7 +345,8 @@ struct PSVSemanticIndexTable {
344345
PSVSemanticIndexTable(const uint32_t *table, uint32_t entries)
345346
: Table(table), Entries(entries) {}
346347
const uint32_t *Get(uint32_t offset) const {
347-
assert(offset < Entries && Table);
348+
if (!(offset < Entries && Table))
349+
return nullptr;
348350
return Table + offset;
349351
}
350352
};
@@ -638,7 +640,8 @@ class DxilPipelineStateValidation {
638640
_T *GetRecord(void *pRecords, uint32_t recordSize, uint32_t numRecords,
639641
uint32_t index) const {
640642
if (pRecords && index < numRecords && sizeof(_T) <= recordSize) {
641-
assert((size_t)index * (size_t)recordSize <= UINT_MAX);
643+
if (!((size_t)index * (size_t)recordSize <= UINT_MAX))
644+
return nullptr;
642645
return reinterpret_cast<_T *>(reinterpret_cast<uint8_t *>(pRecords) +
643646
(index * recordSize));
644647
}
@@ -1126,6 +1129,10 @@ void InitPSVSignatureElement(PSVSignatureElement0 &E,
11261129
const DxilSignatureElement &SE,
11271130
bool i1ToUnknownCompat);
11281131

1132+
// Setup PSVInitInfo with DxilModule.
1133+
// Note that the StringTable and PSVSemanticIndexTable are not done.
1134+
void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM);
1135+
11291136
// Setup shader properties for PSVRuntimeInfo* with DxilModule.
11301137
void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM);
11311138
void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM);

lib/DxilContainer/DxilContainerAssembler.cpp

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -738,30 +738,14 @@ class DxilPSVWriter : public DxilPartWriter {
738738
DxilPSVWriter(const DxilModule &mod, uint32_t PSVVersion = UINT_MAX)
739739
: m_Module(mod), m_PSVInitInfo(PSVVersion) {
740740
m_Module.GetValidatorVersion(m_ValMajor, m_ValMinor);
741-
// Constraint PSVVersion based on validator version
742-
uint32_t PSVVersionConstraint = hlsl::GetPSVVersion(m_ValMajor, m_ValMinor);
743-
if (PSVVersion > PSVVersionConstraint)
744-
m_PSVInitInfo.PSVVersion = PSVVersionConstraint;
745-
746-
const ShaderModel *SM = m_Module.GetShaderModel();
747-
UINT uCBuffers = m_Module.GetCBuffers().size();
748-
UINT uSamplers = m_Module.GetSamplers().size();
749-
UINT uSRVs = m_Module.GetSRVs().size();
750-
UINT uUAVs = m_Module.GetUAVs().size();
751-
m_PSVInitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs;
741+
hlsl::SetupPSVInitInfo(m_PSVInitInfo, m_Module);
742+
752743
// TODO: for >= 6.2 version, create more efficient structure
753744
if (m_PSVInitInfo.PSVVersion > 0) {
754-
m_PSVInitInfo.ShaderStage = (PSVShaderKind)SM->GetKind();
755745
// Copy Dxil Signatures
756746
m_StringBuffer.push_back('\0'); // For empty semantic name (system value)
757-
m_PSVInitInfo.SigInputElements =
758-
m_Module.GetInputSignature().GetElements().size();
759747
m_SigInputElements.resize(m_PSVInitInfo.SigInputElements);
760-
m_PSVInitInfo.SigOutputElements =
761-
m_Module.GetOutputSignature().GetElements().size();
762748
m_SigOutputElements.resize(m_PSVInitInfo.SigOutputElements);
763-
m_PSVInitInfo.SigPatchConstOrPrimElements =
764-
m_Module.GetPatchConstOrPrimSignature().GetElements().size();
765749
m_SigPatchConstOrPrimElements.resize(
766750
m_PSVInitInfo.SigPatchConstOrPrimElements);
767751
uint32_t i = 0;
@@ -791,20 +775,6 @@ class DxilPSVWriter : public DxilPartWriter {
791775
m_PSVInitInfo.StringTable.Size = m_StringBuffer.size();
792776
m_PSVInitInfo.SemanticIndexTable.Table = m_SemanticIndexBuffer.data();
793777
m_PSVInitInfo.SemanticIndexTable.Entries = m_SemanticIndexBuffer.size();
794-
// Set up ViewID and signature dependency info
795-
m_PSVInitInfo.UsesViewID =
796-
m_Module.m_ShaderFlags.GetViewID() ? true : false;
797-
m_PSVInitInfo.SigInputVectors =
798-
m_Module.GetInputSignature().NumVectorsUsed(0);
799-
for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
800-
m_PSVInitInfo.SigOutputVectors[streamIndex] =
801-
m_Module.GetOutputSignature().NumVectorsUsed(streamIndex);
802-
}
803-
m_PSVInitInfo.SigPatchConstOrPrimVectors = 0;
804-
if (SM->IsHS() || SM->IsDS() || SM->IsMS()) {
805-
m_PSVInitInfo.SigPatchConstOrPrimVectors =
806-
m_Module.GetPatchConstOrPrimSignature().NumVectorsUsed(0);
807-
}
808778
}
809779
if (!m_PSV.InitNew(m_PSVInitInfo, nullptr, &m_PSVBufferSize)) {
810780
DXASSERT(false, "PSV InitNew failed computing size!");

lib/DxilContainer/DxilPipelineStateValidation.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,43 @@ void hlsl::InitPSVSignatureElement(PSVSignatureElement0 &E,
110110
E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF;
111111
}
112112

113+
void hlsl::SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM) {
114+
// Constraint PSVVersion based on validator version
115+
unsigned ValMajor, ValMinor;
116+
DM.GetValidatorVersion(ValMajor, ValMinor);
117+
unsigned PSVVersionConstraint = hlsl::GetPSVVersion(ValMajor, ValMinor);
118+
if (InitInfo.PSVVersion > PSVVersionConstraint)
119+
InitInfo.PSVVersion = PSVVersionConstraint;
120+
121+
const ShaderModel *SM = DM.GetShaderModel();
122+
uint32_t uCBuffers = DM.GetCBuffers().size();
123+
uint32_t uSamplers = DM.GetSamplers().size();
124+
uint32_t uSRVs = DM.GetSRVs().size();
125+
uint32_t uUAVs = DM.GetUAVs().size();
126+
InitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs;
127+
128+
if (InitInfo.PSVVersion > 0) {
129+
InitInfo.ShaderStage = (PSVShaderKind)SM->GetKind();
130+
InitInfo.SigInputElements = DM.GetInputSignature().GetElements().size();
131+
InitInfo.SigPatchConstOrPrimElements =
132+
DM.GetPatchConstOrPrimSignature().GetElements().size();
133+
InitInfo.SigOutputElements = DM.GetOutputSignature().GetElements().size();
134+
135+
// Set up ViewID and signature dependency info
136+
InitInfo.UsesViewID = DM.m_ShaderFlags.GetViewID() ? true : false;
137+
InitInfo.SigInputVectors = DM.GetInputSignature().NumVectorsUsed(0);
138+
for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) {
139+
InitInfo.SigOutputVectors[streamIndex] =
140+
DM.GetOutputSignature().NumVectorsUsed(streamIndex);
141+
}
142+
InitInfo.SigPatchConstOrPrimVectors = 0;
143+
if (SM->IsHS() || SM->IsDS() || SM->IsMS()) {
144+
InitInfo.SigPatchConstOrPrimVectors =
145+
DM.GetPatchConstOrPrimSignature().NumVectorsUsed(0);
146+
}
147+
}
148+
}
149+
113150
void hlsl::SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM) {
114151
const ShaderModel *SM = DM.GetShaderModel();
115152
pInfo->MinimumExpectedWaveLaneCount = 0;

0 commit comments

Comments
 (0)