Skip to content

Commit ae4cd7f

Browse files
committed
use vgpr16 for madmixfma
1 parent 0a4b87d commit ae4cd7f

File tree

1 file changed

+68
-64
lines changed

1 file changed

+68
-64
lines changed

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 68 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -158,36 +158,52 @@ defm V_PK_MAXIMUM3_F16 : VOP3PInst<"v_pk_maximum3_f16", VOP3P_Profile<VOP_V2F16_
158158
}
159159
} // End isCommutable = 1, FPDPRounding = 1
160160

161+
class MadFmaMixPatOp<dag op1, bit true16, bit isHi16 = 0> {
162+
dag ret = !if(true16,
163+
!if(isHi16,
164+
(REG_SEQUENCE VGPR_32, (i16 (IMPLICIT_DEF)), lo16, op1, hi16),
165+
(REG_SEQUENCE VGPR_32, op1, lo16, (i16 (IMPLICIT_DEF)), hi16)),
166+
op1);
167+
}
168+
161169
// TODO: Make sure we're doing the right thing with denormals. Note
162170
// that FMA and MAD will differ.
163-
multiclass MadFmaMixPats<SDPatternOperator fma_like,
164-
Instruction mix_inst,
165-
Instruction mixlo_inst,
166-
Instruction mixhi_inst,
167-
ValueType VT = f16,
168-
ValueType vecVT = v2f16> {
171+
multiclass MadFmaMixPatsImpl<SDPatternOperator fma_like,
172+
Instruction mix_inst,
173+
Instruction mixlo_inst,
174+
Instruction mixhi_inst,
175+
ValueType VT,
176+
ValueType vecVT,
177+
bit true16> {
169178
defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
170179
defvar VOP3PMadMixModsExtPat = !if (!eq(VT, bf16), VOP3PMadMixBF16ModsExt, VOP3PMadMixModsExt);
180+
171181
// At least one of the operands needs to be an fpextend of an f16
172182
// for this to be worthwhile, so we need three patterns here.
173183
// TODO: Could we use a predicate to inspect src1/2/3 instead?
174184
def : GCNPat <
175185
(f32 (fma_like (f32 (VOP3PMadMixModsExtPat VT:$src0, i32:$src0_mods)),
176186
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_mods)),
177187
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_mods)))),
178-
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
188+
(mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
189+
$src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
190+
$src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
179191
DSTCLAMP.NONE)>;
180192
def : GCNPat <
181193
(f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
182194
(f32 (VOP3PMadMixModsExtPat VT:$src1, i32:$src1_mods)),
183195
(f32 (VOP3PMadMixModsPat f32:$src2, i32:$src2_mods)))),
184-
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
196+
(mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
197+
$src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
198+
$src2_mods, $src2,
185199
DSTCLAMP.NONE)>;
186200
def : GCNPat <
187201
(f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
188202
(f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_mods)),
189203
(f32 (VOP3PMadMixModsExtPat VT:$src2, i32:$src2_mods)))),
190-
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
204+
(mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
205+
$src1_mods, $src1,
206+
$src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
191207
DSTCLAMP.NONE)>;
192208

193209
def : GCNPat <
@@ -198,13 +214,13 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
198214
(VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$hi_src0, i32:$hi_src0_modifiers)),
199215
(f32 (VOP3PMadMixModsPat VT:$hi_src1, i32:$hi_src1_modifiers)),
200216
(f32 (VOP3PMadMixModsPat VT:$hi_src2, i32:$hi_src2_modifiers))))))),
201-
(vecVT (mixhi_inst $hi_src0_modifiers, $hi_src0,
202-
$hi_src1_modifiers, $hi_src1,
203-
$hi_src2_modifiers, $hi_src2,
217+
(vecVT (mixhi_inst $hi_src0_modifiers, MadFmaMixPatOp<(VT $hi_src0), true16>.ret,
218+
$hi_src1_modifiers, MadFmaMixPatOp<(VT $hi_src1), true16>.ret,
219+
$hi_src2_modifiers, MadFmaMixPatOp<(VT $hi_src2), true16>.ret,
204220
DSTCLAMP.ENABLE,
205-
(mixlo_inst $lo_src0_modifiers, $lo_src0,
206-
$lo_src1_modifiers, $lo_src1,
207-
$lo_src2_modifiers, $lo_src2,
221+
(mixlo_inst $lo_src0_modifiers, MadFmaMixPatOp<(VT $lo_src0), true16>.ret,
222+
$lo_src1_modifiers, MadFmaMixPatOp<(VT $lo_src1), true16>.ret,
223+
$lo_src2_modifiers, MadFmaMixPatOp<(VT $lo_src2), true16>.ret,
208224
DSTCLAMP.ENABLE,
209225
(i32 (IMPLICIT_DEF)))))
210226
>;
@@ -233,28 +249,25 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
233249
(VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
234250
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
235251
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))),
236-
(mixlo_inst $src0_modifiers, $src0,
237-
$src1_modifiers, $src1,
238-
$src2_modifiers, $src2,
252+
(mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
253+
$src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
254+
$src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
239255
DSTCLAMP.NONE,
240256
(i32 (IMPLICIT_DEF)))
241257
>;
242258

243259
// FIXME: Special case handling for maxhi (especially for clamp)
244260
// because dealing with the write to high half of the register is
245261
// difficult.
246-
foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
247-
let True16Predicate = p in {
248-
249262
def : GCNPat <
250263
(build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
251264
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
252265
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
253-
(vecVT (mixhi_inst $src0_modifiers, $src0,
254-
$src1_modifiers, $src1,
255-
$src2_modifiers, $src2,
266+
(vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
267+
$src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
268+
$src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
256269
DSTCLAMP.NONE,
257-
VGPR_32:$elt0))
270+
MadFmaMixPatOp<(VT $elt0), true16>.ret))
258271
>;
259272

260273
def : GCNPat <
@@ -263,51 +276,42 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
263276
(AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
264277
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
265278
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
266-
(vecVT (mixhi_inst $src0_modifiers, $src0,
267-
$src1_modifiers, $src1,
268-
$src2_modifiers, $src2,
279+
(vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
280+
$src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
281+
$src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
269282
DSTCLAMP.ENABLE,
270-
VGPR_32:$elt0))
283+
MadFmaMixPatOp<(VT $elt0), true16>.ret))
271284
>;
285+
}
272286

273-
} // end True16Predicate
287+
multiclass MadFmaMixPats<SDPatternOperator fma_like,
288+
Instruction mix_inst,
289+
Instruction mixlo_inst,
290+
Instruction mixhi_inst,
291+
ValueType VT = f16,
292+
ValueType vecVT = v2f16> {
274293

275-
let True16Predicate = UseRealTrue16Insts in {
276-
def : GCNPat <
277-
(build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
278-
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
279-
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
280-
(vecVT (mixlo_inst $src0_modifiers, $src0,
281-
$src1_modifiers, $src1,
282-
$src2_modifiers, $src2,
283-
DSTCLAMP.NONE,
284-
(REG_SEQUENCE VGPR_32, (VT (IMPLICIT_DEF)), lo16, $elt1, hi16)))
285-
>;
294+
defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
286295

287-
def : GCNPat <
288-
(build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
289-
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
290-
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
291-
(vecVT (mixhi_inst $src0_modifiers, $src0,
292-
$src1_modifiers, $src1,
293-
$src2_modifiers, $src2,
294-
DSTCLAMP.NONE,
295-
(REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
296-
>;
296+
foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
297+
let True16Predicate = p in
298+
defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 0>;
299+
300+
let True16Predicate = UseRealTrue16Insts in {
301+
defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 1>;
302+
303+
def : GCNPat <
304+
(build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
305+
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
306+
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
307+
(vecVT (mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), /*true16*/1>.ret,
308+
$src1_modifiers, MadFmaMixPatOp<(VT $src1), /*true16*/1>.ret,
309+
$src2_modifiers, MadFmaMixPatOp<(VT $src2), /*true16*/1>.ret,
310+
DSTCLAMP.NONE,
311+
MadFmaMixPatOp<(VT $elt1), /*true16*/1, /*isHi16*/1>.ret))
312+
>;
313+
}
297314

298-
def : GCNPat <
299-
(build_vector
300-
VT:$elt0,
301-
(AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
302-
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
303-
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
304-
(vecVT (mixhi_inst $src0_modifiers, $src0,
305-
$src1_modifiers, $src1,
306-
$src2_modifiers, $src2,
307-
DSTCLAMP.ENABLE,
308-
(REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
309-
>;
310-
} // end True16Predicate
311315
}
312316

313317
class MinimumMaximumByMinimum3Maximum3VOP3P<SDPatternOperator node,

0 commit comments

Comments
 (0)