@@ -88,6 +88,10 @@ HexagonTargetLowering::initializeHVXLowering() {
8888 addRegisterClass (MVT::v64f32, &Hexagon::HvxWRRegClass);
8989 addRegisterClass (MVT::v128f16, &Hexagon::HvxWRRegClass);
9090 }
91+ if (Subtarget.useHVXV81Ops ()) {
92+ addRegisterClass (MVT::v64bf16, &Hexagon::HvxVRRegClass);
93+ addRegisterClass (MVT::v128bf16, &Hexagon::HvxWRRegClass);
94+ }
9195 }
9296
9397 // Set up operation actions.
@@ -162,6 +166,30 @@ HexagonTargetLowering::initializeHVXLowering() {
162166 setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v64f32, ByteW);
163167 setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v32f32, ByteV);
164168
169+ if (Subtarget.useHVXV81Ops ()) {
170+ setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v128bf16, ByteW);
171+ setPromoteTo (ISD::VECTOR_SHUFFLE, MVT::v64bf16, ByteV);
172+ setPromoteTo (ISD::SETCC, MVT::v64bf16, MVT::v64f32);
173+ setPromoteTo (ISD::FADD, MVT::v64bf16, MVT::v64f32);
174+ setPromoteTo (ISD::FSUB, MVT::v64bf16, MVT::v64f32);
175+ setPromoteTo (ISD::FMUL, MVT::v64bf16, MVT::v64f32);
176+ setPromoteTo (ISD::FMINNUM, MVT::v64bf16, MVT::v64f32);
177+ setPromoteTo (ISD::FMAXNUM, MVT::v64bf16, MVT::v64f32);
178+
179+ setOperationAction (ISD::SPLAT_VECTOR, MVT::v64bf16, Legal);
180+ setOperationAction (ISD::INSERT_SUBVECTOR, MVT::v64bf16, Custom);
181+ setOperationAction (ISD::EXTRACT_SUBVECTOR, MVT::v64bf16, Custom);
182+
183+ setOperationAction (ISD::MLOAD, MVT::v64bf16, Custom);
184+ setOperationAction (ISD::MSTORE, MVT::v64bf16, Custom);
185+ setOperationAction (ISD::BUILD_VECTOR, MVT::v64bf16, Custom);
186+ setOperationAction (ISD::CONCAT_VECTORS, MVT::v64bf16, Custom);
187+
188+ setOperationAction (ISD::SPLAT_VECTOR, MVT::bf16 , Custom);
189+ setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::bf16 , Custom);
190+ setOperationAction (ISD::BUILD_VECTOR, MVT::bf16 , Custom);
191+ }
192+
165193 for (MVT P : FloatW) {
166194 setOperationAction (ISD::LOAD, P, Custom);
167195 setOperationAction (ISD::STORE, P, Custom);
@@ -1667,14 +1695,15 @@ HexagonTargetLowering::LowerHvxBuildVector(SDValue Op, SelectionDAG &DAG)
16671695 // In case of MVT::f16 BUILD_VECTOR, since MVT::f16 is
16681696 // not a legal type, just bitcast the node to use i16
16691697 // types and bitcast the result back to f16
1670- if (VecTy.getVectorElementType () == MVT::f16 ) {
1671- SmallVector<SDValue,64 > NewOps;
1698+ if (VecTy.getVectorElementType () == MVT::f16 ||
1699+ VecTy.getVectorElementType () == MVT::bf16 ) {
1700+ SmallVector<SDValue, 64 > NewOps;
16721701 for (unsigned i = 0 ; i != Size; i++)
16731702 NewOps.push_back (DAG.getBitcast (MVT::i16 , Ops[i]));
16741703
1675- SDValue T0 = DAG. getNode (ISD::BUILD_VECTOR, dl,
1676- tyVector (VecTy, MVT::i16 ), NewOps);
1677- return DAG.getBitcast (tyVector (VecTy, MVT:: f16 ), T0);
1704+ SDValue T0 =
1705+ DAG. getNode (ISD::BUILD_VECTOR, dl, tyVector (VecTy, MVT::i16 ), NewOps);
1706+ return DAG.getBitcast (tyVector (VecTy, VecTy. getVectorElementType () ), T0);
16781707 }
16791708
16801709 // First, split the BUILD_VECTOR for vector pairs. We could generate
@@ -1698,7 +1727,7 @@ HexagonTargetLowering::LowerHvxSplatVector(SDValue Op, SelectionDAG &DAG)
16981727 MVT VecTy = ty (Op);
16991728 MVT ArgTy = ty (Op.getOperand (0 ));
17001729
1701- if (ArgTy == MVT::f16 ) {
1730+ if (ArgTy == MVT::f16 || ArgTy == MVT:: bf16 ) {
17021731 MVT SplatTy = MVT::getVectorVT (MVT::i16 , VecTy.getVectorNumElements ());
17031732 SDValue ToInt16 = DAG.getBitcast (MVT::i16 , Op.getOperand (0 ));
17041733 SDValue ToInt32 = DAG.getNode (ISD::ANY_EXTEND, dl, MVT::i32 , ToInt16);
@@ -1831,12 +1860,12 @@ HexagonTargetLowering::LowerHvxInsertElement(SDValue Op, SelectionDAG &DAG)
18311860 if (ElemTy == MVT::i1)
18321861 return insertHvxElementPred (VecV, IdxV, ValV, dl, DAG);
18331862
1834- if (ElemTy == MVT::f16 ) {
1863+ if (ElemTy == MVT::f16 || ElemTy == MVT:: bf16 ) {
18351864 SDValue T0 = DAG.getNode (ISD::INSERT_VECTOR_ELT, dl,
18361865 tyVector (VecTy, MVT::i16 ),
18371866 DAG.getBitcast (tyVector (VecTy, MVT::i16 ), VecV),
18381867 DAG.getBitcast (MVT::i16 , ValV), IdxV);
1839- return DAG.getBitcast (tyVector (VecTy, MVT:: f16 ), T0);
1868+ return DAG.getBitcast (tyVector (VecTy, ElemTy ), T0);
18401869 }
18411870
18421871 return insertHvxElementReg (VecV, IdxV, ValV, dl, DAG);
@@ -2334,6 +2363,20 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
23342363 MVT VecTy = ty (Op);
23352364 MVT ArgTy = ty (Op.getOperand (0 ));
23362365 const SDLoc &dl (Op);
2366+
2367+ if (ArgTy == MVT::v64bf16) {
2368+ MVT HalfTy = typeSplit (VecTy).first ;
2369+ SDValue BF16Vec = Op.getOperand (0 );
2370+ SDValue Zeroes = getInstr (Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
2371+ // Interleave zero vector with the bf16 vector, with zeroes in the lower half
2372+ // of each 32 bit lane, effectively extending the bf16 values to fp32 values.
2373+ SDValue ShuffVec = getInstr (Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
2374+ VectorPair VecPair = opSplit (ShuffVec, dl, DAG);
2375+ SDValue Result = getInstr (Hexagon::V6_vshuffvdd, dl, VecTy,
2376+ {VecPair.second , VecPair.first , DAG.getSignedConstant (-4 , dl, MVT::i32 )}, DAG);
2377+ return Result;
2378+ }
2379+
23372380 assert (VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);
23382381
23392382 SDValue F16Vec = Op.getOperand (0 );
0 commit comments