@@ -323,44 +323,45 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
323323 TensorValue scale = createScale (opDesc.scale , newScaleEncoding, rewriter);
324324
325325 return createUpcastMxfpOp (op, scale, opDesc.elemType , rewriter);
326- } else {
327- auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
328- opDesc.scale .getType ().getEncoding ());
329- assert (scaleEncoding && " Expecting blocked encoding for scale" );
330-
331- // Referring to
332- // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
333- // the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
334- unsigned scalingBlockSize = 32 ;
335- // 2 FP4E2M1 are packed in 1 I8
336- if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
337- scalingBlockSize = 16 ;
338- SmallVector<unsigned , 2 > sizePerThread (rank, 1 );
339- sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
340- auto newOpEncoding = ttg::BlockedEncodingAttr::get (
341- ctx, sizePerThread, scaleEncoding.getThreadsPerWarp (),
342- scaleEncoding.getWarpsPerCTA (), scaleEncoding.getCTAOrder (),
343- scaleEncoding.getCTALayout ());
344-
345- TensorValue op =
346- createArg (opDesc.op , opDesc.elemType , newOpEncoding, rewriter);
347- TensorValue scale = opDesc.scale ;
348-
349- auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get (
350- ctx, dpasEnc.getRepeatCount (), dpasEnc.getSystolicDepth (),
351- dpasEnc.getExecutionSize (), opsPerChannel, dpasEnc.getWarpsPerCTA (),
352- dpasEnc.getRepCluster (), dpasEnc.getSubGroupSize ());
353- auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get (
354- ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel ());
355-
356- auto upcastOp = createUpcastMxfpOp (op, scale, opDesc.elemType , rewriter);
357-
358- auto retType = cast<RankedTensorType>(upcastOp.getType ());
359- retType = RankedTensorType::get (
360- retType.getShape (), retType.getElementType (), retDotOpEncoding);
361- return rewriter.create <ttg::ConvertLayoutOp>(opDesc.op .getLoc (), retType,
362- upcastOp);
363326 }
327+
328+ // Temporary code: remove once upcast_mxfp support dot encoding.
329+ auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
330+ opDesc.scale .getType ().getEncoding ());
331+ assert (scaleEncoding && " Expecting blocked encoding for scale" );
332+
333+ // Referring to
334+ // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
335+ // the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
336+ unsigned scalingBlockSize = 32 ;
337+ // 2 FP4E2M1 are packed in 1 I8
338+ if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
339+ scalingBlockSize = 16 ;
340+ SmallVector<unsigned , 2 > sizePerThread (rank, 1 );
341+ sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
342+ auto newOpEncoding = ttg::BlockedEncodingAttr::get (
343+ ctx, sizePerThread, scaleEncoding.getThreadsPerWarp (),
344+ scaleEncoding.getWarpsPerCTA (), scaleEncoding.getCTAOrder (),
345+ scaleEncoding.getCTALayout ());
346+
347+ TensorValue op =
348+ createArg (opDesc.op , opDesc.elemType , newOpEncoding, rewriter);
349+ TensorValue scale = opDesc.scale ;
350+
351+ auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get (
352+ ctx, dpasEnc.getRepeatCount (), dpasEnc.getSystolicDepth (),
353+ dpasEnc.getExecutionSize (), opsPerChannel, dpasEnc.getWarpsPerCTA (),
354+ dpasEnc.getRepCluster (), dpasEnc.getSubGroupSize ());
355+ auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get (
356+ ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel ());
357+
358+ auto upcastOp = createUpcastMxfpOp (op, scale, opDesc.elemType , rewriter);
359+
360+ auto upcastRetType = cast<RankedTensorType>(upcastOp.getType ());
361+ retType = RankedTensorType::get (retType.getShape (),
362+ retType.getElementType (), retDotOpEncoding);
363+ return rewriter.create <ttg::ConvertLayoutOp>(opDesc.op .getLoc (),
364+ upcastRetType, upcastOp);
364365 }
365366
366367 template <unsigned opIdx>
0 commit comments