|
| 1 | +#include "amdilc_spirv.h" |
| 2 | +#include "amdilc_internal.h" |
| 3 | + |
| 4 | +typedef struct { |
| 5 | + IlcSpvId varId; |
| 6 | + IlcSpvWord location; |
| 7 | +} IlcInputRegister; |
| 8 | + |
| 9 | +typedef struct { |
| 10 | + IlcSpvBuffer sourceBuffer; |
| 11 | + IlcSpvModule* module; |
| 12 | + IlcSpvId entryPointId; |
| 13 | + SpvExecutionModel execModel; |
| 14 | + const char* entryPointName; |
| 15 | + IlcSpvId* interfaces; |
| 16 | + unsigned interfaceCount; |
| 17 | + IlcInputRegister* existingInputRegisters; |
| 18 | + unsigned existingInputCount; |
| 19 | + unsigned outputPointsCount; |
| 20 | + IlcSpvId invocationVarId; |
| 21 | + IlcSpvId floatId; |
| 22 | + IlcSpvId float4Id; |
| 23 | + IlcSpvId intId; |
| 24 | + bool isInFunction; |
| 25 | + bool isAfterReturn; |
| 26 | +} IlcRecompiler; |
| 27 | + |
| 28 | +IlcRecompiledShader ilcRecompileKernel( |
| 29 | + const IlcSpvWord* spirvWords, |
| 30 | + unsigned wordCount, |
| 31 | + const unsigned* inputPassthroughLocations, |
| 32 | + unsigned passthroughCount) |
| 33 | +{ |
| 34 | + IlcSpvModule module; |
| 35 | + module.currentId = 0; |
| 36 | + for (int i = 0; i < ID_MAX; i++) { |
| 37 | + module.buffer[i] = (IlcSpvBuffer) { 0, NULL }; |
| 38 | + } |
| 39 | + |
| 40 | + IlcRecompiler recompiler = (IlcRecompiler){ |
| 41 | + .module = &module, |
| 42 | + .entryPointId = 0, |
| 43 | + .execModel = 0, |
| 44 | + .entryPointName = NULL, |
| 45 | + .interfaces = NULL, |
| 46 | + .interfaceCount = 0, |
| 47 | + .existingInputRegisters = NULL, |
| 48 | + .existingInputCount = 0, |
| 49 | + .outputPointsCount = 0, |
| 50 | + .invocationVarId = 0, |
| 51 | + .floatId = 0, |
| 52 | + .float4Id = 0, |
| 53 | + .intId = 0, |
| 54 | + .isInFunction = false, |
| 55 | + .isAfterReturn = false, |
| 56 | + }; |
| 57 | + //header will be inserted at finish |
| 58 | + unsigned bufferIndex = ID_CAPABILITIES; |
| 59 | + unsigned bufferStart = 5; |
| 60 | + for (unsigned i = 5; i < wordCount;) { |
| 61 | + SpvOp opCode = spirvWords[i] & SpvOpCodeMask; |
| 62 | + unsigned instrWordCount = spirvWords[i] >> SpvWordCountShift; |
| 63 | + unsigned newBufferIndex = getBufferIndex(opCode); |
| 64 | + if (newBufferIndex != bufferIndex) { |
| 65 | + if (bufferIndex != ID_ENTRY_POINTS && bufferIndex != ID_CODE) { |
| 66 | + // skip the entry point as it will be rewritten |
| 67 | + ilcSpvUnwrapBuffer(&module.buffer[bufferIndex], &spirvWords[bufferStart], i - bufferStart); |
| 68 | + } |
| 69 | + bufferIndex = newBufferIndex; |
| 70 | + bufferStart = i; |
| 71 | + } |
| 72 | + bool finishProcessing = false; |
| 73 | + switch (bufferIndex) { |
| 74 | + case ID_TYPES: |
| 75 | + case ID_TYPES_WITH_CONSTANTS: |
| 76 | + if (opCode == SpvOpTypeFloat && spirvWords[i + 2] == 32) { |
| 77 | + recompiler.floatId = spirvWords[i + 1]; |
| 78 | + } else if (opCode == SpvOpTypeVector && spirvWords[i + 2] == recompiler.floatId && spirvWords[i + 3] == 4) { |
| 79 | + recompiler.float4Id = spirvWords[i + 1]; |
| 80 | + } else if (opCode == SpvOpTypeInt && spirvWords[i + 2] == 32 && spirvWords[i + 3]) { |
| 81 | + recompiler.intId = spirvWords[i + 1]; |
| 82 | + } |
| 83 | + module.currentId = module.currentId < spirvWords[i + 1] ? spirvWords[i + 1] : module.currentId; |
| 84 | + break; |
| 85 | + case ID_ENTRY_POINTS: |
| 86 | + if (opCode == SpvOpEntryPoint) { |
| 87 | + recompiler.execModel = spirvWords[i + 1]; |
| 88 | + recompiler.entryPointId = spirvWords[i + 2]; |
| 89 | + recompiler.entryPointName = (const char*)&spirvWords[i + 3]; |
| 90 | + unsigned nameLength = (strlen(recompiler.entryPointName) + 4) / sizeof(IlcSpvWord); |
| 91 | + recompiler.interfaceCount = instrWordCount - nameLength - 3; |
| 92 | + recompiler.interfaces = malloc(recompiler.interfaceCount * sizeof(IlcSpvWord)); |
| 93 | + memcpy(recompiler.interfaces, &spirvWords[i + 3 + nameLength], recompiler.interfaceCount * sizeof(IlcSpvWord)); |
| 94 | + } |
| 95 | + break; |
| 96 | + case ID_EXEC_MODES: |
| 97 | + if (opCode == SpvOpExecutionMode && spirvWords[i + 2] == SpvExecutionModeOutputVertices) { |
| 98 | + recompiler.outputPointsCount = spirvWords[i + 3]; |
| 99 | + } |
| 100 | + break; |
| 101 | + case ID_VARIABLES: |
| 102 | + if (opCode == SpvOpVariable && spirvWords[i + 3] == SpvStorageClassInput) { |
| 103 | + bool foundLocation = false; |
| 104 | + IlcSpvWord locationIdx = 0; |
| 105 | + IlcSpvId varId = spirvWords[i + 2]; |
| 106 | + for (unsigned j = 0; !foundLocation && j < module.buffer[ID_DECORATIONS].wordCount;) { |
| 107 | + SpvOp decorOpCode = module.buffer[ID_DECORATIONS].words[j] & SpvOpCodeMask; |
| 108 | + unsigned decorInstrWordCount = module.buffer[ID_DECORATIONS].words[j] >> SpvWordCountShift; |
| 109 | + if (decorOpCode == SpvOpDecorate && module.buffer[ID_DECORATIONS].words[j + 1] == varId && |
| 110 | + module.buffer[ID_DECORATIONS].words[j + 2] == SpvDecorationLocation) { |
| 111 | + locationIdx = module.buffer[ID_DECORATIONS].words[j + 3]; |
| 112 | + foundLocation = true; |
| 113 | + } else if (decorOpCode == SpvOpDecorate && module.buffer[ID_DECORATIONS].words[j + 1] == varId && |
| 114 | + module.buffer[ID_DECORATIONS].words[j + 2] == SpvDecorationBuiltIn && |
| 115 | + module.buffer[ID_DECORATIONS].words[j + 3] == SpvBuiltInInvocationId) { |
| 116 | + recompiler.invocationVarId = varId; |
| 117 | + } |
| 118 | + j += decorInstrWordCount; |
| 119 | + } |
| 120 | + if (foundLocation) { |
| 121 | + recompiler.existingInputRegisters = realloc(recompiler.existingInputRegisters, (1 + recompiler.existingInputCount) * sizeof(IlcInputRegister)); |
| 122 | + recompiler.existingInputRegisters[recompiler.existingInputCount] = (IlcInputRegister) { |
| 123 | + .varId = varId, |
| 124 | + .location = locationIdx, |
| 125 | + }; |
| 126 | + recompiler.existingInputCount++; |
| 127 | + } |
| 128 | + } |
| 129 | + break; |
| 130 | + case ID_CODE: |
| 131 | + if (opCode == SpvOpFunction && spirvWords[i + 2] == recompiler.entryPointId) { |
| 132 | + recompiler.isInFunction = true; |
| 133 | + } else if (opCode == SpvOpStore) { |
| 134 | + module.currentId = module.currentId < spirvWords[i + 1] ? spirvWords[i + 1] : module.currentId; |
| 135 | + } else if (opCode == SpvOpLoad) { |
| 136 | + module.currentId = module.currentId < spirvWords[i + 2] ? spirvWords[i + 2] : module.currentId; |
| 137 | + } |
| 138 | + if (opCode == SpvOpReturn && recompiler.isInFunction) { |
| 139 | + finishProcessing = true; |
| 140 | + } else { |
| 141 | + // copy the code over |
| 142 | + ilcSpvUnwrapBuffer(&module.buffer[ID_CODE], &spirvWords[i], instrWordCount); |
| 143 | + } |
| 144 | + break; |
| 145 | + } |
| 146 | + if (finishProcessing) { |
| 147 | + break; |
| 148 | + } |
| 149 | + i += instrWordCount; |
| 150 | + } |
| 151 | + // HACK: just add offset to avoid collision |
| 152 | + module.currentId += 65536; |
| 153 | + //TODO: handle outputs checking |
| 154 | + IlcSpvId float4InputPtrTypeId = ilcSpvPutPointerType(&module, SpvStorageClassInput, recompiler.float4Id); |
| 155 | + IlcSpvId float4OutputPtrTypeId = ilcSpvPutPointerType(&module, SpvStorageClassOutput, recompiler.float4Id); |
| 156 | + if (recompiler.execModel == SpvExecutionModelTessellationControl) { |
| 157 | + if (recompiler.invocationVarId == 0) { |
| 158 | + IlcSpvId intPtrInputId = ilcSpvPutPointerType(&module, SpvStorageClassInput, recompiler.intId); |
| 159 | + recompiler.invocationVarId = ilcSpvPutVariable(&module, intPtrInputId, SpvStorageClassInput); |
| 160 | + IlcSpvWord builtInType = SpvBuiltInInvocationId; |
| 161 | + ilcSpvPutDecoration(&module, recompiler.invocationVarId, SpvDecorationBuiltIn, 1, &builtInType); |
| 162 | + ilcSpvPutName(&module, recompiler.invocationVarId, "invocationId"); |
| 163 | + recompiler.interfaces = realloc(recompiler.interfaces, (recompiler.interfaceCount + 1) * sizeof(IlcSpvId)); |
| 164 | + recompiler.interfaces[recompiler.interfaceCount] = recompiler.invocationVarId; |
| 165 | + recompiler.interfaceCount++; |
| 166 | + } |
| 167 | + int maxArraySize = -1; |
| 168 | + for (unsigned i = 0; i < passthroughCount; ++i) { |
| 169 | + maxArraySize = maxArraySize < (int)inputPassthroughLocations[i] ? inputPassthroughLocations[i] : maxArraySize; |
| 170 | + } |
| 171 | + for (unsigned i = 0; i < recompiler.existingInputCount; ++i) { |
| 172 | + maxArraySize = maxArraySize < (int)recompiler.existingInputRegisters[i].location ? recompiler.existingInputRegisters[i].location : maxArraySize; |
| 173 | + } |
| 174 | + maxArraySize++; |
| 175 | + if (maxArraySize <= 0) { |
| 176 | + goto finish; |
| 177 | + } |
| 178 | + //vertex count |
| 179 | + if (recompiler.outputPointsCount == 0) { |
| 180 | + LOGW("didn't handle output control point count\n"); |
| 181 | + recompiler.outputPointsCount = 3; |
| 182 | + } |
| 183 | + IlcSpvId vertexLengthId = ilcSpvPutConstant(&module, recompiler.intId, recompiler.outputPointsCount); |
| 184 | + |
| 185 | + //TODO: check input/output vertex count |
| 186 | + IlcSpvId inputArrTypeId = ilcSpvPutArrayType(&module, recompiler.float4Id, vertexLengthId); |
| 187 | + IlcSpvId inputVarTypeId = ilcSpvPutPointerType(&module, SpvStorageClassInput, inputArrTypeId); |
| 188 | + |
| 189 | + IlcSpvId outputLengthId = ilcSpvPutConstant(&module, recompiler.intId, maxArraySize); |
| 190 | + // array of registers per vertex |
| 191 | + IlcSpvId outputArrTypeId = ilcSpvPutArrayType(&module, recompiler.float4Id, outputLengthId); |
| 192 | + // array of registers per primitive |
| 193 | + IlcSpvId outputVArrTypeId = ilcSpvPutArrayType(&module, outputArrTypeId, vertexLengthId); |
| 194 | + IlcSpvId outputVArrPtrTypeId = ilcSpvPutPointerType(&module, SpvStorageClassOutput, outputVArrTypeId); |
| 195 | + IlcSpvId outputVArrId = ilcSpvPutVariable(&module, outputVArrPtrTypeId, SpvStorageClassOutput); |
| 196 | + ilcSpvPutName(&module, outputVArrId, "vertex_out"); |
| 197 | + |
| 198 | + recompiler.interfaces = realloc(recompiler.interfaces, (recompiler.interfaceCount + 1) * sizeof(IlcSpvId)); |
| 199 | + recompiler.interfaces[recompiler.interfaceCount] = outputVArrId; |
| 200 | + recompiler.interfaceCount++; |
| 201 | + IlcSpvWord outputLocationIdx = 0; |
| 202 | + ilcSpvPutDecoration(&module, outputVArrId, SpvDecorationLocation, 1, &outputLocationIdx); |
| 203 | + for (unsigned i = 0; i < passthroughCount; ++i) { |
| 204 | + bool includesLocation = false; |
| 205 | + IlcSpvId inputVariableId = 0; |
| 206 | + for (unsigned j = 0; j < recompiler.existingInputCount; ++j) { |
| 207 | + if (recompiler.existingInputRegisters[j].location == inputPassthroughLocations[i]) { |
| 208 | + includesLocation = true; |
| 209 | + inputVariableId = recompiler.existingInputRegisters[j].varId; |
| 210 | + } |
| 211 | + } |
| 212 | + if (!includesLocation) { |
| 213 | + char name[32]; |
| 214 | + snprintf(name, sizeof(name), "vicp_patched%u", inputPassthroughLocations[i]); |
| 215 | + inputVariableId = ilcSpvPutVariable(&module, inputVarTypeId, SpvStorageClassInput); |
| 216 | + ilcSpvPutName(&module, inputVariableId, name); |
| 217 | + ilcSpvPutDecoration(&module, inputVariableId, SpvDecorationLocation, 1, &inputPassthroughLocations[i]); |
| 218 | + |
| 219 | + recompiler.interfaces = realloc(recompiler.interfaces, (recompiler.interfaceCount + 1) * sizeof(IlcSpvId)); |
| 220 | + recompiler.interfaces[recompiler.interfaceCount] = inputVariableId; |
| 221 | + recompiler.interfaceCount++; |
| 222 | + } |
| 223 | + IlcSpvId inputIndexId = ilcSpvPutConstant(&module, recompiler.intId, inputPassthroughLocations[i]); |
| 224 | + IlcSpvId invocationValueId = ilcSpvPutLoad(&module, recompiler.intId, recompiler.invocationVarId); |
| 225 | + |
| 226 | + IlcSpvId inputPtrId = ilcSpvPutAccessChain(&module, float4InputPtrTypeId, inputVariableId, 1, &invocationValueId); |
| 227 | + IlcSpvId loadedInputId = ilcSpvPutLoad(&module, recompiler.float4Id, inputPtrId); |
| 228 | + IlcSpvId indexesId[] = { |
| 229 | + invocationValueId, |
| 230 | + inputIndexId, |
| 231 | + }; |
| 232 | + IlcSpvId dstId = ilcSpvPutAccessChain(&module, float4OutputPtrTypeId, outputVArrId, 2, indexesId ); |
| 233 | + ilcSpvPutStore(&module, dstId, loadedInputId); |
| 234 | + } |
| 235 | + } else { |
| 236 | + for (unsigned i = 0; i < passthroughCount; ++i) { |
| 237 | + bool includesLocation = false; |
| 238 | + IlcSpvId inputVariableId = 0; |
| 239 | + for (unsigned j = 0; j < recompiler.existingInputCount; ++j) { |
| 240 | + if (recompiler.existingInputRegisters[j].location == inputPassthroughLocations[i]) { |
| 241 | + includesLocation = true; |
| 242 | + inputVariableId = recompiler.existingInputRegisters[j].varId; |
| 243 | + } |
| 244 | + } |
| 245 | + if (includesLocation) { |
| 246 | + // no need to passthrough |
| 247 | + continue; |
| 248 | + } |
| 249 | + inputVariableId = ilcSpvPutVariable(&module, float4InputPtrTypeId, SpvStorageClassInput); |
| 250 | + IlcSpvId outputVariableId = ilcSpvPutVariable(&module, float4OutputPtrTypeId, SpvStorageClassOutput); |
| 251 | + ilcSpvPutDecoration(&module, inputVariableId, SpvDecorationLocation, 1, &inputPassthroughLocations[i]); |
| 252 | + ilcSpvPutDecoration(&module, outputVariableId, SpvDecorationLocation, 1, &inputPassthroughLocations[i]); |
| 253 | + |
| 254 | + IlcSpvId valueId = ilcSpvPutLoad(&module, recompiler.float4Id, inputVariableId); |
| 255 | + ilcSpvPutStore(&module, outputVariableId, valueId); |
| 256 | + |
| 257 | + recompiler.interfaces = realloc(recompiler.interfaces, (recompiler.interfaceCount + 2) * sizeof(IlcSpvId)); |
| 258 | + recompiler.interfaces[recompiler.interfaceCount] = outputVariableId; |
| 259 | + recompiler.interfaces[recompiler.interfaceCount + 1] = inputVariableId; |
| 260 | + recompiler.interfaceCount += 2; |
| 261 | + } |
| 262 | + } |
| 263 | +finish: |
| 264 | + ilcSpvPutReturn(&module); |
| 265 | + ilcSpvPutFunctionEnd(&module); |
| 266 | + recompiler.isInFunction = false; |
| 267 | + ilcSpvPutEntryPoint(&module, recompiler.entryPointId, recompiler.execModel, recompiler.entryPointName, |
| 268 | + recompiler.interfaceCount, recompiler.interfaces); |
| 269 | + //inject some code |
| 270 | + ilcSpvFinish(&module); |
| 271 | + free(recompiler.existingInputRegisters); |
| 272 | + free(recompiler.interfaces); |
| 273 | + |
| 274 | + return (IlcRecompiledShader) { |
| 275 | + .codeSize = sizeof(IlcSpvWord) * module.buffer[ID_MAIN].wordCount, |
| 276 | + .code = module.buffer[ID_MAIN].words, |
| 277 | + }; |
| 278 | +} |
0 commit comments