@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
176176 ImmTyWaitVAVDst,
177177 ImmTyWaitVMVSrc,
178178 ImmTyBitOp3,
179+ ImmTyMatrixAFMT,
180+ ImmTyMatrixBFMT,
179181 ImmTyMatrixAReuse,
180182 ImmTyMatrixBReuse,
181183 ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
423425 bool isIndexKey8bit () const { return isImmTy (ImmTyIndexKey8bit); }
424426 bool isIndexKey16bit () const { return isImmTy (ImmTyIndexKey16bit); }
425427 bool isIndexKey32bit () const { return isImmTy (ImmTyIndexKey32bit); }
428+ bool isMatrixAFMT () const { return isImmTy (ImmTyMatrixAFMT); }
429+ bool isMatrixBFMT () const { return isImmTy (ImmTyMatrixBFMT); }
426430 bool isMatrixAReuse () const { return isImmTy (ImmTyMatrixAReuse); }
427431 bool isMatrixBReuse () const { return isImmTy (ImmTyMatrixBReuse); }
428432 bool isTFE () const { return isImmTy (ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11741178 case ImmTyWaitVAVDst: OS << " WaitVAVDst" ; break ;
11751179 case ImmTyWaitVMVSrc: OS << " WaitVMVSrc" ; break ;
11761180 case ImmTyBitOp3: OS << " BitOp3" ; break ;
1181+ case ImmTyMatrixAFMT: OS << " ImmTyMatrixAFMT" ; break ;
1182+ case ImmTyMatrixBFMT: OS << " ImmTyMatrixBFMT" ; break ;
11771183 case ImmTyMatrixAReuse: OS << " ImmTyMatrixAReuse" ; break ;
11781184 case ImmTyMatrixBReuse: OS << " ImmTyMatrixBReuse" ; break ;
11791185 case ImmTyByteSel: OS << " ByteSel" ; break ;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
17141720 ParseStatus parseIndexKey8bit (OperandVector &Operands);
17151721 ParseStatus parseIndexKey16bit (OperandVector &Operands);
17161722 ParseStatus parseIndexKey32bit (OperandVector &Operands);
1723+ ParseStatus tryParseMatrixFMT (OperandVector &Operands, StringRef Name,
1724+ AMDGPUOperand::ImmTy Type);
1725+ ParseStatus parseMatrixAFMT (OperandVector &Operands);
1726+ ParseStatus parseMatrixBFMT (OperandVector &Operands);
17171727
17181728 ParseStatus parseDfmtNfmt (int64_t &Format);
17191729 ParseStatus parseUfmt (int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
18491859 const unsigned CPol);
18501860 bool validateTFE (const MCInst &Inst, const OperandVector &Operands);
18511861 std::optional<StringRef> validateLdsDirect (const MCInst &Inst);
1862+ bool validateWMMA (const MCInst &Inst, const OperandVector &Operands);
18521863 unsigned getConstantBusLimit (unsigned Opcode) const ;
18531864 bool usesConstantBus (const MCInst &Inst, unsigned OpIdx);
18541865 bool isInlineConstant (const MCInst &Inst, unsigned OpIdx) const ;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
54005411 return true ;
54015412}
54025413
5414+ bool AMDGPUAsmParser::validateWMMA (const MCInst &Inst,
5415+ const OperandVector &Operands) {
5416+ unsigned Opc = Inst.getOpcode ();
5417+ const MCRegisterInfo *TRI = getContext ().getRegisterInfo ();
5418+ const MCInstrDesc &Desc = MII.get (Opc);
5419+
5420+ auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
5421+ int FmtIdx = AMDGPU::getNamedOperandIdx (Opc, FmtOp);
5422+ if (FmtIdx == -1 )
5423+ return true ;
5424+ unsigned Fmt = Inst.getOperand (FmtIdx).getImm ();
5425+ int SrcIdx = AMDGPU::getNamedOperandIdx (Opc, SrcOp);
5426+ unsigned RegSize =
5427+ TRI->getRegClass (Desc.operands ()[SrcIdx].RegClass ).getSizeInBits ();
5428+
5429+ if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs (Fmt) * 32 )
5430+ return true ;
5431+
5432+ static const char *FmtNames[] = {" MATRIX_FMT_FP8" , " MATRIX_FMT_BF8" ,
5433+ " MATRIX_FMT_FP6" , " MATRIX_FMT_BF6" ,
5434+ " MATRIX_FMT_FP4" };
5435+
5436+ Error (getRegLoc (mc2PseudoReg (Inst.getOperand (SrcIdx).getReg ()), Operands),
5437+ " wrong register tuple size for " + Twine (FmtNames[Fmt]));
5438+ return false ;
5439+ };
5440+
5441+ return validateFmt (AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
5442+ validateFmt (AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
5443+ }
5444+
54035445bool AMDGPUAsmParser::validateInstruction (const MCInst &Inst,
54045446 const SMLoc &IDLoc,
54055447 const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
55335575 if (!validateTFE (Inst, Operands)) {
55345576 return false ;
55355577 }
5578+ if (!validateWMMA (Inst, Operands)) {
5579+ return false ;
5580+ }
55365581
55375582 return true ;
55385583}
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
71917236 return tryParseIndexKey (Operands, AMDGPUOperand::ImmTyIndexKey32bit);
71927237}
71937238
7239+ ParseStatus AMDGPUAsmParser::tryParseMatrixFMT (OperandVector &Operands,
7240+ StringRef Name,
7241+ AMDGPUOperand::ImmTy Type) {
7242+ return parseStringOrIntWithPrefix (Operands, Name,
7243+ {" MATRIX_FMT_FP8" , " MATRIX_FMT_BF8" ,
7244+ " MATRIX_FMT_FP6" , " MATRIX_FMT_BF6" ,
7245+ " MATRIX_FMT_FP4" },
7246+ Type);
7247+ }
7248+
7249+ ParseStatus AMDGPUAsmParser::parseMatrixAFMT (OperandVector &Operands) {
7250+ return tryParseMatrixFMT (Operands, " matrix_a_fmt" ,
7251+ AMDGPUOperand::ImmTyMatrixAFMT);
7252+ }
7253+
7254+ ParseStatus AMDGPUAsmParser::parseMatrixBFMT (OperandVector &Operands) {
7255+ return tryParseMatrixFMT (Operands, " matrix_b_fmt" ,
7256+ AMDGPUOperand::ImmTyMatrixBFMT);
7257+ }
7258+
71947259// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
71957260// values to live in a joint format operand in the MCInst encoding.
71967261ParseStatus AMDGPUAsmParser::parseDfmtNfmt (int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
92929357 DefaultVal);
92939358 }
92949359
9360+ int MatrixAFMTIdx =
9361+ AMDGPU::getNamedOperandIdx (Opc, AMDGPU::OpName::matrix_a_fmt);
9362+ if (MatrixAFMTIdx != -1 ) {
9363+ addOptionalImmOperand (Inst, Operands, OptIdx,
9364+ AMDGPUOperand::ImmTyMatrixAFMT, 0 );
9365+ }
9366+
9367+ int MatrixBFMTIdx =
9368+ AMDGPU::getNamedOperandIdx (Opc, AMDGPU::OpName::matrix_b_fmt);
9369+ if (MatrixBFMTIdx != -1 ) {
9370+ addOptionalImmOperand (Inst, Operands, OptIdx,
9371+ AMDGPUOperand::ImmTyMatrixBFMT, 0 );
9372+ }
9373+
92959374 if (AMDGPU::hasNamedOperand (Opc, AMDGPU::OpName::matrix_a_reuse))
92969375 addOptionalImmOperand (Inst, Operands, OptIdx,
92979376 AMDGPUOperand::ImmTyMatrixAReuse, 0 );
0 commit comments