@@ -97,23 +97,25 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
9797 rdcspv::Id variable;
9898 // constant ID for the index of this attribute
9999 rdcspv::Id indexConst;
100- // base gvec4 type for this input. We always fetch uvec4 from the buffer but then bitcast to
101- // vec4 or ivec4 if needed
100+ // only for inputs - we load as uvec4 and bitcast to this type (vec4/ivec4) as needed. This is a
101+ // 4-component vector always
102102 rdcspv::Id fetchVec4Type;
103- // the actual gvec4 type for the input, possibly needed to convert to from the above if it's
104- // declared as a 16-bit type since we always fetch 32-bit.
105- rdcspv::Id vec4Type;
106- // the base type for this attribute. Must be present already by definition! This is the same
107- // scalar type as vec4Type but with the correct number of components.
103+ // the type with the right number of components but the component is rounded up to a 32-bit type
108104 rdcspv::Id baseType;
109105 // Uniform Pointer type ID for this output. Used only for output data, to write to output SSBO
106+ // underlying type is baseType
110107 rdcspv::Id ssboPtrType;
111108 // Output Pointer type ID for this attribute.
109+ // underlying type is baseType
112110 // For inputs, used to 'write' to the global at the start.
113111 // For outputs, used to 'read' from the global at the end.
114112 rdcspv::Id privatePtrType;
115113 };
116114
115+ rdcspv::Id uint32Type = editor.DeclareType (rdcspv::scalar<uint32_t >());
116+ rdcspv::Id sint32Type = editor.DeclareType (rdcspv::scalar<int32_t >());
117+ rdcspv::Id floatType = editor.DeclareType (rdcspv::scalar<float >());
118+
117119 rdcarray<inputOutputIDs> ins;
118120 ins.resize (numInputs);
119121 rdcarray<inputOutputIDs> outs;
@@ -123,6 +125,7 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
123125 std::set<rdcspv::Id> outputs;
124126
125127 std::map<rdcspv::Id, rdcspv::Id> typeReplacements;
128+ rdcarray<rdcspv::Id> expandedPtrTypes, expandedPtrVars;
126129
127130 // keep track of any builtins we're preserving
128131 std::set<rdcspv::Id> builtinKeeps;
@@ -238,6 +241,65 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
238241
239242 if (id)
240243 {
244+ rdcspv::DataType &dataType = editor.GetDataType (ptr.type );
245+
246+ rdcspv::Id expandedPtr;
247+
248+ // expand up input/output pointers to sub-32-bit types to be 32-bit
249+ if (dataType.scalar ().width < 32 && dataType.scalar ().width > 0 )
250+ {
251+ VarType varType = dataType.scalar ().Type ();
252+
253+ if (varType == VarType::Half)
254+ varType = VarType::Float;
255+ else if (varType == VarType::SShort || varType == VarType::SByte)
256+ varType = VarType::SInt;
257+ else if (varType == VarType::UShort || varType == VarType::UByte)
258+ varType = VarType::UInt;
259+
260+ if (dataType.type == rdcspv::DataType::VectorType)
261+ {
262+ const uint32_t compCount = dataType.vector ().count ;
263+ expandedPtr = editor.GetType (rdcspv::Vector (rdcspv::scalar (varType), compCount));
264+
265+ // if this pointer doesn't exist, add it while preserving the iterator
266+ if (expandedPtr == rdcspv::Id ())
267+ {
268+ if (varType == VarType::Float)
269+ expandedPtr = editor.AddOperation (
270+ it, rdcspv::OpTypeVector (editor.MakeId (), floatType, compCount));
271+ else if (varType == VarType::UInt)
272+ expandedPtr = editor.AddOperation (
273+ it, rdcspv::OpTypeVector (editor.MakeId (), uint32Type, compCount));
274+ else
275+ expandedPtr = editor.AddOperation (
276+ it, rdcspv::OpTypeVector (editor.MakeId (), sint32Type, compCount));
277+ ++it;
278+ }
279+ }
280+ else
281+ {
282+ expandedPtr = editor.GetType (rdcspv::scalar (varType));
283+
284+ // if this pointer doesn't exist, add it while preserving the iterator
285+ if (expandedPtr == rdcspv::Id ())
286+ {
287+ if (varType == VarType::Float)
288+ expandedPtr = editor.AddOperation (it, rdcspv::OpTypeFloat (editor.MakeId (), 32 ));
289+ else
290+ expandedPtr = editor.AddOperation (
291+ it, rdcspv::OpTypeInt (editor.MakeId (), 32 , varType == VarType::SInt));
292+ ++it;
293+ }
294+ }
295+
296+ ptr.type = expandedPtr;
297+
298+ // record the original pointer type so we can patch with conversions any loads/stores
299+ if (!expandedPtrTypes.contains (ptr.result ))
300+ expandedPtrTypes.push_back (ptr.result );
301+ }
302+
241303 rdcspv::Pointer privPtr (ptr.type , rdcspv::StorageClass::Private);
242304
243305 rdcspv::Id origId = editor.GetType (privPtr);
@@ -269,6 +331,9 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
269331
270332 bool mod = false ;
271333
334+ if (expandedPtrTypes.contains (var.resultType ))
335+ expandedPtrVars.push_back (var.result );
336+
272337 if (builtinKeeps.find (var.result ) != builtinKeeps.end ())
273338 {
274339 // if this variable is one we're keeping as a builtin, we need to do something different.
@@ -410,6 +475,77 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
410475 }
411476 }
412477
478+ for (rdcspv::Iter it = editor.Begin (rdcspv::Section::Functions); it; ++it)
479+ {
480+ // identify any loads or stores via expanded pointer types, these will expect a different return
481+ // type so insert an appropriate conversion to expand/contract. The pointers themselves will be
482+ // handled either via globals being patched or access chains above being patched (at the same
483+ // time as we patch from input/output to private). The only thing remaining is the potential
484+ // type mismatch of storing a half to a float or loading a half from a float etc.
485+ //
486+ // we do this before patching the types in any OpAccessChain so we can identify such loads or
487+ // stores either due to using one of the old pointer types, or because the pointer is the global
488+ // directly (it must be one or ther other)
489+ if (it.opcode () == rdcspv::Op::Load)
490+ {
491+ rdcspv::OpLoad load (it);
492+
493+ rdcspv::Id ptrType = editor.GetIDType (load.pointer );
494+
495+ if (expandedPtrTypes.contains (ptrType) || expandedPtrVars.contains (load.pointer ))
496+ {
497+ // this pointer was expanded, get the new type and update the load to a temp id
498+ rdcspv::Id tmpLoadedVal = editor.MakeId ();
499+
500+ editor.PreModify (it);
501+ it.word (1 ) = editor.GetDataType (ptrType).InnerType ().value ();
502+ it.word (2 ) = tmpLoadedVal.value ();
503+ editor.PostModify (it);
504+
505+ ++it;
506+
507+ rdcspv::Scalar scalarType = editor.GetDataType (load.resultType ).scalar ();
508+
509+ if (scalarType.type == rdcspv::Op::TypeFloat)
510+ editor.AddOperation (it, rdcspv::OpFConvert (load.resultType , load.result , tmpLoadedVal));
511+ else if (scalarType.signedness )
512+ editor.AddOperation (it, rdcspv::OpSConvert (load.resultType , load.result , tmpLoadedVal));
513+ else
514+ editor.AddOperation (it, rdcspv::OpUConvert (load.resultType , load.result , tmpLoadedVal));
515+ }
516+ }
517+ else if (it.opcode () == rdcspv::Op::Store)
518+ {
519+ rdcspv::OpStore store (it);
520+
521+ rdcspv::Id ptrType = editor.GetIDType (store.pointer );
522+
523+ if (expandedPtrTypes.contains (ptrType) || expandedPtrVars.contains (store.pointer ))
524+ {
525+ // this pointer was expanded, get the new type and update the store to use a temp id
526+ rdcspv::Id tmpStoreVal = editor.MakeId ();
527+
528+ rdcspv::Id storedType = editor.GetDataType (ptrType).InnerType ();
529+ rdcspv::Scalar scalarType = editor.GetDataType (editor.GetIDType (store.object )).scalar ();
530+
531+ if (scalarType.type == rdcspv::Op::TypeFloat)
532+ editor.AddOperation (it, rdcspv::OpFConvert (storedType, tmpStoreVal, store.object ));
533+ else if (scalarType.signedness )
534+ editor.AddOperation (it, rdcspv::OpSConvert (storedType, tmpStoreVal, store.object ));
535+ else
536+ editor.AddOperation (it, rdcspv::OpUConvert (storedType, tmpStoreVal, store.object ));
537+
538+ ++it;
539+
540+ RDCASSERT (it.opcode () == rdcspv::Op::Store);
541+
542+ editor.PreModify (it);
543+ it.word (2 ) = tmpStoreVal.value ();
544+ editor.PostModify (it);
545+ }
546+ }
547+ }
548+
413549 for (rdcspv::Iter it = editor.Begin (rdcspv::Section::Functions); it; ++it)
414550 {
415551 // identify functions with result types we might want to replace
@@ -518,25 +654,31 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
518654
519655 io.variable = patchData.outputs [i].ID ;
520656
521- // base type - either a scalar or a vector, since matrix outputs are decayed to vectors
522- {
523- rdcspv::Scalar scalarType = rdcspv::scalar (refl.outputSignature [i].varType );
657+ VarType varType = refl.outputSignature [i].varType ;
524658
525- io. vec4Type = editor. DeclareType ( rdcspv::Vector (scalarType, 4 )) ;
659+ const uint32_t compCount = refl. outputSignature [i]. compCount ;
526660
527- if (refl.outputSignature [i].compCount > 1 )
528- io.baseType =
529- editor.DeclareType (rdcspv::Vector (scalarType, refl.outputSignature [i].compCount ));
530- else
531- io.baseType = editor.DeclareType (scalarType);
532- }
661+ // upconvert to 32-bit as needed
662+ if (varType == VarType::Half)
663+ varType = VarType::Float;
664+ else if (varType == VarType::SShort || varType == VarType::SByte)
665+ varType = VarType::SInt;
666+ else if (varType == VarType::UShort || varType == VarType::UByte)
667+ varType = VarType::UInt;
668+
669+ rdcspv::Scalar scalarType = rdcspv::scalar (varType);
670+
671+ if (compCount > 1 )
672+ io.baseType = editor.DeclareType (rdcspv::Vector (scalarType, compCount));
673+ else
674+ io.baseType = editor.DeclareType (scalarType);
533675
534676 io.ssboPtrType = editor.DeclareType (rdcspv::Pointer (io.baseType , bufferClass));
535677 io.privatePtrType =
536678 editor.DeclareType (rdcspv::Pointer (io.baseType , rdcspv::StorageClass::Private));
537679
538- RDCASSERT (io.baseType && io.vec4Type && io. indexConst && io.privatePtrType && io.ssboPtrType ,
539- io.baseType , io. vec4Type , io. indexConst , io.privatePtrType , io.ssboPtrType );
680+ RDCASSERT (io.baseType && io.indexConst && io.privatePtrType && io.ssboPtrType , io. baseType ,
681+ io.indexConst , io.privatePtrType , io.ssboPtrType );
540682 }
541683
542684 // repeat for inputs
@@ -551,49 +693,41 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
551693
552694 io.variable = patchData.inputs [i].ID ;
553695
554- VarType vType = refl.inputSignature [i].varType ;
696+ VarType varType = refl.inputSignature [i].varType ;
697+
698+ const uint32_t compCount = refl.inputSignature [i].compCount ;
555699
556- rdcspv::Scalar scalarType = rdcspv::scalar (vType);
700+ // upconvert to 32-bit as needed
701+ if (varType == VarType::Half)
702+ varType = VarType::Float;
703+ else if (varType == VarType::SShort || varType == VarType::SByte)
704+ varType = VarType::SInt;
705+ else if (varType == VarType::UShort || varType == VarType::UByte)
706+ varType = VarType::UInt;
707+
708+ rdcspv::Scalar scalarType = rdcspv::scalar (varType);
557709
558710 // 64-bit values are loaded as uvec4 and then packed in pairs, so we need to declare vec4ID as
559711 // uvec4
560- if (vType == VarType::Double || vType == VarType::ULong || vType == VarType::SLong)
712+ if (varType == VarType::Double || varType == VarType::ULong || varType == VarType::SLong)
561713 {
562- io.fetchVec4Type = io.vec4Type =
563- editor.DeclareType (rdcspv::Vector (rdcspv::scalar<uint32_t >(), 4 ));
714+ io.fetchVec4Type = editor.DeclareType (rdcspv::Vector (rdcspv::scalar<uint32_t >(), 4 ));
564715 }
565716 else
566717 {
567- io.vec4Type = editor.DeclareType (rdcspv::Vector (scalarType, 4 ));
568-
569- // if the underlying scalar is actually
570- switch (vType)
571- {
572- case VarType::Half:
573- io.fetchVec4Type = editor.DeclareType (rdcspv::Vector (rdcspv::scalar<float >(), 4 ));
574- break ;
575- case VarType::SShort:
576- case VarType::SByte:
577- io.fetchVec4Type = editor.DeclareType (rdcspv::Vector (rdcspv::scalar<int32_t >(), 4 ));
578- break ;
579- case VarType::UShort:
580- case VarType::UByte:
581- io.fetchVec4Type = editor.DeclareType (rdcspv::Vector (rdcspv::scalar<uint32_t >(), 4 ));
582- break ;
583- default : io.fetchVec4Type = io.vec4Type ; break ;
584- }
718+ io.fetchVec4Type = editor.DeclareType (rdcspv::Vector (scalarType, 4 ));
585719 }
586720
587721 if (refl.inputSignature [i].compCount > 1 )
588- io.baseType = editor.DeclareType (rdcspv::Vector (scalarType, refl. inputSignature [i]. compCount ));
722+ io.baseType = editor.DeclareType (rdcspv::Vector (scalarType, compCount));
589723 else
590724 io.baseType = editor.DeclareType (scalarType);
591725
592726 io.privatePtrType =
593727 editor.DeclareType (rdcspv::Pointer (io.baseType , rdcspv::StorageClass::Private));
594728
595- RDCASSERT (io.baseType && io.vec4Type && io.indexConst && io.privatePtrType , io. baseType ,
596- io.vec4Type , io.indexConst , io.privatePtrType );
729+ RDCASSERT (io.fetchVec4Type && io.baseType && io.indexConst && io.privatePtrType ,
730+ io.fetchVec4Type , io. baseType , io.indexConst , io.privatePtrType );
597731 }
598732
599733 rdcspv::Id u32Type = editor.DeclareType (rdcspv::scalar<uint32_t >());
@@ -1107,19 +1241,8 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
11071241 if (ins[i].fetchVec4Type != uvec4Type)
11081242 result = ops.add (rdcspv::OpBitcast (ins[i].fetchVec4Type , editor.MakeId (), result));
11091243
1110- // we always fetch as full 32-bit values, but if the input was declared as a different
1111- // size (typically ushort or half) then convert here
1112- if (ins[i].fetchVec4Type != ins[i].vec4Type )
1113- {
1114- if (VarTypeCompType (vType) == CompType::Float)
1115- result = ops.add (rdcspv::OpFConvert (ins[i].vec4Type , editor.MakeId (), result));
1116- else if (VarTypeCompType (vType) == CompType::UInt)
1117- result = ops.add (rdcspv::OpUConvert (ins[i].vec4Type , editor.MakeId (), result));
1118- else
1119- result = ops.add (rdcspv::OpSConvert (ins[i].vec4Type , editor.MakeId (), result));
1120- }
1121-
1122- uint32_t comp = Bits::CountTrailingZeroes (uint32_t (refl.inputSignature [i].regChannelMask ));
1244+ uint32_t firstComp =
1245+ Bits::CountTrailingZeroes (uint32_t (refl.inputSignature [i].regChannelMask ));
11231246
11241247 if (vType == VarType::Double || vType == VarType::ULong || vType == VarType::SLong)
11251248 {
@@ -1198,8 +1321,8 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
11981321 // for one component, extract x
11991322
12001323 // baseType value = result.x;
1201- result =
1202- ops. add ( rdcspv::OpCompositeExtract (ins[i].baseType , editor.MakeId (), result, {comp }));
1324+ result = ops. add (
1325+ rdcspv::OpCompositeExtract (ins[i].baseType , editor.MakeId (), result, {firstComp }));
12031326 }
12041327 else if (refl.inputSignature [i].compCount != 4 )
12051328 {
@@ -1208,7 +1331,7 @@ static void ConvertToMeshOutputCompute(const ShaderReflection &refl,
12081331 rdcarray<uint32_t > swizzle;
12091332
12101333 for (uint32_t c = 0 ; c < refl.inputSignature [i].compCount ; c++)
1211- swizzle.push_back (c + comp );
1334+ swizzle.push_back (c + firstComp );
12121335
12131336 // baseTypeN value = result.xyz;
12141337 result = ops.add (
0 commit comments