Skip to content

Commit e28c323

Browse files
committed
Address comments.
1 parent e960360 commit e28c323

File tree

4 files changed

+90
-55
lines changed

4 files changed

+90
-55
lines changed

clang/lib/Headers/amxfp8intrin.h

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*===---------- amxfp8intrin.h - AMX intrinsics -*- C++ -*------------===
1+
/*===------------- amxfp8intrin.h - AMX intrinsics -*- C++ -*----------------===
22
*
33
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
* See https://llvm.org/LICENSE.txt for license information.
@@ -15,9 +15,68 @@
1515
#define __AMXFP8INTRIN_H
1616
#ifdef __x86_64__
1717

18+
19+
/// Compute dot-product of brain-float8 (BF8) or hybrid-float8 (HF8)
20+
/// floating-point pairs in tiles \a a and \a b, accumulating the
21+
/// intermediate single-precision (32-bit) floating-point elements with
22+
/// elements in \a dst, and store the 32-bit result back to tile \a dst.
23+
///
24+
/// \headerfile <immintrin.h>
25+
///
26+
/// \code
27+
/// void _tile_dpbf8ps (__tile dst, __tile a, __tile b)
28+
/// \endcode
29+
///
30+
/// This intrinsic corresponds to the \c TDPBF8PS instruction.
31+
///
32+
/// \param dst
33+
/// The destination tile. Max size is 1024 Bytes.
34+
/// \param a
35+
/// The 1st source tile. Max size is 1024 Bytes.
36+
/// \param b
37+
/// The 2nd source tile. Max size is 1024 Bytes.
1838
#define _tile_dpbf8ps __builtin_ia32_tdpbf8ps
39+
40+
/// \code
41+
/// void _tile_dpbhf8ps (__tile dst, __tile a, __tile b)
42+
/// \endcode
43+
///
44+
/// This intrinsic corresponds to the \c TDPBHF8PS instruction.
45+
///
46+
/// \param dst
47+
/// The destination tile. Max size is 1024 Bytes.
48+
/// \param a
49+
/// The 1st source tile. Max size is 1024 Bytes.
50+
/// \param b
51+
/// The 2nd source tile. Max size is 1024 Bytes.
1952
#define _tile_dpbhf8ps __builtin_ia32_tdpbhf8ps
53+
54+
/// \code
55+
/// void _tile_dphbf8ps (__tile dst, __tile a, __tile b)
56+
/// \endcode
57+
///
58+
/// This intrinsic corresponds to the \c TDPHBF8PS instruction.
59+
///
60+
/// \param dst
61+
/// The destination tile. Max size is 1024 Bytes.
62+
/// \param a
63+
/// The 1st source tile. Max size is 1024 Bytes.
64+
/// \param b
65+
/// The 2nd source tile. Max size is 1024 Bytes.
2066
#define _tile_dphbf8ps __builtin_ia32_tdphbf8ps
67+
68+
/// \code
69+
/// void _tile_dphf8ps (__tile dst, __tile a, __tile b)
70+
/// \endcode
71+
///
72+
/// This intrinsic corresponds to the \c TDPHF8PS instruction.
73+
///
74+
/// \param dst
75+
/// The destination tile. Max size is 1024 Bytes.
76+
/// \param a
77+
/// The 1st source tile. Max size is 1024 Bytes.
78+
/// \param b
79+
/// The 2nd source tile. Max size is 1024 Bytes.
2180
#define _tile_dphf8ps __builtin_ia32_tdphf8ps
2281

2382
#endif /* __x86_64__ */

llvm/lib/Target/X86/X86.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def FeatureAMXCOMPLEX : SubtargetFeature<"amx-complex", "HasAMXCOMPLEX", "true",
271271
"Support AMX-COMPLEX instructions",
272272
[FeatureAMXTILE]>;
273273
def FeatureAMXFP8 : SubtargetFeature<"amx-fp8", "HasAMXFP8", "true",
274-
"Support AMX-FP8 instructions",
275-
[FeatureAMXTILE]>;
274+
"Support AMX-FP8 instructions",
275+
[FeatureAMXTILE]>;
276276
def FeatureCMPCCXADD : SubtargetFeature<"cmpccxadd", "HasCMPCCXADD", "true",
277277
"Support CMPCCXADD instructions">;
278278
def FeatureRAOINT : SubtargetFeature<"raoint", "HasRAOINT", "true",

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -37410,7 +37410,11 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
3741037410
case X86::PTDPBUSD:
3741137411
case X86::PTDPBUUD:
3741237412
case X86::PTDPBF16PS:
37413-
case X86::PTDPFP16PS: {
37413+
case X86::PTDPFP16PS:
37414+
case X86::PTDPBF8PS:
37415+
case X86::PTDPBHF8PS:
37416+
case X86::PTDPHBF8PS:
37417+
case X86::PTDPHF8PS: {
3741437418
unsigned Opc;
3741537419
switch (MI.getOpcode()) {
3741637420
// clang-format off
@@ -37421,6 +37425,10 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
3742137425
case X86::PTDPBUUD: Opc = X86::TDPBUUD; break;
3742237426
case X86::PTDPBF16PS: Opc = X86::TDPBF16PS; break;
3742337427
case X86::PTDPFP16PS: Opc = X86::TDPFP16PS; break;
37428+
case X86::PTDPBF8PS: Opc = X86::TDPBF8PS; break;
37429+
case X86::PTDPBHF8PS: Opc = X86::TDPBHF8PS; break;
37430+
case X86::PTDPHBF8PS: Opc = X86::TDPHBF8PS; break;
37431+
case X86::PTDPHF8PS: Opc = X86::TDPHF8PS; break;
3742437432
// clang-format on
3742537433
}
3742637434

@@ -37503,38 +37511,6 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
3750337511
MI.eraseFromParent(); // The pseudo is gone now.
3750437512
return BB;
3750537513
}
37506-
case X86::PTDPBF8PS:
37507-
case X86::PTDPBHF8PS:
37508-
case X86::PTDPHBF8PS:
37509-
case X86::PTDPHF8PS: {
37510-
const DebugLoc &DL = MI.getDebugLoc();
37511-
unsigned Opc;
37512-
switch (MI.getOpcode()) {
37513-
default:
37514-
llvm_unreachable("Unexpected instruction!");
37515-
case X86::PTDPBF8PS:
37516-
Opc = X86::TDPBF8PS;
37517-
break;
37518-
case X86::PTDPBHF8PS:
37519-
Opc = X86::TDPBHF8PS;
37520-
break;
37521-
case X86::PTDPHBF8PS:
37522-
Opc = X86::TDPHBF8PS;
37523-
break;
37524-
case X86::PTDPHF8PS:
37525-
Opc = X86::TDPHF8PS;
37526-
break;
37527-
}
37528-
37529-
MachineInstrBuilder MIB = BuildMI(*BB, MI, DL, TII->get(Opc));
37530-
MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Define);
37531-
MIB.addReg(TMMImmToTMMReg(MI.getOperand(0).getImm()), RegState::Undef);
37532-
MIB.addReg(TMMImmToTMMReg(MI.getOperand(1).getImm()), RegState::Undef);
37533-
MIB.addReg(TMMImmToTMMReg(MI.getOperand(2).getImm()), RegState::Undef);
37534-
37535-
MI.eraseFromParent();
37536-
return BB;
37537-
}
3753837514
}
3753937515
}
3754037516

llvm/lib/Target/X86/X86InstrAMX.td

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ let Predicates = [HasAMXFP8, In64BitMode] in {
274274
let Constraints = "$src1 = $dst" in {
275275
class AMX_FP8_BASE<bits<8> Opcode, string Opstr> :
276276
I<Opcode, MRMSrcReg4VOp3, (outs TILE:$dst),
277-
(ins TILE:$src1, TILE:$src2, TILE:$src3),
278-
!strconcat(Opstr, "\t{$src3, $src2, $dst|$dst, $src2, $src3}"),
279-
[]>, VEX, VVVV;
277+
(ins TILE:$src1, TILE:$src2, TILE:$src3),
278+
!strconcat(Opstr, "\t{$src3, $src2, $dst|$dst, $src2, $src3}"),
279+
[]>, VEX, VVVV;
280280
}
281281

282282
def TDPBF8PS : AMX_FP8_BASE<0xfd, "tdpbf8ps">, T_MAP5, PS;
@@ -287,22 +287,22 @@ let Predicates = [HasAMXFP8, In64BitMode] in {
287287
let usesCustomInserter = 1 in {
288288
// Pseudo instructions, using immediates instead of tile registers.
289289
// To be translated to the actual instructions in X86ISelLowering.cpp
290-
def PTDPBF8PS : PseudoI<(outs), (ins u8imm:$src1,
291-
u8imm:$src2, u8imm:$src3),
292-
[(int_x86_tdpbf8ps timm:$src1,
293-
timm:$src2, timm:$src3)]>;
294-
def PTDPBHF8PS : PseudoI<(outs), (ins u8imm:$src1,
295-
u8imm:$src2, u8imm:$src3),
296-
[(int_x86_tdpbhf8ps timm:$src1,
297-
timm:$src2, timm:$src3)]>;
298-
def PTDPHBF8PS : PseudoI<(outs), (ins u8imm:$src1,
299-
u8imm:$src2, u8imm:$src3),
300-
[(int_x86_tdphbf8ps timm:$src1,
301-
timm:$src2, timm:$src3)]>;
302-
def PTDPHF8PS : PseudoI<(outs), (ins u8imm:$src1,
303-
u8imm:$src2, u8imm:$src3),
304-
[(int_x86_tdphf8ps timm:$src1,
305-
timm:$src2, timm:$src3)]>;
290+
def PTDPBF8PS : PseudoI<(outs),
291+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
292+
[(int_x86_tdpbf8ps timm:$src1, timm:$src2,
293+
timm:$src3)]>;
294+
def PTDPBHF8PS : PseudoI<(outs),
295+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
296+
[(int_x86_tdpbhf8ps timm:$src1, timm:$src2,
297+
timm:$src3)]>;
298+
def PTDPHBF8PS : PseudoI<(outs),
299+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
300+
[(int_x86_tdphbf8ps timm:$src1, timm:$src2,
301+
timm:$src3)]>;
302+
def PTDPHF8PS : PseudoI<(outs),
303+
(ins u8imm:$src1, u8imm:$src2, u8imm:$src3),
304+
[(int_x86_tdphf8ps timm:$src1, timm:$src2,
305+
timm:$src3)]>;
306306
}
307307
}
308308
}

0 commit comments

Comments
 (0)