Skip to content

Commit 9ab90fa

Browse files
committed
Address code review comments
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 698a0bd commit 9ab90fa

File tree

1 file changed

+38
-37
lines changed

1 file changed

+38
-37
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -323,44 +323,45 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
323323
TensorValue scale = createScale(opDesc.scale, newScaleEncoding, rewriter);
324324

325325
return createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
326-
} else {
327-
auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
328-
opDesc.scale.getType().getEncoding());
329-
assert(scaleEncoding && "Expecting blocked encoding for scale");
330-
331-
// Referring to
332-
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
333-
// the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
334-
unsigned scalingBlockSize = 32;
335-
// 2 FP4E2M1 are packed in 1 I8
336-
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
337-
scalingBlockSize = 16;
338-
SmallVector<unsigned, 2> sizePerThread(rank, 1);
339-
sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
340-
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
341-
ctx, sizePerThread, scaleEncoding.getThreadsPerWarp(),
342-
scaleEncoding.getWarpsPerCTA(), scaleEncoding.getCTAOrder(),
343-
scaleEncoding.getCTALayout());
344-
345-
TensorValue op =
346-
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
347-
TensorValue scale = opDesc.scale;
348-
349-
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
350-
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
351-
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
352-
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
353-
auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(
354-
ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel());
355-
356-
auto upcastOp = createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
357-
358-
auto retType = cast<RankedTensorType>(upcastOp.getType());
359-
retType = RankedTensorType::get(
360-
retType.getShape(), retType.getElementType(), retDotOpEncoding);
361-
return rewriter.create<ttg::ConvertLayoutOp>(opDesc.op.getLoc(), retType,
362-
upcastOp);
363326
}
327+
328+
// Temporary code: remove once upcast_mxfp support dot encoding.
329+
auto scaleEncoding = dyn_cast<ttg::BlockedEncodingAttr>(
330+
opDesc.scale.getType().getEncoding());
331+
assert(scaleEncoding && "Expecting blocked encoding for scale");
332+
333+
// Referring to
334+
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
335+
// the scalingBlockSize should be 32 for E5M2, E4M3 and E2M1
336+
unsigned scalingBlockSize = 32;
337+
// 2 FP4E2M1 are packed in 1 I8
338+
if (opDesc.elemType == tt::ScaleDotElemType::E2M1)
339+
scalingBlockSize = 16;
340+
SmallVector<unsigned, 2> sizePerThread(rank, 1);
341+
sizePerThread[rank - 1 - opIdx] = scalingBlockSize;
342+
auto newOpEncoding = ttg::BlockedEncodingAttr::get(
343+
ctx, sizePerThread, scaleEncoding.getThreadsPerWarp(),
344+
scaleEncoding.getWarpsPerCTA(), scaleEncoding.getCTAOrder(),
345+
scaleEncoding.getCTALayout());
346+
347+
TensorValue op =
348+
createArg(opDesc.op, opDesc.elemType, newOpEncoding, rewriter);
349+
TensorValue scale = opDesc.scale;
350+
351+
auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get(
352+
ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(),
353+
dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(),
354+
dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize());
355+
auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(
356+
ctx, opIdx, retDpasEncoding, retDpasEncoding.getOpsPerChannel());
357+
358+
auto upcastOp = createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter);
359+
360+
auto upcastRetType = cast<RankedTensorType>(upcastOp.getType());
361+
retType = RankedTensorType::get(retType.getShape(),
362+
retType.getElementType(), retDotOpEncoding);
363+
return rewriter.create<ttg::ConvertLayoutOp>(opDesc.op.getLoc(),
364+
upcastRetType, upcastOp);
364365
}
365366

366367
template <unsigned opIdx>

0 commit comments

Comments
 (0)