@@ -291,10 +291,12 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
291291 static_assert (opIdx == 0 || opIdx == 1 , " Illegal operand index" );
292292 assert (opDesc.scale && " Expecting valid operand & scale" );
293293
294- unsigned opsPerChannel = dpasEnc.getOpsPerChannel ();
295-
296294 MLIRContext *ctx = opDesc.op .getContext ();
295+ unsigned numWarps = ttg::TritonGPUDialect::getNumWarps (mod);
296+ unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp (mod);
297+ unsigned opsPerChannel = dpasEnc.getOpsPerChannel ();
297298 unsigned rank = retType.getRank ();
299+
298300 if (upcastMXFPUseDotOpEnc) {
299301 if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
300302 opsPerChannel *= 2 ;
@@ -312,7 +314,6 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
312314 unsigned instrShapeM = dpasEnc.getDPASInstShapeA ()[1 ];
313315 SmallVector<unsigned , 2 > threadsPerWarp{instrShapeM,
314316 warpSize / instrShapeM};
315- int numWarps = ttg::TritonGPUDialect::getNumWarps (mod);
316317 SmallVector<unsigned , 2 > warpsPerCTA (rank, 1 );
317318 warpsPerCTA[0 ] = numWarps;
318319 auto CTALayout = ttg::getCTALayout (retType.getEncoding ());
@@ -323,44 +324,52 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
323324 TensorValue scale = createScale (opDesc.scale , newScaleEncoding, rewriter);
324325
325326 return createUpcastMxfpOp (op, scale, opDesc.elemType , rewriter);
326- } else {
327- auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
328- opDesc.scale .getType ().getEncoding ());
329- assert (scaleEncoding && " Expecting blocked encoding for scale" );
330-
331- // Referring to
332- // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
333- // the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
334- unsigned scalingBlockSize = 32 ;
335- // 2 FP4E2M1 are packed in 1 I8
336- if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
337- scalingBlockSize = 16 ;
338- SmallVector<unsigned , 2 > sizePerThread (rank, 1 );
339- sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
340- auto newOpEncoding = ttg::BlockedEncodingAttr::get (
341- ctx, sizePerThread, scaleEncoding.getThreadsPerWarp (),
342- scaleEncoding.getWarpsPerCTA (), scaleEncoding.getCTAOrder (),
343- scaleEncoding.getCTALayout ());
344-
345- TensorValue op =
346- createArg (opDesc.op , opDesc.elemType , newOpEncoding, rewriter);
347- TensorValue scale = opDesc.scale ;
348-
349- auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get (
350- ctx, dpasEnc.getRepeatCount (), dpasEnc.getSystolicDepth (),
351- dpasEnc.getExecutionSize (), opsPerChannel, dpasEnc.getWarpsPerCTA (),
352- dpasEnc.getRepCluster (), dpasEnc.getSubGroupSize ());
353- auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get (
354- ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel ());
355-
356- auto upcastOp = createUpcastMxfpOp (op, scale, opDesc.elemType , rewriter);
357-
358- auto retType = cast<RankedTensorType>(upcastOp.getType ());
359- retType = RankedTensorType::get (
360- retType.getShape (), retType.getElementType (), retDotOpEncoding);
361- return rewriter.create <ttg::ConvertLayoutOp>(opDesc.op .getLoc (), retType,
362- upcastOp);
363327 }
328+
329+ auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
330+ opDesc.scale .getType ().getEncoding ());
331+ assert (scaleEncoding && " Expecting blocked encoding for scale" );
332+
333+ // Referring to
334+ // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
335+ // the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
336+ unsigned scalingBlockSize = 32 ;
337+ // 2 FP4E2M1 are packed in 1 I8
338+ if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
339+ scalingBlockSize = 16 ;
340+ SmallVector<unsigned > sizePerThread = {1 , 1 };
341+ SmallVector<unsigned > threadsPerWarp = {1 , 1 };
342+ sizePerThread[!opIdx] = scalingBlockSize;
343+ threadsPerWarp[opIdx] = warpSize;
344+ SmallVector<unsigned > warpsPerCTA = {numWarps, 1 };
345+
346+ auto newOpEncoding = ttg::BlockedEncodingAttr::get (
347+ ctx, sizePerThread, threadsPerWarp, warpsPerCTA,
348+ scaleEncoding.getCTAOrder (), scaleEncoding.getCTALayout ());
349+ TensorValue op =
350+ createArg (opDesc.op , opDesc.elemType , newOpEncoding, rewriter);
351+
352+ warpsPerCTA = opIdx ? SmallVector<unsigned >{1 , numWarps}
353+ : SmallVector<unsigned >{numWarps, 1 };
354+ auto newScaleEncoding = ttg::BlockedEncodingAttr::get (
355+ ctx, {1 , 1 }, {warpSize, 1 }, warpsPerCTA, scaleEncoding.getCTAOrder (),
356+ scaleEncoding.getCTALayout ());
357+ TensorValue scale = createScale (opDesc.scale , newScaleEncoding, rewriter);
358+
359+ auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get (
360+ ctx, dpasEnc.getRepeatCount (), dpasEnc.getSystolicDepth (),
361+ dpasEnc.getExecutionSize (), opsPerChannel, dpasEnc.getWarpsPerCTA (),
362+ dpasEnc.getRepCluster (), dpasEnc.getSubGroupSize ());
363+ auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get (
364+ ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel ());
365+
366+ auto upcastOp = createUpcastMxfpOp (op, scale, opDesc.elemType , rewriter);
367+
368+ auto resultType = cast<RankedTensorType>(upcastOp.getType ());
369+ resultType = RankedTensorType::get (
370+ resultType.getShape (), resultType.getElementType (), retDotOpEncoding);
371+ return rewriter.create <ttg::ConvertLayoutOp>(opDesc.op .getLoc (), resultType,
372+ upcastOp);
364373 }
365374
366375 template <unsigned opIdx>
0 commit comments