Skip to content

Commit 0370913

Browse files
changpengrampitec
andcommitted
AMDGPU: Support v_wmma_f32_16x16x128_f8f6f4 on gfx1250
Co-Authored-by: Stanislav Mekhanoshin <[email protected]>
1 parent 9878ef3 commit 0370913

31 files changed

+1696
-34
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,7 @@ TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_fp8, "V8hV16iV16iIsV8hIbI
705705
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_fp8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
706706
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
707707
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8, "V8hV16iV16iIsV8hIbIb", "nc", "gfx1250-insts,wavefrontsize32")
708+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4, "V8fIiV16iIiV16iIsV8f", "nc", "gfx1250-insts,wavefrontsize32")
708709
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
709710
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_fp8_bf8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")
710711
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8, "V8fV16iV16iIsV8fIbIb", "nc", "gfx1250-insts,wavefrontsize32")

clang/lib/CodeGen/TargetBuiltins/AMDGPU.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
855855
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_fp8:
856856
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_bf8_bf8:
857857
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x64_iu8:
858+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
858859
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
859860
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_f16:
860861
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x64_bf16:
@@ -1118,6 +1119,10 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
11181119
ArgsForMatchingMatrixTypes = {4, 1};
11191120
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x64_iu8;
11201121
break;
1122+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4:
1123+
ArgsForMatchingMatrixTypes = {5, 1, 3};
1124+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4;
1125+
break;
11211126
case AMDGPU::BI__builtin_amdgcn_wmma_f32_32x16x128_f4:
11221127
ArgsForMatchingMatrixTypes = {3, 0, 1};
11231128
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_32x16x128_f4;

clang/test/CodeGenOpenCL/builtins-amdgcn-gfx1250-wmma-w32.cl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c)
157157
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, true);
158158
}
159159

160+
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x128_f8f6f4(
161+
// CHECK-GFX1250-NEXT: entry:
162+
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = shufflevector <16 x i32> [[B:%.*]], <16 x i32> poison, <12 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11>
163+
// CHECK-GFX1250-NEXT: [[TMP1:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x128.f8f6f4.v8f32.v16i32.v12i32(i32 1, <16 x i32> [[A:%.*]], i32 2, <12 x i32> [[TMP0]], i16 0, <8 x float> [[C:%.*]])
164+
// CHECK-GFX1250-NEXT: store <8 x float> [[TMP1]], ptr addrspace(1) [[OUT:%.*]], align 32, !tbaa [[TBAA4]]
165+
// CHECK-GFX1250-NEXT: ret void
166+
//
167+
void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c)
168+
{
169+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, 0, c);
170+
}
171+
160172
// CHECK-GFX1250-LABEL: @test_amdgcn_wmma_f32_16x16x32_f16(
161173
// CHECK-GFX1250-NEXT: entry:
162174
// CHECK-GFX1250-NEXT: [[TMP0:%.*]] = tail call <8 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v8f32.v16f16(i1 false, <16 x half> [[A:%.*]], i1 false, <16 x half> [[B:%.*]], i16 0, <8 x float> [[C:%.*]], i1 false, i1 true)

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx1250-wmma-w32-param.cl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ void test_amdgcn_wmma_i32_16x16x64_iu8(global v8i* out, v8i a, v8i b, v8i c, int
114114
*out = __builtin_amdgcn_wmma_i32_16x16x64_iu8(0, a, 0, b, c, false, mod); // expected-error {{'__builtin_amdgcn_wmma_i32_16x16x64_iu8' must be a constant integer}}
115115
}
116116

117+
void test_amdgcn_wmma_f32_16x16x128_f8f6f4(global v8f* out, v16i a, v16i b, v8f c, int mod)
118+
{
119+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(mod, a, 2, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
120+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, mod, b, 0, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
121+
*out = __builtin_amdgcn_wmma_f32_16x16x128_f8f6f4(1, a, 2, b, mod, c); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x128_f8f6f4' must be a constant integer}}
122+
}
123+
117124
void test_amdgcn_wmma_f32_16x16x32_f16(global v8f* out, v16h a, v16h b, v8f c, int mod)
118125
{
119126
*out = __builtin_amdgcn_wmma_f32_16x16x32_f16(mod, a, 0, b, 0, c, false, false); // expected-error {{'__builtin_amdgcn_wmma_f32_16x16x32_f16' must be a constant integer}}

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3717,6 +3717,20 @@ class AMDGPUWmmaIntrinsicModsAllDiff<LLVMType DstTy, LLVMType AB, LLVMType C> :
37173717
IntrWillReturn, IntrNoCallback, IntrNoFree]
37183718
>;
37193719

3720+
class AMDGPUWmmaIntrinsicModsC_MatrixFMT :
3721+
Intrinsic<
3722+
[llvm_anyfloat_ty], // %D
3723+
[
3724+
llvm_i32_ty, // matrix_a_fmt
3725+
llvm_anyint_ty, // %A
3726+
llvm_i32_ty, // matrix_b_fmt
3727+
llvm_anyint_ty, // %B
3728+
llvm_i16_ty, // %C_mod: 0 - none, 1 - neg, 2 - abs, 3 - neg(abs)
3729+
LLVMMatchType<0>, // %C
3730+
],
3731+
[IntrNoMem, IntrConvergent, ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<4>>, IntrWillReturn, IntrNoCallback, IntrNoFree]
3732+
>;
3733+
37203734
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
37213735
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
37223736
def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
@@ -3741,6 +3755,7 @@ def int_amdgcn_wmma_f32_16x16x128_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint
37413755
def int_amdgcn_wmma_f32_16x16x128_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
37423756
def int_amdgcn_wmma_f32_16x16x128_bf8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
37433757
def int_amdgcn_wmma_i32_16x16x64_iu8 : AMDGPUWmmaIntrinsicModsAB<llvm_anyint_ty, llvm_anyint_ty>;
3758+
def int_amdgcn_wmma_f32_16x16x128_f8f6f4 : AMDGPUWmmaIntrinsicModsC_MatrixFMT;
37443759
def int_amdgcn_wmma_f32_32x16x128_f4 : AMDGPUWmmaIntrinsicF4ModsC<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
37453760
}
37463761

llvm/lib/IR/Verifier.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6627,6 +6627,54 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
66276627
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
66286628
break;
66296629
}
6630+
case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
6631+
Value *Src0 = Call.getArgOperand(1);
6632+
Value *Src1 = Call.getArgOperand(3);
6633+
6634+
unsigned FmtA = cast<ConstantInt>(Call.getArgOperand(0))->getZExtValue();
6635+
unsigned FmtB = cast<ConstantInt>(Call.getArgOperand(2))->getZExtValue();
6636+
Check(FmtA <= 4, "invalid value for matrix format", Call,
6637+
Call.getArgOperand(0));
6638+
Check(FmtB <= 4, "invalid value for matrix format", Call,
6639+
Call.getArgOperand(2));
6640+
6641+
// AMDGPU::MatrixFMT values
6642+
auto getFormatNumRegs = [](unsigned FormatVal) {
6643+
switch (FormatVal) {
6644+
case 0:
6645+
case 1:
6646+
return 16u;
6647+
case 2:
6648+
case 3:
6649+
return 12u;
6650+
case 4:
6651+
return 8u;
6652+
default:
6653+
llvm_unreachable("invalid format value");
6654+
}
6655+
};
6656+
6657+
auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
6658+
if (!Ty || !Ty->getElementType()->isIntegerTy(32))
6659+
return false;
6660+
unsigned NumElts = Ty->getNumElements();
6661+
return NumElts == 16 || NumElts == 12 || NumElts == 8;
6662+
};
6663+
6664+
auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
6665+
auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
6666+
Check(isValidSrcASrcBVector(Src0Ty),
6667+
"operand 1 must be 8, 12 or 16 element i32 vector", &Call, Src0);
6668+
Check(isValidSrcASrcBVector(Src1Ty),
6669+
"operand 3 must be 8, 12 or 16 element i32 vector", &Call, Src1);
6670+
6671+
// Permit excess registers for the format.
6672+
Check(Src0Ty->getNumElements() >= getFormatNumRegs(FmtA),
6673+
"invalid vector type for format", &Call, Src0, Call.getArgOperand(0));
6674+
Check(Src1Ty->getNumElements() >= getFormatNumRegs(FmtB),
6675+
"invalid vector type for format", &Call, Src1, Call.getArgOperand(2));
6676+
break;
6677+
}
66306678
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
66316679
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
66326680
Value *V = Call.getArgOperand(0);

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,47 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
16941694
NewII->takeName(&II);
16951695
return IC.replaceInstUsesWith(II, NewII);
16961696
}
1697+
case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4: {
1698+
Value *Src0 = II.getArgOperand(1);
1699+
Value *Src1 = II.getArgOperand(3);
1700+
unsigned FmtA = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
1701+
uint64_t FmtB = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue();
1702+
auto *Src0Ty = cast<FixedVectorType>(Src0->getType());
1703+
auto *Src1Ty = cast<FixedVectorType>(Src1->getType());
1704+
1705+
bool MadeChange = false;
1706+
unsigned Src0NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtA);
1707+
unsigned Src1NumElts = AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(FmtB);
1708+
1709+
// Depending on the used format, fewer registers are required so shrink the
1710+
// vector type.
1711+
if (Src0Ty->getNumElements() > Src0NumElts) {
1712+
Src0 = IC.Builder.CreateExtractVector(
1713+
FixedVectorType::get(Src0Ty->getElementType(), Src0NumElts), Src0,
1714+
IC.Builder.getInt64(0));
1715+
MadeChange = true;
1716+
}
1717+
1718+
if (Src1Ty->getNumElements() > Src1NumElts) {
1719+
Src1 = IC.Builder.CreateExtractVector(
1720+
FixedVectorType::get(Src1Ty->getElementType(), Src1NumElts), Src1,
1721+
IC.Builder.getInt64(0));
1722+
MadeChange = true;
1723+
}
1724+
1725+
if (!MadeChange)
1726+
return std::nullopt;
1727+
1728+
SmallVector<Value *, 13> Args(II.args());
1729+
Args[1] = Src0;
1730+
Args[3] = Src1;
1731+
1732+
CallInst *NewII = IC.Builder.CreateIntrinsic(
1733+
IID, {II.getArgOperand(5)->getType(), Src0->getType(), Src1->getType()},
1734+
Args, &II);
1735+
NewII->takeName(&II);
1736+
return IC.replaceInstUsesWith(II, NewII);
1737+
}
16971738
}
16981739
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =
16991740
AMDGPU::getImageDimIntrinsicInfo(II.getIntrinsicID())) {

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4714,6 +4714,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
47144714
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_fp8:
47154715
case Intrinsic::amdgcn_wmma_f32_16x16x128_bf8_bf8:
47164716
case Intrinsic::amdgcn_wmma_i32_16x16x64_iu8:
4717+
case Intrinsic::amdgcn_wmma_f32_16x16x128_f8f6f4:
47174718
case Intrinsic::amdgcn_wmma_f32_32x16x128_f4:
47184719
case Intrinsic::amdgcn_swmmac_f16_16x16x64_f16:
47194720
case Intrinsic::amdgcn_swmmac_bf16_16x16x64_bf16:

llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
176176
ImmTyWaitVAVDst,
177177
ImmTyWaitVMVSrc,
178178
ImmTyBitOp3,
179+
ImmTyMatrixAFMT,
180+
ImmTyMatrixBFMT,
179181
ImmTyMatrixAReuse,
180182
ImmTyMatrixBReuse,
181183
ImmTyByteSel,
@@ -423,6 +425,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
423425
bool isIndexKey8bit() const { return isImmTy(ImmTyIndexKey8bit); }
424426
bool isIndexKey16bit() const { return isImmTy(ImmTyIndexKey16bit); }
425427
bool isIndexKey32bit() const { return isImmTy(ImmTyIndexKey32bit); }
428+
bool isMatrixAFMT() const { return isImmTy(ImmTyMatrixAFMT); }
429+
bool isMatrixBFMT() const { return isImmTy(ImmTyMatrixBFMT); }
426430
bool isMatrixAReuse() const { return isImmTy(ImmTyMatrixAReuse); }
427431
bool isMatrixBReuse() const { return isImmTy(ImmTyMatrixBReuse); }
428432
bool isTFE() const { return isImmTy(ImmTyTFE); }
@@ -1174,6 +1178,8 @@ class AMDGPUOperand : public MCParsedAsmOperand {
11741178
case ImmTyWaitVAVDst: OS << "WaitVAVDst"; break;
11751179
case ImmTyWaitVMVSrc: OS << "WaitVMVSrc"; break;
11761180
case ImmTyBitOp3: OS << "BitOp3"; break;
1181+
case ImmTyMatrixAFMT: OS << "ImmTyMatrixAFMT"; break;
1182+
case ImmTyMatrixBFMT: OS << "ImmTyMatrixBFMT"; break;
11771183
case ImmTyMatrixAReuse: OS << "ImmTyMatrixAReuse"; break;
11781184
case ImmTyMatrixBReuse: OS << "ImmTyMatrixBReuse"; break;
11791185
case ImmTyByteSel: OS << "ByteSel" ; break;
@@ -1714,6 +1720,10 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
17141720
ParseStatus parseIndexKey8bit(OperandVector &Operands);
17151721
ParseStatus parseIndexKey16bit(OperandVector &Operands);
17161722
ParseStatus parseIndexKey32bit(OperandVector &Operands);
1723+
ParseStatus tryParseMatrixFMT(OperandVector &Operands, StringRef Name,
1724+
AMDGPUOperand::ImmTy Type);
1725+
ParseStatus parseMatrixAFMT(OperandVector &Operands);
1726+
ParseStatus parseMatrixBFMT(OperandVector &Operands);
17171727

17181728
ParseStatus parseDfmtNfmt(int64_t &Format);
17191729
ParseStatus parseUfmt(int64_t &Format);
@@ -1849,6 +1859,7 @@ class AMDGPUAsmParser : public MCTargetAsmParser {
18491859
const unsigned CPol);
18501860
bool validateTFE(const MCInst &Inst, const OperandVector &Operands);
18511861
std::optional<StringRef> validateLdsDirect(const MCInst &Inst);
1862+
bool validateWMMA(const MCInst &Inst, const OperandVector &Operands);
18521863
unsigned getConstantBusLimit(unsigned Opcode) const;
18531864
bool usesConstantBus(const MCInst &Inst, unsigned OpIdx);
18541865
bool isInlineConstant(const MCInst &Inst, unsigned OpIdx) const;
@@ -5400,6 +5411,37 @@ bool AMDGPUAsmParser::validateTFE(const MCInst &Inst,
54005411
return true;
54015412
}
54025413

5414+
bool AMDGPUAsmParser::validateWMMA(const MCInst &Inst,
5415+
const OperandVector &Operands) {
5416+
unsigned Opc = Inst.getOpcode();
5417+
const MCRegisterInfo *TRI = getContext().getRegisterInfo();
5418+
const MCInstrDesc &Desc = MII.get(Opc);
5419+
5420+
auto validateFmt = [&](AMDGPU::OpName FmtOp, AMDGPU::OpName SrcOp) -> bool {
5421+
int FmtIdx = AMDGPU::getNamedOperandIdx(Opc, FmtOp);
5422+
if (FmtIdx == -1)
5423+
return true;
5424+
unsigned Fmt = Inst.getOperand(FmtIdx).getImm();
5425+
int SrcIdx = AMDGPU::getNamedOperandIdx(Opc, SrcOp);
5426+
unsigned RegSize =
5427+
TRI->getRegClass(Desc.operands()[SrcIdx].RegClass).getSizeInBits();
5428+
5429+
if (RegSize == AMDGPU::wmmaScaleF8F6F4FormatToNumRegs(Fmt) * 32)
5430+
return true;
5431+
5432+
static const char *FmtNames[] = {"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
5433+
"MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
5434+
"MATRIX_FMT_FP4"};
5435+
5436+
Error(getRegLoc(mc2PseudoReg(Inst.getOperand(SrcIdx).getReg()), Operands),
5437+
"wrong register tuple size for " + Twine(FmtNames[Fmt]));
5438+
return false;
5439+
};
5440+
5441+
return validateFmt(AMDGPU::OpName::matrix_a_fmt, AMDGPU::OpName::src0) &&
5442+
validateFmt(AMDGPU::OpName::matrix_b_fmt, AMDGPU::OpName::src1);
5443+
}
5444+
54035445
bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
54045446
const SMLoc &IDLoc,
54055447
const OperandVector &Operands) {
@@ -5533,6 +5575,9 @@ bool AMDGPUAsmParser::validateInstruction(const MCInst &Inst,
55335575
if (!validateTFE(Inst, Operands)) {
55345576
return false;
55355577
}
5578+
if (!validateWMMA(Inst, Operands)) {
5579+
return false;
5580+
}
55365581

55375582
return true;
55385583
}
@@ -7191,6 +7236,26 @@ ParseStatus AMDGPUAsmParser::parseIndexKey32bit(OperandVector &Operands) {
71917236
return tryParseIndexKey(Operands, AMDGPUOperand::ImmTyIndexKey32bit);
71927237
}
71937238

7239+
ParseStatus AMDGPUAsmParser::tryParseMatrixFMT(OperandVector &Operands,
7240+
StringRef Name,
7241+
AMDGPUOperand::ImmTy Type) {
7242+
return parseStringOrIntWithPrefix(Operands, Name,
7243+
{"MATRIX_FMT_FP8", "MATRIX_FMT_BF8",
7244+
"MATRIX_FMT_FP6", "MATRIX_FMT_BF6",
7245+
"MATRIX_FMT_FP4"},
7246+
Type);
7247+
}
7248+
7249+
ParseStatus AMDGPUAsmParser::parseMatrixAFMT(OperandVector &Operands) {
7250+
return tryParseMatrixFMT(Operands, "matrix_a_fmt",
7251+
AMDGPUOperand::ImmTyMatrixAFMT);
7252+
}
7253+
7254+
ParseStatus AMDGPUAsmParser::parseMatrixBFMT(OperandVector &Operands) {
7255+
return tryParseMatrixFMT(Operands, "matrix_b_fmt",
7256+
AMDGPUOperand::ImmTyMatrixBFMT);
7257+
}
7258+
71947259
// dfmt and nfmt (in a tbuffer instruction) are parsed as one to allow their
71957260
// values to live in a joint format operand in the MCInst encoding.
71967261
ParseStatus AMDGPUAsmParser::parseDfmtNfmt(int64_t &Format) {
@@ -9292,6 +9357,20 @@ void AMDGPUAsmParser::cvtVOP3P(MCInst &Inst, const OperandVector &Operands,
92929357
DefaultVal);
92939358
}
92949359

9360+
int MatrixAFMTIdx =
9361+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_a_fmt);
9362+
if (MatrixAFMTIdx != -1) {
9363+
addOptionalImmOperand(Inst, Operands, OptIdx,
9364+
AMDGPUOperand::ImmTyMatrixAFMT, 0);
9365+
}
9366+
9367+
int MatrixBFMTIdx =
9368+
AMDGPU::getNamedOperandIdx(Opc, AMDGPU::OpName::matrix_b_fmt);
9369+
if (MatrixBFMTIdx != -1) {
9370+
addOptionalImmOperand(Inst, Operands, OptIdx,
9371+
AMDGPUOperand::ImmTyMatrixBFMT, 0);
9372+
}
9373+
92959374
if (AMDGPU::hasNamedOperand(Opc, AMDGPU::OpName::matrix_a_reuse))
92969375
addOptionalImmOperand(Inst, Operands, OptIdx,
92979376
AMDGPUOperand::ImmTyMatrixAReuse, 0);

0 commit comments

Comments
 (0)