Skip to content

Commit eb35bbe

Browse files
committed
Improve handling of 16-bit output from vertex shaders on Vulkan
* We now promote all I/O variables to 32-bit minimum when converting them to private since we might not have the capabilities available to leave them as 16-bit. We instead snoop on any reads/writes to the private variables and insert the necessary width-expansions or contractions to convert between the stored & private 32-bit types and the shader's expected type.
1 parent c98a1c3 commit eb35bbe

File tree

1 file changed

+184
-61
lines changed

1 file changed

+184
-61
lines changed

renderdoc/driver/vulkan/vk_postvs.cpp

Lines changed: 184 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -97,23 +97,25 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
9797
rdcspv::Id variable;
9898
// constant ID for the index of this attribute
9999
rdcspv::Id indexConst;
100-
// base gvec4 type for this input. We always fetch uvec4 from the buffer but then bitcast to
101-
// vec4 or ivec4 if needed
100+
// only for inputs - we load as uvec4 and bitcast to this type (vec4/ivec4) as needed. This is a
101+
// 4-component vector always
102102
rdcspv::Id fetchVec4Type;
103-
// the actual gvec4 type for the input, possibly needed to convert to from the above if it's
104-
// declared as a 16-bit type since we always fetch 32-bit.
105-
rdcspv::Id vec4Type;
106-
// the base type for this attribute. Must be present already by definition! This is the same
107-
// scalar type as vec4Type but with the correct number of components.
103+
// the type with the right number of components but the component is rounded up to a 32-bit type
108104
rdcspv::Id baseType;
109105
// Uniform Pointer type ID for this output. Used only for output data, to write to output SSBO
106+
// underlying type is baseType
110107
rdcspv::Id ssboPtrType;
111108
// Output Pointer type ID for this attribute.
109+
// underlying type is baseType
112110
// For inputs, used to 'write' to the global at the start.
113111
// For outputs, used to 'read' from the global at the end.
114112
rdcspv::Id privatePtrType;
115113
};
116114

115+
rdcspv::Id uint32Type = editor.DeclareType(rdcspv::scalar<uint32_t>());
116+
rdcspv::Id sint32Type = editor.DeclareType(rdcspv::scalar<int32_t>());
117+
rdcspv::Id floatType = editor.DeclareType(rdcspv::scalar<float>());
118+
117119
rdcarray<inputOutputIDs> ins;
118120
ins.resize(numInputs);
119121
rdcarray<inputOutputIDs> outs;
@@ -123,6 +125,7 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
123125
std::set<rdcspv::Id> outputs;
124126

125127
std::map<rdcspv::Id, rdcspv::Id> typeReplacements;
128+
rdcarray<rdcspv::Id> expandedPtrTypes, expandedPtrVars;
126129

127130
// keep track of any builtins we're preserving
128131
std::set<rdcspv::Id> builtinKeeps;
@@ -238,6 +241,65 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
238241

239242
if(id)
240243
{
244+
rdcspv::DataType &dataType = editor.GetDataType(ptr.type);
245+
246+
rdcspv::Id expandedPtr;
247+
248+
// expand up input/output pointers to sub-32-bit types to be 32-bit
249+
if(dataType.scalar().width < 32 && dataType.scalar().width > 0)
250+
{
251+
VarType varType = dataType.scalar().Type();
252+
253+
if(varType == VarType::Half)
254+
varType = VarType::Float;
255+
else if(varType == VarType::SShort || varType == VarType::SByte)
256+
varType = VarType::SInt;
257+
else if(varType == VarType::UShort || varType == VarType::UByte)
258+
varType = VarType::UInt;
259+
260+
if(dataType.type == rdcspv::DataType::VectorType)
261+
{
262+
const uint32_t compCount = dataType.vector().count;
263+
expandedPtr = editor.GetType(rdcspv::Vector(rdcspv::scalar(varType), compCount));
264+
265+
// if this pointer doesn't exist, add it while preserving the iterator
266+
if(expandedPtr == rdcspv::Id())
267+
{
268+
if(varType == VarType::Float)
269+
expandedPtr = editor.AddOperation(
270+
it, rdcspv::OpTypeVector(editor.MakeId(), floatType, compCount));
271+
else if(varType == VarType::UInt)
272+
expandedPtr = editor.AddOperation(
273+
it, rdcspv::OpTypeVector(editor.MakeId(), uint32Type, compCount));
274+
else
275+
expandedPtr = editor.AddOperation(
276+
it, rdcspv::OpTypeVector(editor.MakeId(), sint32Type, compCount));
277+
++it;
278+
}
279+
}
280+
else
281+
{
282+
expandedPtr = editor.GetType(rdcspv::scalar(varType));
283+
284+
// if this pointer doesn't exist, add it while preserving the iterator
285+
if(expandedPtr == rdcspv::Id())
286+
{
287+
if(varType == VarType::Float)
288+
expandedPtr = editor.AddOperation(it, rdcspv::OpTypeFloat(editor.MakeId(), 32));
289+
else
290+
expandedPtr = editor.AddOperation(
291+
it, rdcspv::OpTypeInt(editor.MakeId(), 32, varType == VarType::SInt));
292+
++it;
293+
}
294+
}
295+
296+
ptr.type = expandedPtr;
297+
298+
// record the original pointer type so we can patch with conversions any loads/stores
299+
if(!expandedPtrTypes.contains(ptr.result))
300+
expandedPtrTypes.push_back(ptr.result);
301+
}
302+
241303
rdcspv::Pointer privPtr(ptr.type, rdcspv::StorageClass::Private);
242304

243305
rdcspv::Id origId = editor.GetType(privPtr);
@@ -269,6 +331,9 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
269331

270332
bool mod = false;
271333

334+
if(expandedPtrTypes.contains(var.resultType))
335+
expandedPtrVars.push_back(var.result);
336+
272337
if(builtinKeeps.find(var.result) != builtinKeeps.end())
273338
{
274339
// if this variable is one we're keeping as a builtin, we need to do something different.
@@ -410,6 +475,77 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
410475
}
411476
}
412477

478+
for(rdcspv::Iter it = editor.Begin(rdcspv::Section::Functions); it; ++it)
479+
{
480+
// identify any loads or stores via expanded pointer types, these will expect a different return
481+
// type so insert an appropriate conversion to expand/contract. The pointers themselves will be
482+
// handled either via globals being patched or access chains above being patched (at the same
483+
// time as we patch from input/output to private). The only thing remaining is the potential
484+
// type mismatch of storing a half to a float or loading a half from a float etc.
485+
//
486+
// we do this before patching the types in any OpAccessChain so we can identify such loads or
487+
// stores either due to using one of the old pointer types, or because the pointer is the global
488+
// directly (it must be one or ther other)
489+
if(it.opcode() == rdcspv::Op::Load)
490+
{
491+
rdcspv::OpLoad load(it);
492+
493+
rdcspv::Id ptrType = editor.GetIDType(load.pointer);
494+
495+
if(expandedPtrTypes.contains(ptrType) || expandedPtrVars.contains(load.pointer))
496+
{
497+
// this pointer was expanded, get the new type and update the load to a temp id
498+
rdcspv::Id tmpLoadedVal = editor.MakeId();
499+
500+
editor.PreModify(it);
501+
it.word(1) = editor.GetDataType(ptrType).InnerType().value();
502+
it.word(2) = tmpLoadedVal.value();
503+
editor.PostModify(it);
504+
505+
++it;
506+
507+
rdcspv::Scalar scalarType = editor.GetDataType(load.resultType).scalar();
508+
509+
if(scalarType.type == rdcspv::Op::TypeFloat)
510+
editor.AddOperation(it, rdcspv::OpFConvert(load.resultType, load.result, tmpLoadedVal));
511+
else if(scalarType.signedness)
512+
editor.AddOperation(it, rdcspv::OpSConvert(load.resultType, load.result, tmpLoadedVal));
513+
else
514+
editor.AddOperation(it, rdcspv::OpUConvert(load.resultType, load.result, tmpLoadedVal));
515+
}
516+
}
517+
else if(it.opcode() == rdcspv::Op::Store)
518+
{
519+
rdcspv::OpStore store(it);
520+
521+
rdcspv::Id ptrType = editor.GetIDType(store.pointer);
522+
523+
if(expandedPtrTypes.contains(ptrType) || expandedPtrVars.contains(store.pointer))
524+
{
525+
// this pointer was expanded, get the new type and update the store to use a temp id
526+
rdcspv::Id tmpStoreVal = editor.MakeId();
527+
528+
rdcspv::Id storedType = editor.GetDataType(ptrType).InnerType();
529+
rdcspv::Scalar scalarType = editor.GetDataType(editor.GetIDType(store.object)).scalar();
530+
531+
if(scalarType.type == rdcspv::Op::TypeFloat)
532+
editor.AddOperation(it, rdcspv::OpFConvert(storedType, tmpStoreVal, store.object));
533+
else if(scalarType.signedness)
534+
editor.AddOperation(it, rdcspv::OpSConvert(storedType, tmpStoreVal, store.object));
535+
else
536+
editor.AddOperation(it, rdcspv::OpUConvert(storedType, tmpStoreVal, store.object));
537+
538+
++it;
539+
540+
RDCASSERT(it.opcode() == rdcspv::Op::Store);
541+
542+
editor.PreModify(it);
543+
it.word(2) = tmpStoreVal.value();
544+
editor.PostModify(it);
545+
}
546+
}
547+
}
548+
413549
for(rdcspv::Iter it = editor.Begin(rdcspv::Section::Functions); it; ++it)
414550
{
415551
// identify functions with result types we might want to replace
@@ -518,25 +654,31 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
518654

519655
io.variable = patchData.outputs[i].ID;
520656

521-
// base type - either a scalar or a vector, since matrix outputs are decayed to vectors
522-
{
523-
rdcspv::Scalar scalarType = rdcspv::scalar(refl.outputSignature[i].varType);
657+
VarType varType = refl.outputSignature[i].varType;
524658

525-
io.vec4Type = editor.DeclareType(rdcspv::Vector(scalarType, 4));
659+
const uint32_t compCount = refl.outputSignature[i].compCount;
526660

527-
if(refl.outputSignature[i].compCount > 1)
528-
io.baseType =
529-
editor.DeclareType(rdcspv::Vector(scalarType, refl.outputSignature[i].compCount));
530-
else
531-
io.baseType = editor.DeclareType(scalarType);
532-
}
661+
// upconvert to 32-bit as needed
662+
if(varType == VarType::Half)
663+
varType = VarType::Float;
664+
else if(varType == VarType::SShort || varType == VarType::SByte)
665+
varType = VarType::SInt;
666+
else if(varType == VarType::UShort || varType == VarType::UByte)
667+
varType = VarType::UInt;
668+
669+
rdcspv::Scalar scalarType = rdcspv::scalar(varType);
670+
671+
if(compCount > 1)
672+
io.baseType = editor.DeclareType(rdcspv::Vector(scalarType, compCount));
673+
else
674+
io.baseType = editor.DeclareType(scalarType);
533675

534676
io.ssboPtrType = editor.DeclareType(rdcspv::Pointer(io.baseType, bufferClass));
535677
io.privatePtrType =
536678
editor.DeclareType(rdcspv::Pointer(io.baseType, rdcspv::StorageClass::Private));
537679

538-
RDCASSERT(io.baseType && io.vec4Type && io.indexConst && io.privatePtrType && io.ssboPtrType,
539-
io.baseType, io.vec4Type, io.indexConst, io.privatePtrType, io.ssboPtrType);
680+
RDCASSERT(io.baseType && io.indexConst && io.privatePtrType && io.ssboPtrType, io.baseType,
681+
io.indexConst, io.privatePtrType, io.ssboPtrType);
540682
}
541683

542684
// repeat for inputs
@@ -551,49 +693,41 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
551693

552694
io.variable = patchData.inputs[i].ID;
553695

554-
VarType vType = refl.inputSignature[i].varType;
696+
VarType varType = refl.inputSignature[i].varType;
697+
698+
const uint32_t compCount = refl.inputSignature[i].compCount;
555699

556-
rdcspv::Scalar scalarType = rdcspv::scalar(vType);
700+
// upconvert to 32-bit as needed
701+
if(varType == VarType::Half)
702+
varType = VarType::Float;
703+
else if(varType == VarType::SShort || varType == VarType::SByte)
704+
varType = VarType::SInt;
705+
else if(varType == VarType::UShort || varType == VarType::UByte)
706+
varType = VarType::UInt;
707+
708+
rdcspv::Scalar scalarType = rdcspv::scalar(varType);
557709

558710
// 64-bit values are loaded as uvec4 and then packed in pairs, so we need to declare vec4ID as
559711
// uvec4
560-
if(vType == VarType::Double || vType == VarType::ULong || vType == VarType::SLong)
712+
if(varType == VarType::Double || varType == VarType::ULong || varType == VarType::SLong)
561713
{
562-
io.fetchVec4Type = io.vec4Type =
563-
editor.DeclareType(rdcspv::Vector(rdcspv::scalar<uint32_t>(), 4));
714+
io.fetchVec4Type = editor.DeclareType(rdcspv::Vector(rdcspv::scalar<uint32_t>(), 4));
564715
}
565716
else
566717
{
567-
io.vec4Type = editor.DeclareType(rdcspv::Vector(scalarType, 4));
568-
569-
// if the underlying scalar is actually
570-
switch(vType)
571-
{
572-
case VarType::Half:
573-
io.fetchVec4Type = editor.DeclareType(rdcspv::Vector(rdcspv::scalar<float>(), 4));
574-
break;
575-
case VarType::SShort:
576-
case VarType::SByte:
577-
io.fetchVec4Type = editor.DeclareType(rdcspv::Vector(rdcspv::scalar<int32_t>(), 4));
578-
break;
579-
case VarType::UShort:
580-
case VarType::UByte:
581-
io.fetchVec4Type = editor.DeclareType(rdcspv::Vector(rdcspv::scalar<uint32_t>(), 4));
582-
break;
583-
default: io.fetchVec4Type = io.vec4Type; break;
584-
}
718+
io.fetchVec4Type = editor.DeclareType(rdcspv::Vector(scalarType, 4));
585719
}
586720

587721
if(refl.inputSignature[i].compCount > 1)
588-
io.baseType = editor.DeclareType(rdcspv::Vector(scalarType, refl.inputSignature[i].compCount));
722+
io.baseType = editor.DeclareType(rdcspv::Vector(scalarType, compCount));
589723
else
590724
io.baseType = editor.DeclareType(scalarType);
591725

592726
io.privatePtrType =
593727
editor.DeclareType(rdcspv::Pointer(io.baseType, rdcspv::StorageClass::Private));
594728

595-
RDCASSERT(io.baseType && io.vec4Type && io.indexConst && io.privatePtrType, io.baseType,
596-
io.vec4Type, io.indexConst, io.privatePtrType);
729+
RDCASSERT(io.fetchVec4Type && io.baseType && io.indexConst && io.privatePtrType,
730+
io.fetchVec4Type, io.baseType, io.indexConst, io.privatePtrType);
597731
}
598732

599733
rdcspv::Id u32Type = editor.DeclareType(rdcspv::scalar<uint32_t>());
@@ -1107,19 +1241,8 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
11071241
if(ins[i].fetchVec4Type != uvec4Type)
11081242
result = ops.add(rdcspv::OpBitcast(ins[i].fetchVec4Type, editor.MakeId(), result));
11091243

1110-
// we always fetch as full 32-bit values, but if the input was declared as a different
1111-
// size (typically ushort or half) then convert here
1112-
if(ins[i].fetchVec4Type != ins[i].vec4Type)
1113-
{
1114-
if(VarTypeCompType(vType) == CompType::Float)
1115-
result = ops.add(rdcspv::OpFConvert(ins[i].vec4Type, editor.MakeId(), result));
1116-
else if(VarTypeCompType(vType) == CompType::UInt)
1117-
result = ops.add(rdcspv::OpUConvert(ins[i].vec4Type, editor.MakeId(), result));
1118-
else
1119-
result = ops.add(rdcspv::OpSConvert(ins[i].vec4Type, editor.MakeId(), result));
1120-
}
1121-
1122-
uint32_t comp = Bits::CountTrailingZeroes(uint32_t(refl.inputSignature[i].regChannelMask));
1244+
uint32_t firstComp =
1245+
Bits::CountTrailingZeroes(uint32_t(refl.inputSignature[i].regChannelMask));
11231246

11241247
if(vType == VarType::Double || vType == VarType::ULong || vType == VarType::SLong)
11251248
{
@@ -1198,8 +1321,8 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
11981321
// for one component, extract x
11991322

12001323
// baseType value = result.x;
1201-
result =
1202-
ops.add(rdcspv::OpCompositeExtract(ins[i].baseType, editor.MakeId(), result, {comp}));
1324+
result = ops.add(
1325+
rdcspv::OpCompositeExtract(ins[i].baseType, editor.MakeId(), result, {firstComp}));
12031326
}
12041327
else if(refl.inputSignature[i].compCount != 4)
12051328
{
@@ -1208,7 +1331,7 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
12081331
rdcarray<uint32_t> swizzle;
12091332

12101333
for(uint32_t c = 0; c < refl.inputSignature[i].compCount; c++)
1211-
swizzle.push_back(c + comp);
1334+
swizzle.push_back(c + firstComp);
12121335

12131336
// baseTypeN value = result.xyz;
12141337
result = ops.add(

0 commit comments

Comments
 (0)