@@ -3861,58 +3861,114 @@ bool AMDGPUDAGToDAGISel::SelectVOP3OpSelMods(SDValue In, SDValue &Src,
38613861 return SelectVOP3Mods (In, Src, SrcMods);
38623862}
38633863
3864+ // Match lowered fpext from bf16 to f32. This is a bit operation extending
3865+ // a 16-bit value with 16-bit of zeroes at LSB:
3866+ //
3867+ // 1. (f32 (bitcast (build_vector (i16 0), (i16 (bitcast bf16:val)))))
3868+ // 2. (f32 (bitcast (and i32:val, 0xffff0000))) -> IsExtractHigh = true
3869+ // 3. (f32 (bitcast (shl i32:va, 16) -> IsExtractHigh = false
3870+ static SDValue matchBF16FPExtendLike (SDValue Op, bool &IsExtractHigh) {
3871+ if (Op.getValueType () != MVT::f32 || Op.getOpcode () != ISD::BITCAST)
3872+ return SDValue ();
3873+ Op = Op.getOperand (0 );
3874+
3875+ IsExtractHigh = false ;
3876+ if (Op.getValueType () == MVT::v2i16 && Op.getOpcode () == ISD::BUILD_VECTOR) {
3877+ auto Low16 = dyn_cast<ConstantSDNode>(Op.getOperand (0 ));
3878+ if (!Low16 || !Low16->isZero ())
3879+ return SDValue ();
3880+ Op = stripBitcast (Op.getOperand (1 ));
3881+ if (Op.getValueType () != MVT::bf16 )
3882+ return SDValue ();
3883+ return Op;
3884+ }
3885+
3886+ if (Op.getValueType () != MVT::i32 )
3887+ return SDValue ();
3888+
3889+ if (Op.getOpcode () == ISD::AND) {
3890+ if (auto Mask = dyn_cast<ConstantSDNode>(Op.getOperand (1 ))) {
3891+ if (Mask->getZExtValue () == 0xffff0000 ) {
3892+ IsExtractHigh = true ;
3893+ return Op.getOperand (0 );
3894+ }
3895+ }
3896+ return SDValue ();
3897+ }
3898+
3899+ if (Op.getOpcode () == ISD::SHL) {
3900+ if (auto Amt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ))) {
3901+ if (Amt->getZExtValue () == 16 )
3902+ return Op.getOperand (0 );
3903+ }
3904+ }
3905+
3906+ return SDValue ();
3907+ }
3908+
38643909// The return value is not whether the match is possible (which it always is),
38653910// but whether or not it a conversion is really used.
38663911bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl (SDValue In, SDValue &Src,
3867- unsigned &Mods) const {
3912+ unsigned &Mods,
3913+ MVT VT) const {
38683914 Mods = 0 ;
38693915 SelectVOP3ModsImpl (In, Src, Mods);
38703916
3917+ bool IsExtractHigh = false ;
38713918 if (Src.getOpcode () == ISD::FP_EXTEND) {
38723919 Src = Src.getOperand (0 );
3873- assert (Src.getValueType () == MVT::f16 );
3874- Src = stripBitcast (Src);
3920+ } else if (VT == MVT::bf16 ) {
3921+ SDValue B16 = matchBF16FPExtendLike (Src, IsExtractHigh);
3922+ if (!B16)
3923+ return false ;
3924+ Src = B16;
3925+ } else
3926+ return false ;
38753927
3876- // Be careful about folding modifiers if we already have an abs. fneg is
3877- // applied last, so we don't want to apply an earlier fneg.
3878- if ((Mods & SISrcMods::ABS) == 0 ) {
3879- unsigned ModsTmp;
3880- SelectVOP3ModsImpl (Src, Src, ModsTmp);
3928+ if (Src.getValueType () != VT &&
3929+ (VT != MVT::bf16 || Src.getValueType () != MVT::i32 ))
3930+ return false ;
38813931
3882- if ((ModsTmp & SISrcMods::NEG) != 0 )
3883- Mods ^= SISrcMods::NEG;
3932+ Src = stripBitcast (Src);
38843933
3885- if ((ModsTmp & SISrcMods::ABS) != 0 )
3886- Mods |= SISrcMods::ABS;
3887- }
3934+ // Be careful about folding modifiers if we already have an abs. fneg is
3935+ // applied last, so we don't want to apply an earlier fneg.
3936+ if ((Mods & SISrcMods::ABS) == 0 ) {
3937+ unsigned ModsTmp;
3938+ SelectVOP3ModsImpl (Src, Src, ModsTmp);
3939+
3940+ if ((ModsTmp & SISrcMods::NEG) != 0 )
3941+ Mods ^= SISrcMods::NEG;
38883942
3889- // op_sel/op_sel_hi decide the source type and source.
3890- // If the source's op_sel_hi is set, it indicates to do a conversion from fp16.
3891- // If the sources's op_sel is set, it picks the high half of the source
3892- // register.
3943+ if ((ModsTmp & SISrcMods::ABS) != 0 )
3944+ Mods |= SISrcMods::ABS;
3945+ }
38933946
3894- Mods |= SISrcMods::OP_SEL_1;
3895- if (isExtractHiElt (Src, Src)) {
3896- Mods |= SISrcMods::OP_SEL_0;
3947+ // op_sel/op_sel_hi decide the source type and source.
3948+ // If the source's op_sel_hi is set, it indicates to do a conversion from
3949+ // fp16. If the sources's op_sel is set, it picks the high half of the source
3950+ // register.
38973951
3898- // TODO: Should we try to look for neg/abs here?
3899- }
3952+ Mods |= SISrcMods::OP_SEL_1;
3953+ if (IsExtractHigh ||
3954+ (Src.getValueSizeInBits () == 16 && isExtractHiElt (Src, Src))) {
3955+ Mods |= SISrcMods::OP_SEL_0;
39003956
3901- // Prevent unnecessary subreg COPY to VGPR_16
3902- if (Src.getOpcode () == ISD::TRUNCATE &&
3903- Src.getOperand (0 ).getValueType () == MVT::i32 ) {
3904- Src = Src.getOperand (0 );
3905- }
3906- return true ;
3957+ // TODO: Should we try to look for neg/abs here?
39073958 }
39083959
3909- return false ;
3960+ // Prevent unnecessary subreg COPY to VGPR_16
3961+ if (Src.getOpcode () == ISD::TRUNCATE &&
3962+ Src.getOperand (0 ).getValueType () == MVT::i32 ) {
3963+ Src = Src.getOperand (0 );
3964+ }
3965+ return true ;
39103966}
39113967
39123968bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt (SDValue In, SDValue &Src,
39133969 SDValue &SrcMods) const {
39143970 unsigned Mods = 0 ;
3915- if (!SelectVOP3PMadMixModsImpl (In, Src, Mods))
3971+ if (!SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT:: f16 ))
39163972 return false ;
39173973 SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
39183974 return true ;
@@ -3921,7 +3977,24 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
39213977bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods (SDValue In, SDValue &Src,
39223978 SDValue &SrcMods) const {
39233979 unsigned Mods = 0 ;
3924- SelectVOP3PMadMixModsImpl (In, Src, Mods);
3980+ SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT::f16 );
3981+ SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
3982+ return true ;
3983+ }
3984+
3985+ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16ModsExt (SDValue In, SDValue &Src,
3986+ SDValue &SrcMods) const {
3987+ unsigned Mods = 0 ;
3988+ if (!SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT::bf16 ))
3989+ return false ;
3990+ SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
3991+ return true ;
3992+ }
3993+
3994+ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixBF16Mods (SDValue In, SDValue &Src,
3995+ SDValue &SrcMods) const {
3996+ unsigned Mods = 0 ;
3997+ SelectVOP3PMadMixModsImpl (In, Src, Mods, MVT::bf16 );
39253998 SrcMods = CurDAG->getTargetConstant (Mods, SDLoc (In), MVT::i32 );
39263999 return true ;
39274000}
0 commit comments