@@ -52,50 +52,52 @@ LogicalResult UpcastMXFPOp::verify() {
5252 }
5353
5454 // / TODO: Temporarily disabled this check to allow for the blocked encoding.
55- // / we need to re-enable this check once we have the dot op encoding
56- // / UpcastMXFPOp lowering
57- // auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
58- // if (!dotEncoding) {
59- // return emitOpError("Expected a DotOperandEncodingAttr for values");
60- // }
55+ // / Enable once we have the dot op encoding UpcastMXFPOp lowering.
56+ auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
57+ if (mlir::triton::tools::getBoolEnv (
58+ " TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING" ) &&
59+ !dotEncoding) {
60+ return emitOpError (" Expected a DotOperandEncodingAttr for values" );
61+ }
6162 if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
6263 return emitOpError (
6364 " Expected a BlockOperandEncoding or LinearOperandEncoding "
6465 " for scales" );
6566 }
67+ if (!dotEncoding)
68+ return success ();
6669
67- // if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
68- // // Necessary to keep all of the scales of a given block of values in the
69- // // same warp
70- // auto threadsPerWarp =
71- // cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp();
72- // if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
73- // return emitOpError("Expected threads per warp to be {16, 2}");
74- // }
75- // }
76-
77- // // Change to support fp8 types
78- // const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
79- // // Figure out the K dimension for the input A/B. For A/B scale, the K
80- // // dimension is always the last dimension.
81- // const int opIdx = dotEncoding.getOpIdx();
82- // const bool hasBatch = xShape.size() == 3;
83- // const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
84-
85- // if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
86- // return emitOpError("K dimension of first operand must be 16 times "
87- // "larger than last/K dimension of the second operand");
88- // }
89-
90- // // Check other dimensions match too. For input A/B, we need to figure out
91- // the
92- // // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
93- // const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
94- // if (hasBatch && xShape[0] != scaleShape[0])
95- // return emitOpError("batch dimension must match between operands");
96- // if (xShape[mnIdx] != scaleShape[hasBatch]) {
97- // return emitOpError("M/N dimension must match between operands");
98- // }
70+ if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent ())) {
71+ // Necessary to keep all of the scales of a given block of values in the
72+ // same warp
73+ auto threadsPerWarp =
74+ cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp ();
75+ if (threadsPerWarp != ArrayRef<unsigned >({16 , 2 })) {
76+ return emitOpError (" Expected threads per warp to be {16, 2}" );
77+ }
78+ }
79+
80+ // Change to support fp8 types
81+ const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1 ;
82+ // Figure out the K dimension for the input A/B. For A/B scale, the K
83+ // dimension is always the last dimension.
84+ const int opIdx = dotEncoding.getOpIdx ();
85+ const bool hasBatch = xShape.size () == 3 ;
86+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
87+
88+ if (xShape[kIdx ] != (32 / elemsPacked) * scaleShape.back ()) {
89+ return emitOpError (" K dimension of first operand must be 16 times "
90+ " larger than last/K dimension of the second operand" );
91+ }
92+
93+ // Check other dimensions match too. For input A/B, we need to figure out the
94+ // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
95+ const int mnIdx = (opIdx == 0 ? 0 : 1 ) + hasBatch;
96+ if (hasBatch && xShape[0 ] != scaleShape[0 ])
97+ return emitOpError (" batch dimension must match between operands" );
98+ if (xShape[mnIdx] != scaleShape[hasBatch]) {
99+ return emitOpError (" M/N dimension must match between operands" );
100+ }
99101
100102 return success ();
101103}
@@ -110,8 +112,6 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
110112 auto xShape = xTy.getShape ();
111113
112114 auto encoding = xTy.getEncoding ();
113- bool upcastMXFPUseDotOpEnc =
114- mlir::triton::tools::getBoolEnv (" TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING" );
115115
116116 if (typeEncoded == ScaleDotElemType::E2M1) {
117117 RankedTensorType retTy;
@@ -122,10 +122,8 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
122122 retTy = RankedTensorType::get (xShape, FloatType::getBF16 (ctx));
123123 } else {
124124 Type elemType = FloatType::getBF16 (ctx);
125- Attribute newVEncoding;
126- if (upcastMXFPUseDotOpEnc) {
127- auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
128-
125+ Attribute newVEncoding = nullptr ;
126+ if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
129127 const int opIdx = oldEncoding.getOpIdx ();
130128 const bool hasBatch = xShape.size () == 3 ;
131129 const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
@@ -151,10 +149,9 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
151149 ctx, oldEncoding.getOpIdx (), oldEncoding.getParent (),
152150 oldEncoding.getKWidth () * 2 );
153151 }
154- } else {
155- auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding);
156- assert (oldEncoding &&
157- " Expected a blocked encoding for UpcastMXFP op result." );
152+ } else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
153+ // TODO: Temporary code, remove once upcast_mxfp support dot encoding.
154+ assert (!tools::getBoolEnv (" TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING" ));
158155 newShape.back () *= 2 ;
159156 SmallVector<unsigned > sizePerThread = oldEncoding.getSizePerThread ();
160157 sizePerThread.back () *= 2 ;
0 commit comments