@@ -7789,10 +7789,10 @@ bool EmitPass::validateInlineAsmConstraints(llvm::CallInst* inst, SmallVector<St
77897789 StringRef constraintStr(IA->getConstraintString());
77907790
77917791 //lambda for checking constraint types
7792- auto CheckConstraintTypes = [this](StringRef str)->bool
7792+ auto CheckConstraintTypes = [this](StringRef str, CVariable* cv )->bool
77937793 {
77947794 // TODO: Only "rw" (raw register operand) constraint allowed for now. Add more checks as needed
7795- if (str.equals("=rw"))
7795+ if (str.equals("=rw") && cv == m_destination )
77967796 {
77977797 return true;
77987798 }
@@ -7806,6 +7806,14 @@ bool EmitPass::validateInlineAsmConstraints(llvm::CallInst* inst, SmallVector<St
78067806 // and since output reg is always first, only match with '0'
78077807 return true;
78087808 }
7809+ else if (str.equals("i"))
7810+ {
7811+ return cv->IsImmediate();
7812+ }
7813+ else if (str.equals("rw.u"))
7814+ {
7815+ return cv->IsUniform();
7816+ }
78097817 else
78107818 {
78117819 assert(0 && "Unsupported constraint type!");
@@ -7819,31 +7827,19 @@ bool EmitPass::validateInlineAsmConstraints(llvm::CallInst* inst, SmallVector<St
78197827 bool success = true;
78207828
78217829 // Check the output constraint tokens
7822- unsigned index = 0;
7823- while (index < constraints.size() && constraints[index].startswith("="))
7830+ if (m_destination)
78247831 {
7825- success &= CheckConstraintTypes(constraints[index++] );
7832+ success &= CheckConstraintTypes(constraints[0], m_destination );
78267833 }
78277834
78287835 if (success)
78297836 {
78307837 // Check the input constraint tokens
7831- for (unsigned i = 0; i < inst->getNumArgOperands(); i++)
7838+ unsigned index = m_destination ? 1 : 0;
7839+ for (unsigned i = 0; i < inst->getNumArgOperands(); i++, index++)
78327840 {
7833- StringRef tstr = constraints[index++];
78347841 CVariable* cv = GetSymbol(inst->getArgOperand(i));
7835-
7836- if (tstr.endswith(".u"))
7837- {
7838- // Check if var is uniform
7839- if (!cv->IsUniform())
7840- {
7841- assert(0 && "Compiler cannot prove variable is uniform");
7842- return false;
7843- }
7844- tstr = tstr.substr(0, tstr.size() - 2);
7845- }
7846- success &= CheckConstraintTypes(tstr);
7842+ success &= CheckConstraintTypes(constraints[index], cv);
78477843 }
78487844 }
78497845 return success;
@@ -7855,95 +7851,88 @@ void EmitPass::EmitInlineAsm(llvm::CallInst* inst)
78557851{
78567852 std::stringstream& str = m_encoder->GetVISABuilder()->GetAsmTextStream();
78577853 InlineAsm* IA = cast<InlineAsm>(inst->getCalledValue());
7858- const char* asmStr = IA->getAsmString().c_str();
7859- const char* lastEmitted = asmStr;
7854+ string asmStr = IA->getAsmString();
78607855 smallvector<CVariable*, 8> opnds;
78617856 SmallVector<StringRef, 8> constraints;
7857+
7858+ if (asmStr.empty())
7859+ return;
7860+
78627861 if (!validateInlineAsmConstraints(inst, constraints))
78637862 {
78647863 assert(0 && "Constraints for inline assembly cannot be validated");
7864+ return;
78657865 }
78667866
78677867 if (m_destination)
78687868 {
7869- // Check if dest operand is also an input
7870- // If so, push the input operand as the destination
7871- bool hasReadWriteReg = false;
7872- unsigned opNum = 0;
7873- for (StringRef str : constraints)
7874- {
7875- if (str.startswith("="))
7876- continue;
7877- else if (str.equals("0"))
7878- {
7879- opnds.push_back(GetSymbol(inst->getArgOperand(opNum)));
7880- hasReadWriteReg = true;
7881- break;
7882- }
7883- opNum++;
7884- }
7885- if (!hasReadWriteReg)
7886- {
7887- opnds.push_back(m_destination);
7888- }
7869+ opnds.push_back(m_destination);
78897870 }
78907871 for (unsigned i = 0; i < inst->getNumArgOperands(); i++)
78917872 {
78927873 CVariable* cv = GetSymbol(inst->getArgOperand(i));
78937874 opnds.push_back(cv);
78947875 }
78957876
7896- str << endl << "/// Inlined ASM" << endl;
7897- while (*lastEmitted )
7877+ // Check for read/write registers
7878+ if (m_destination )
78987879 {
7899- switch (*lastEmitted )
7880+ for (unsigned i = 1; i < constraints.size(); i++ )
79007881 {
7901- default:
7902- {
7903- const char* literalEnd = lastEmitted + 1;
7904- while (*literalEnd && *literalEnd != '{' && *literalEnd != '|' &&
7905- *literalEnd != '}' && *literalEnd != '$' && *literalEnd != '\n')
7882+ if (constraints[i].equals("0"))
79067883 {
7907- ++literalEnd;
7884+ assert(i < opnds.size());
7885+ CVariable* cv = opnds[i];
7886+ if (!cv->IsImmediate())
7887+ {
7888+ // Replace dest reg with src reg
7889+ opnds[0] = cv;
7890+ }
7891+ else
7892+ {
7893+ // If src is immediate, dest may not be initialized, so initialize it
7894+ m_encoder->Copy(m_destination, cv);
7895+ m_encoder->Push();
7896+ }
7897+ break;
79087898 }
7909- str.write(lastEmitted, literalEnd - lastEmitted);
7910- lastEmitted = literalEnd;
7911- break;
79127899 }
7913- case '\n':
7914- {
7915- ++lastEmitted;
7916- str << '\n';
7917- break;
7918- }
7919- case '$':
7920- {
7921- ++lastEmitted;
7922- const char* idStart = lastEmitted;
7923- const char* idEnd = idStart;
7924- while (*idEnd >= '0' && *idEnd <= '9')
7925- ++idEnd;
7900+ }
79267901
7927- unsigned val = 0;
7928- if (StringRef(idStart, idEnd - idStart).getAsInteger(10, val))
7929- {
7930- assert(0 && "Invalid operand format");
7931- return;
7932- }
7933- lastEmitted = idEnd;
7902+ str << endl << "/// Inlined ASM" << endl;
7903+ // Look for variables to replace with the VISA variable
7904+ size_t startPos = 0;
7905+ while (startPos < asmStr.size())
7906+ {
7907+ size_t varPos = asmStr.find('$', startPos);
7908+ if (varPos == string::npos)
7909+ break;
79347910
7935- if (val >= opnds.size())
7936- {
7937- assert(0 && "Invalid operand index") ;
7938- return;
7939- }
7911+ // Find the operand number
7912+ const char* idStart = &(asmStr[varPos + 1]);
7913+ const char* idEnd = idStart ;
7914+ while (*idEnd >= '0' && *idEnd <= '9')
7915+ ++idEnd;
79407916
7941- str << m_encoder->GetVariableName(opnds[val]);
7942- break;
7917+ unsigned val = 0;
7918+ if (StringRef(idStart, idEnd - idStart).getAsInteger(10, val))
7919+ {
7920+ assert(0 && "Invalid operand format");
7921+ return;
79437922 }
7923+ if (val >= opnds.size())
7924+ {
7925+ assert(0 && "Invalid operand index");
7926+ return;
79447927 }
7928+ string varName = m_encoder->GetVariableName(opnds[val]);
7929+ asmStr.replace(varPos, (idEnd - idStart + 1), varName);
7930+
7931+ startPos = varPos + varName.size();
79457932 }
7946- if (str.str().back() != '\n') str << endl;
7933+
7934+ str << asmStr;
7935+ if (asmStr.back() != '\n') str << endl;
79477936 str << "/// End Inlined ASM" << endl << endl;
79487937}
79497938
0 commit comments