Skip to content

Commit b80b025

Browse files
changpengrampitecshiltian
authored
AMDGPU: Implement MC layer support for gfx1250 wmma instructions. (#148570)
Regular wmma/swmmac plus matrix reuse only. --------- Co-authored-by: Stanislav Mekhanoshin <[email protected]> Co-authored-by: Shilei Tian <[email protected]>
1 parent 5e303e8 commit b80b025

File tree

11 files changed

+2682
-77
lines changed

11 files changed

+2682
-77
lines changed

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
157157
ImmTyNegHi,
158158
ImmTyIndexKey8bit,
159159
ImmTyIndexKey16bit,
160+
ImmTyIndexKey32bit,
160161
ImmTyDPP8,
161162
ImmTyDppCtrl,
162163
ImmTyDppRowMask,
@@ -174,8 +175,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
174175
ImmTyWaitEXP,
175176
ImmTyWaitVAVDst,
176177
ImmTyWaitVMVSrc,
177-
ImmTyByteSel,
178178
ImmTyBitOp3,
179+
ImmTyMatrixAReuse,
180+
ImmTyMatrixBReuse,
181+
ImmTyByteSel,
179182
};
180183

181184
// Immediate operand kind.
@@ -419,6 +422,9 @@ class AMDGPUOperand : public MCParsedAsmOperand {
419422
bool isCPol() const { return isImmTy(ImmTyCPol); }
420423
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
421424
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
425+
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
426+
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
427+
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
422428
bool isTFE() const { return isImmTy(ImmTyTFE); }
423429
bool isFORMAT() const { return isImmTy(ImmTyFORMAT) && isUInt<7>(getImm()); }
424430
bool isDppFI() const { return isImmTy(ImmTyDppFI); }
@@ -747,6 +753,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
747753
return isRegOrInlineNoMods(AMDGPU::VReg_256RegClassID, MVT::f64);
748754
}
749755

756+
bool isVISrc_512_f64() const {
757+
return isRegOrInlineNoMods(AMDGPU::VReg_512RegClassID, MVT::f64);
758+
}
759+
750760
bool isVISrc_128B16() const {
751761
return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::i16);
752762
}
@@ -1116,6 +1126,7 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11161126
case ImmTyCPol: OS << "CPol"; break;
11171127
case ImmTyIndexKey8bit: OS << "index_key"; break;
11181128
case ImmTyIndexKey16bit: OS << "index_key"; break;
1129+
case ImmTyIndexKey32bit: OS << "index_key"; break;
11191130
case ImmTyTFE: OS << "TFE"; break;
11201131
case ImmTyD16: OS << "D16"; break;
11211132
case ImmTyFORMAT: OS << "FORMAT"; break;
@@ -1162,8 +1173,10 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11621173
case ImmTyWaitEXP: OS << "WaitEXP"; break;
11631174
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
11641175
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
1165-
case ImmTyByteSel: OS << "ByteSel" ; break;
11661176
case ImmTyBitOp3: OS << "BitOp3"; break;
1177+
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
1178+
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
1179+
case ImmTyByteSel: OS << "ByteSel" ; break;
11671180
}
11681181
// clang-format on
11691182
}
@@ -1700,6 +1713,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
17001713
AMDGPUOperand::ImmTy ImmTy);
17011714
ParseStatus parseIndexKey8bit(OperandVector &Operands);
17021715
ParseStatus parseIndexKey16bit(OperandVector &Operands);
1716+
ParseStatus parseIndexKey32bit(OperandVector &Operands);
17031717

17041718
ParseStatus parseDfmtNfmt(int64_t &Format);
17051719
ParseStatus parseUfmt(int64_t &Format);
@@ -7153,7 +7167,9 @@ ParseStatus AMDGPUAsmParser::tryParseIndexKey(OperandVector &Operands,
71537167
if (!Res.isSuccess())
71547168
return Res;
71557169

7156-
if (ImmTy == AMDGPUOperand::ImmTyIndexKey16bit && (ImmVal < 0 || ImmVal > 1))
7170+
if ((ImmTy == AMDGPUOperand::ImmTyIndexKey16bit ||
7171+
ImmTy == AMDGPUOperand::ImmTyIndexKey32bit) &&
7172+
(ImmVal < 0 || ImmVal > 1))
71577173
return Error(Loc, Twine("out of range ", StringRef(Pref)));
71587174

71597175
if (ImmTy == AMDGPUOperand::ImmTyIndexKey8bit && (ImmVal < 0 || ImmVal > 3))
@@ -7171,6 +7187,10 @@ ParseStatus AMDGPUAsmParser::parseIndexKey16bit(OperandVector &Operands) {
71717187
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey16bit);
71727188
}
71737189

7190+
ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
7191+
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
7192+
}
7193+
71747194
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
71757195
// values to live in a joint format operand in the MCInst encoding.
71767196
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9272,6 +9292,14 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
92729292
DefaultVal);
92739293
}
92749294

9295+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
9296+
addOptionalImmOperand(Inst, Operands, OptIdx,
9297+
AMDGPUOperand::ImmTyMatrixAReuse, 0);
9298+
9299+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_b_reuse))
9300+
addOptionalImmOperand(Inst, Operands, OptIdx,
9301+
AMDGPUOperand::ImmTyMatrixBReuse, 0);
9302+
92759303
int NegLoIdx = AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::neg_lo);
92769304
if (NegLoIdx != -1)
92779305
addOptionalImmOperand(Inst, Operands, OptIdx, AMDGPUOperand::ImmTyNegLo);
@@ -9378,6 +9406,10 @@ void AMDGPUAsmParser::cvtSWMMAC(MCInst &Inst, const OperandVector &Operands) {
93789406
addOptionalImmOperand(Inst, Operands, OptIdx,
93799407
AMDGPUOperand::ImmTyIndexKey16bit);
93809408

9409+
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::index_key_32bit))
9410+
addOptionalImmOperand(Inst, Operands, OptIdx,
9411+
AMDGPUOperand::ImmTyIndexKey32bit);
9412+
93819413
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::clamp))
93829414
addOptionalImmOperand(Inst, Operands, OptIdx, AMDGPUOperand::ImmTyClamp);
93839415

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,16 @@ void AMDGPUInstPrinter::printIndexKey16bit(const MCInst *MI, unsigned OpNo,
13321332
O << " index_key:" << Imm;
13331333
}
13341334

1335+
void AMDGPUInstPrinter::printIndexKey32bit(const MCInst *MI, unsigned OpNo,
1336+
const MCSubtargetInfo &STI,
1337+
raw_ostream &O) {
1338+
auto Imm = MI->getOperand(OpNo).getImm() & 0x7;
1339+
if (Imm == 0)
1340+
return;
1341+
1342+
O << " index_key:" << Imm;
1343+
}
1344+
13351345
void AMDGPUInstPrinter::printInterpSlot(const MCInst *MI, unsigned OpNum,
13361346
const MCSubtargetInfo &STI,
13371347
raw_ostream &O) {

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ class AMDGPUInstPrinter : public MCInstPrinter {
132132
const MCSubtargetInfo &STI, raw_ostream &O);
133133
void printIndexKey16bit(const MCInst *MI, unsigned OpNo,
134134
const MCSubtargetInfo &STI, raw_ostream &O);
135+
void printIndexKey32bit(const MCInst *MI, unsigned OpNo,
136+
const MCSubtargetInfo &STI, raw_ostream &O);
135137
void printInterpSlot(const MCInst *MI, unsigned OpNo,
136138
const MCSubtargetInfo &STI, raw_ostream &O);
137139
void printInterpAttr(const MCInst *MI, unsigned OpNo,

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,11 @@ void AMDGPUMCCodeEmitter::encodeInstruction(const MCInst &MI,
381381

382382
// Set unused op_sel_hi bits to 1 for VOP3P and MAI instructions.
383383
// Note that accvgpr_read/write are MAI, have src0, but do not use op_sel.
384-
if ((Desc.TSFlags & SIInstrFlags::VOP3P) ||
385-
Opcode == AMDGPU::V_ACCVGPR_READ_B32_vi ||
386-
Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_vi) {
384+
if (((Desc.TSFlags & SIInstrFlags::VOP3P) ||
385+
Opcode == AMDGPU::V_ACCVGPR_READ_B32_vi ||
386+
Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_vi) &&
387+
// Matrix B reuse operand reuses op_sel_hi.
388+
!AMDGPU::hasNamedOperand(Opcode, AMDGPU::OpName::matrix_b_reuse)) {
387389
Encoding |= getImplicitOpSelHiEncoding(Opcode);
388390
}
389391

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,7 @@ def op_sel_hi0 : ArrayOperand0<"op_sel_hi", "OpSelHi">;
12471247
def neg_lo0 : ArrayOperand0<"neg_lo", "NegLo">;
12481248
def neg_hi0 : ArrayOperand0<"neg_hi", "NegHi">;
12491249

1250+
def IndexKey32bit : CustomOperand<i32, 1>;
12501251
def IndexKey16bit : CustomOperand<i32, 1>;
12511252
def IndexKey8bit : CustomOperand<i32, 1>;
12521253

@@ -1302,6 +1303,9 @@ let PrintMethod = "printBitOp3" in
13021303
def BitOp3 : NamedIntOperand<"bitop3">;
13031304
def bitop3_0 : DefaultOperand<BitOp3, 0>;
13041305

1306+
def MatrixAReuse : NamedBitOperand<"matrix_a_reuse">;
1307+
def MatrixBReuse : NamedBitOperand<"matrix_b_reuse">;
1308+
13051309
class KImmFPOperand<ValueType vt> : ImmOperand<vt> {
13061310
let OperandNamespace = "AMDGPU";
13071311
let OperandType = "OPERAND_KIMM"#vt.Size;
@@ -1641,6 +1645,7 @@ def WMMAModsF16NegAbs : ComplexPattern<untyped, 2, "SelectWMMAModsF16NegAbs">;
16411645
def WMMAVISrc : ComplexPattern<untyped, 1, "SelectWMMAVISrc">;
16421646
def SWMMACIndex8 : ComplexPattern<untyped, 2, "SelectSWMMACIndex8">;
16431647
def SWMMACIndex16 : ComplexPattern<untyped, 2, "SelectSWMMACIndex16">;
1648+
def SWMMACIndex32 : ComplexPattern<untyped, 2, "SelectSWMMACIndex32">;
16441649

16451650
def VOP3OpSel : ComplexPattern<untyped, 2, "SelectVOP3OpSel">;
16461651

@@ -2654,6 +2659,7 @@ class VOPProfile <list<ValueType> _ArgVT, bit _EnableClamp = 0> {
26542659
isModifierType<Src2VT>.ret,
26552660
HasOMod);
26562661
field bit HasNeg = HasModifiers;
2662+
field bit HasMatrixReuse = 0;
26572663

26582664
field bit HasSrc0Mods = HasModifiers;
26592665
field bit HasSrc1Mods = !if(HasModifiers, !or(HasSrc1FloatMods, HasSrc1IntMods), 0);

llvm/lib/Target/AMDGPU/SIRegisterInfo.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,7 @@ def VISrc_256_f32 : SrcRegOrImm9 <VReg_256, "OPERAND_REG_INLINE_C_FP32">;
12941294
def VISrc_256_f64 : SrcRegOrImm9 <VReg_256, "OPERAND_REG_INLINE_C_FP64">;
12951295
def VISrc_512_b32 : SrcRegOrImm9 <VReg_512, "OPERAND_REG_INLINE_C_INT32">;
12961296
def VISrc_512_f32 : SrcRegOrImm9 <VReg_512, "OPERAND_REG_INLINE_C_FP32">;
1297+
def VISrc_512_f64 : SrcRegOrImm9 <VReg_512, "OPERAND_REG_INLINE_C_FP64">;
12971298
def VISrc_1024_b32 : SrcRegOrImm9 <VReg_1024, "OPERAND_REG_INLINE_C_INT32">;
12981299
def VISrc_1024_f32 : SrcRegOrImm9 <VReg_1024, "OPERAND_REG_INLINE_C_FP32">;
12991300

0 commit comments

Comments
 (0)