44#include " triton/Dialect/Triton/IR/Utility.h"
55#include " triton/Dialect/TritonGPU/IR/Attributes.h"
66#include " triton/Dialect/TritonGPU/IR/Dialect.h"
7+ #include " triton/Tools/Sys/GetEnv.hpp"
78
89#define GET_OP_CLASSES
910#include " triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
@@ -109,6 +110,8 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
109110 auto xShape = xTy.getShape ();
110111
111112 auto encoding = xTy.getEncoding ();
113+ bool upcastMXFPUseDotOpEnc =
114+ mlir::triton::tools::getBoolEnv (" TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING" );
112115
113116 if (typeEncoded == ScaleDotElemType::E2M1) {
114117 RankedTensorType retTy;
@@ -118,34 +121,47 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
118121 newShape.back () *= 2 ;
119122 retTy = RankedTensorType::get (xShape, FloatType::getBF16 (ctx));
120123 } else {
121- auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
122-
123- const int opIdx = oldEncoding.getOpIdx ();
124- const bool hasBatch = xShape.size () == 3 ;
125- const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
126- newShape[kIdx ] *= 2 ;
127124 Type elemType = FloatType::getBF16 (ctx);
128-
129- // Note: For Intel the dot operands layout's kWidth parameter must match
130- // the parent's DPAS layout opsPerChannel so we need to materialize a new
131- // DPAS layout.
132125 Attribute newVEncoding;
133- if (auto dpasEncoding =
134- dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent ())) {
135- auto newDpasEncoding = intel::DpasEncodingAttr::get (
136- ctx, dpasEncoding.getRepeatCount (), dpasEncoding.getSystolicDepth (),
137- dpasEncoding.getExecutionSize (),
138- intel::DpasEncodingAttr::getOpsPerChannel (elemType),
139- dpasEncoding.getWarpsPerCTA (), dpasEncoding.getRepCluster (),
140- dpasEncoding.getSubGroupSize ());
141- newVEncoding = DotOperandEncodingAttr::get (
142- ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel ());
126+ if (upcastMXFPUseDotOpEnc) {
127+ auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
128+
129+ const int opIdx = oldEncoding.getOpIdx ();
130+ const bool hasBatch = xShape.size () == 3 ;
131+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
132+ newShape[kIdx ] *= 2 ;
133+
134+ // Note: For Intel the dot operands layout's kWidth parameter must match
135+ // the parent's DPAS layout opsPerChannel so we need to materialize a
136+ // new DPAS layout.
137+ if (auto dpasEncoding =
138+ dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent ())) {
139+ auto newDpasEncoding = intel::DpasEncodingAttr::get (
140+ ctx, dpasEncoding.getRepeatCount (),
141+ dpasEncoding.getSystolicDepth (), dpasEncoding.getExecutionSize (),
142+ intel::DpasEncodingAttr::getOpsPerChannel (elemType),
143+ dpasEncoding.getWarpsPerCTA (), dpasEncoding.getRepCluster (),
144+ dpasEncoding.getSubGroupSize ());
145+ newVEncoding = DotOperandEncodingAttr::get (
146+ ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel ());
147+ } else {
148+ // Figure out the K dimension for the input A/B, given that the return
149+ // type is upcasted A/B type so we need to update the proper dim size.
150+ newVEncoding = DotOperandEncodingAttr::get (
151+ ctx, oldEncoding.getOpIdx (), oldEncoding.getParent (),
152+ oldEncoding.getKWidth () * 2 );
153+ }
143154 } else {
144- // Figure out the K dimension for the input A/B, given that the return
145- // type is upcasted A/B type so we need to update the proper dim size.
146- newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
147- oldEncoding.getParent (),
148- oldEncoding.getKWidth () * 2 );
155+ auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding);
156+ assert (oldEncoding &&
157+ " Expected a blocked encoding for UpcastMXFP op result." );
158+ newShape.back () *= 2 ;
159+ SmallVector<unsigned > sizePerThread = oldEncoding.getSizePerThread ();
160+ sizePerThread.back () *= 2 ;
161+ newVEncoding = BlockedEncodingAttr::get (
162+ ctx, sizePerThread, oldEncoding.getThreadsPerWarp (),
163+ oldEncoding.getWarpsPerCTA (), oldEncoding.getCTAOrder (),
164+ oldEncoding.getCTALayout ());
149165 }
150166 retTy = RankedTensorType::get (newShape, elemType, newVEncoding);
151167 }
0 commit comments