Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 68 additions & 64 deletions llvm/lib/Target/AMDGPU/VOP3PInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,36 +158,52 @@ defm V_PK_MAXIMUM3_F16 : VOP3PInst<"v_pk_maximum3_f16", VOP3P_Profile<VOP_V2F16_
}
} // End isCommutable = 1, FPDPRounding = 1

class MadFmaMixPatOp<dag op1, bit true16, bit isHi16 = 0> {
dag ret = !if(true16,
!if(isHi16,
(REG_SEQUENCE VGPR_32, (i16 (IMPLICIT_DEF)), lo16, op1, hi16),
(REG_SEQUENCE VGPR_32, op1, lo16, (i16 (IMPLICIT_DEF)), hi16)),
op1);
}

// TODO: Make sure we're doing the right thing with denormals. Note
// that FMA and MAD will differ.
multiclass MadFmaMixPats<SDPatternOperator fma_like,
Instruction mix_inst,
Instruction mixlo_inst,
Instruction mixhi_inst,
ValueType VT = f16,
ValueType vecVT = v2f16> {
multiclass MadFmaMixPatsImpl<SDPatternOperator fma_like,
Instruction mix_inst,
Instruction mixlo_inst,
Instruction mixhi_inst,
ValueType VT,
ValueType vecVT,
bit true16> {
defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);
defvar VOP3PMadMixModsExtPat = !if (!eq(VT, bf16), VOP3PMadMixBF16ModsExt, VOP3PMadMixModsExt);

// At least one of the operands needs to be an fpextend of an f16
// for this to be worthwhile, so we need three patterns here.
// TODO: Could we use a predicate to inspect src1/2/3 instead?
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixModsExtPat VT:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_mods)))),
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
(mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
$src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
$src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE)>;
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsExtPat VT:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsPat f32:$src2, i32:$src2_mods)))),
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
(mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
$src1_mods, MadFmaMixPatOp<(VT $src1), true16>.ret,
$src2_mods, $src2,
DSTCLAMP.NONE)>;
def : GCNPat <
(f32 (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_mods)),
(f32 (VOP3PMadMixModsPat f32:$src1, i32:$src1_mods)),
(f32 (VOP3PMadMixModsExtPat VT:$src2, i32:$src2_mods)))),
(mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
(mix_inst $src0_mods, MadFmaMixPatOp<(VT $src0), true16>.ret,
$src1_mods, $src1,
$src2_mods, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE)>;

def : GCNPat <
Expand All @@ -198,13 +214,13 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$hi_src0, i32:$hi_src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$hi_src1, i32:$hi_src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$hi_src2, i32:$hi_src2_modifiers))))))),
(vecVT (mixhi_inst $hi_src0_modifiers, $hi_src0,
$hi_src1_modifiers, $hi_src1,
$hi_src2_modifiers, $hi_src2,
(vecVT (mixhi_inst $hi_src0_modifiers, MadFmaMixPatOp<(VT $hi_src0), true16>.ret,
$hi_src1_modifiers, MadFmaMixPatOp<(VT $hi_src1), true16>.ret,
$hi_src2_modifiers, MadFmaMixPatOp<(VT $hi_src2), true16>.ret,
DSTCLAMP.ENABLE,
(mixlo_inst $lo_src0_modifiers, $lo_src0,
$lo_src1_modifiers, $lo_src1,
$lo_src2_modifiers, $lo_src2,
(mixlo_inst $lo_src0_modifiers, MadFmaMixPatOp<(VT $lo_src0), true16>.ret,
$lo_src1_modifiers, MadFmaMixPatOp<(VT $lo_src1), true16>.ret,
$lo_src2_modifiers, MadFmaMixPatOp<(VT $lo_src2), true16>.ret,
DSTCLAMP.ENABLE,
(i32 (IMPLICIT_DEF)))))
>;
Expand Down Expand Up @@ -233,28 +249,25 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))),
(mixlo_inst $src0_modifiers, $src0,
$src1_modifiers, $src1,
$src2_modifiers, $src2,
(mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
$src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
$src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE,
(i32 (IMPLICIT_DEF)))
>;

// FIXME: Special case handling for maxhi (especially for clamp)
// because dealing with the write to high half of the register is
// difficult.
foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
let True16Predicate = p in {

def : GCNPat <
(build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
(vecVT (mixhi_inst $src0_modifiers, $src0,
$src1_modifiers, $src1,
$src2_modifiers, $src2,
(vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
$src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
$src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.NONE,
VGPR_32:$elt0))
MadFmaMixPatOp<(VT $elt0), true16>.ret))
>;

def : GCNPat <
Expand All @@ -263,51 +276,42 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
(AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
(vecVT (mixhi_inst $src0_modifiers, $src0,
$src1_modifiers, $src1,
$src2_modifiers, $src2,
(vecVT (mixhi_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), true16>.ret,
$src1_modifiers, MadFmaMixPatOp<(VT $src1), true16>.ret,
$src2_modifiers, MadFmaMixPatOp<(VT $src2), true16>.ret,
DSTCLAMP.ENABLE,
VGPR_32:$elt0))
MadFmaMixPatOp<(VT $elt0), true16>.ret))
>;
}

} // end True16Predicate
multiclass MadFmaMixPats<SDPatternOperator fma_like,
Instruction mix_inst,
Instruction mixlo_inst,
Instruction mixhi_inst,
ValueType VT = f16,
ValueType vecVT = v2f16> {

let True16Predicate = UseRealTrue16Insts in {
def : GCNPat <
(build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
(vecVT (mixlo_inst $src0_modifiers, $src0,
$src1_modifiers, $src1,
$src2_modifiers, $src2,
DSTCLAMP.NONE,
(REG_SEQUENCE VGPR_32, (VT (IMPLICIT_DEF)), lo16, $elt1, hi16)))
>;
defvar VOP3PMadMixModsPat = !if (!eq(VT, bf16), VOP3PMadMixBF16Mods, VOP3PMadMixMods);

def : GCNPat <
(build_vector VT:$elt0, (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers)))))),
(vecVT (mixhi_inst $src0_modifiers, $src0,
$src1_modifiers, $src1,
$src2_modifiers, $src2,
DSTCLAMP.NONE,
(REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
>;
foreach p = [NotHasTrue16BitInsts, UseFakeTrue16Insts] in
let True16Predicate = p in
defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 0>;

let True16Predicate = UseRealTrue16Insts in {
defm : MadFmaMixPatsImpl<fma_like, mix_inst, mixlo_inst, mixhi_inst, VT, vecVT, /*true16*/ 1>;

def : GCNPat <
(build_vector (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))), VT:$elt1),
(vecVT (mixlo_inst $src0_modifiers, MadFmaMixPatOp<(VT $src0), /*true16*/1>.ret,
$src1_modifiers, MadFmaMixPatOp<(VT $src1), /*true16*/1>.ret,
$src2_modifiers, MadFmaMixPatOp<(VT $src2), /*true16*/1>.ret,
DSTCLAMP.NONE,
MadFmaMixPatOp<(VT $elt1), /*true16*/1, /*isHi16*/1>.ret))
>;
}

def : GCNPat <
(build_vector
VT:$elt0,
(AMDGPUclamp (VT (fpround (fma_like (f32 (VOP3PMadMixModsPat VT:$src0, i32:$src0_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src1, i32:$src1_modifiers)),
(f32 (VOP3PMadMixModsPat VT:$src2, i32:$src2_modifiers))))))),
(vecVT (mixhi_inst $src0_modifiers, $src0,
$src1_modifiers, $src1,
$src2_modifiers, $src2,
DSTCLAMP.ENABLE,
(REG_SEQUENCE VGPR_32, $elt0, lo16, (VT (IMPLICIT_DEF)), hi16)))
>;
} // end True16Predicate
}

class MinimumMaximumByMinimum3Maximum3VOP3P<SDPatternOperator node,
Expand Down
Loading