@@ -227,7 +227,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
227227 return elemType == tt::ScaleDotElemType::E2M1 ||
228228 elemType == tt::ScaleDotElemType::E4M3 ||
229229 elemType == tt::ScaleDotElemType::E5M2 ||
230- elemType == tt::ScaleDotElemType::BF16;
230+ elemType == tt::ScaleDotElemType::BF16 ||
231+ elemType == tt::ScaleDotElemType::FP16;
231232 };
232233 if (!supportsTypes (aElemType) || !supportsTypes (bElemType))
233234 return rewriter.notifyMatchFailure (scaledDotOp, " NYI: mxfp6 operand" );
@@ -263,27 +264,31 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
263264 assert ((aDesc.scale || bDesc.scale ) && " No scale provided" );
264265 assert (!(aDesc.scale && bDesc.scale ) && " NYI: Both LHS and RHS scale" );
265266
267+ bool useFp16 = aDesc.elemType == tt::ScaleDotElemType::FP16 ||
268+ bDesc.elemType == tt::ScaleDotElemType::FP16;
269+
266270 if (aDesc.scale ) {
267271 TensorValue newA =
268272 convertScaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandA>(
269- aDesc, fastMath, dpasEnc, newRetType, mod, rewriter);
273+ aDesc, useFp16, fastMath, dpasEnc, newRetType, mod, rewriter);
270274 TensorValue newB =
271275 convertUnscaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandB>(
272- bDesc, dpasEnc, newRetType, rewriter);
276+ bDesc, useFp16, dpasEnc, newRetType, rewriter);
273277 return {newA, newB};
274278 }
275279
276280 TensorValue newB =
277281 convertScaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandB>(
278- bDesc, fastMath, dpasEnc, newRetType, mod, rewriter);
282+ bDesc, useFp16, fastMath, dpasEnc, newRetType, mod, rewriter);
279283 TensorValue newA =
280284 convertUnscaledOperand<ttgi::DpasEncodingAttr::OpIdx::OperandA>(
281- aDesc, dpasEnc, newRetType, rewriter);
285+ aDesc, useFp16, dpasEnc, newRetType, rewriter);
282286 return {newA, newB};
283287 }
284288
285289 template <ttgi::DpasEncodingAttr::OpIdx opIdx>
286- TensorValue convertScaledOperand (OpDescriptor opDesc, bool fastMath,
290+ TensorValue convertScaledOperand (OpDescriptor opDesc, bool useFp16,
291+ bool fastMath,
287292 ttg::intel::DpasEncodingAttr dpasEnc,
288293 RankedTensorType retType, ModuleOp mod,
289294 PatternRewriter &rewriter) const {
@@ -304,7 +309,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
304309 auto newOpEncoding = ttg::DotOperandEncodingAttr::get (
305310 ctx, unsigned (opIdx), opEncoding, opEncoding.getOpsPerChannel ());
306311 TensorValue op =
307- createArg (opDesc.op , opDesc.elemType , newOpEncoding, rewriter);
312+ createArg (opDesc.op , opDesc.elemType , useFp16, newOpEncoding, rewriter);
308313
309314 unsigned instrShapeM = dpasEnc.getDPASInstShapeA ()[0 ];
310315 SmallVector<unsigned , 2 > threadsPerWarp{instrShapeM,
@@ -332,7 +337,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
332337 }
333338
334339 template <ttgi::DpasEncodingAttr::OpIdx opIdx>
335- TensorValue convertUnscaledOperand (OpDescriptor opDesc,
340+ TensorValue convertUnscaledOperand (OpDescriptor opDesc, bool useFp16,
336341 ttg::intel::DpasEncodingAttr dpasEnc,
337342 RankedTensorType retType,
338343 PatternRewriter &rewriter) const {
@@ -341,7 +346,8 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
341346 auto newOpEncoding = ttg::DotOperandEncodingAttr::get (
342347 opDesc.op .getContext (), unsigned (opIdx), dpasEnc,
343348 dpasEnc.getOpsPerChannel ());
344- return createArg (opDesc.op , opDesc.elemType , newOpEncoding, rewriter);
349+ return createArg (opDesc.op , opDesc.elemType , useFp16, newOpEncoding,
350+ rewriter);
345351 }
346352
347353 ttg::intel::DpasEncodingAttr
@@ -385,7 +391,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
385391 oldAcc);
386392 }
387393
388- TensorValue createArg (TensorValue v, tt::ScaleDotElemType type,
394+ TensorValue createArg (TensorValue v, tt::ScaleDotElemType type, bool useFp16,
389395 Attribute vEncoding, PatternRewriter &rewriter) const {
390396 RankedTensorType vType = v.getType ();
391397 auto newVType = RankedTensorType::get (vType.getShape (),
@@ -395,13 +401,16 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
395401
396402 // convert to bf16
397403 if (type != tt::ScaleDotElemType::E2M1 &&
398- type != tt::ScaleDotElemType::BF16) {
404+ type != tt::ScaleDotElemType::BF16 &&
405+ type != tt::ScaleDotElemType::FP16) {
399406 assert (type == tt::ScaleDotElemType::E5M2 ||
400407 type == tt::ScaleDotElemType::E4M3);
401- auto vTypeBf16 = RankedTensorType::get (
402- newVType.getShape (), rewriter.getBF16Type (), newVType.getEncoding ());
408+ auto upcastedType = RankedTensorType::get (
409+ newVType.getShape (),
410+ useFp16 ? rewriter.getF16Type () : rewriter.getBF16Type (),
411+ newVType.getEncoding ());
403412 ret = cast<TypedValue<RankedTensorType>>(
404- rewriter.create <tt::FpToFpOp>(v.getLoc (), vTypeBf16 , ret)
413+ rewriter.create <tt::FpToFpOp>(v.getLoc (), upcastedType , ret)
405414 .getResult ());
406415 }
407416 return ret;
@@ -423,8 +432,11 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
423432 if (!scale)
424433 return v;
425434
435+ Builder b (v.getContext ());
436+ bool useFp16 = elemType == tt::ScaleDotElemType::FP16;
437+ Type outputElemType = useFp16 ? b.getF16Type () : b.getBF16Type ();
426438 auto retTy = triton::gpu::intel::UpcastMXFPOp::deduceOutputType (
427- v, elemType, Builder (v. getContext ()). getBF16Type () );
439+ v, elemType, outputElemType );
428440 return rewriter.create <ttgi::UpcastMXFPOp>(v.getLoc (), retTy, v, scale,
429441 elemType, fastMath);
430442 }
0 commit comments