@@ -133,4 +133,112 @@ LogicalResult ExtractSliceOp::verify() {
133133
134134 return success ();
135135}
136+
137+ LogicalResult UpcastMXFPOp::verify () {
138+ auto fpType = getFpType ();
139+
140+ auto xTy = getSrc ().getType ();
141+ auto scaleTy = getScale ().getType ();
142+ Builder b (getContext ());
143+ if (xTy.getElementType () != b.getBF16Type () &&
144+ xTy.getElementType () != b.getF16Type () &&
145+ xTy.getElementType () != b.getI8Type ()) {
146+ return emitOpError (
147+ " element type of the first operand must be bf16/fp16 or i8" );
148+ }
149+
150+ if (scaleTy.getElementType () != b.getI8Type ()) {
151+ return emitOpError (" element type of the second operand must be uint8" );
152+ }
153+
154+ auto xShape = xTy.getShape ();
155+ auto scaleShape = scaleTy.getShape ();
156+
157+ if (xShape.size () != scaleShape.size () || xShape.size () < 2 ) {
158+ return emitOpError (
159+ " operands must have the same number of dimensions, at least 2" );
160+ }
161+
162+ if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
163+ fpType == ScaleDotElemType::E5M2)) {
164+ return emitOpError (" NYI: fpType must be E2M1, E4M3, or E5M2" );
165+ }
166+
167+ auto layoutX = xTy.getEncoding ();
168+ auto layoutScale = scaleTy.getEncoding ();
169+ if (bool (layoutX) != bool (layoutScale)) {
170+ return emitOpError (
171+ " Expected either both or neither operands to have an encoding" );
172+ }
173+ // Nothing to check if no encoding. This is used to infer the return type in
174+ // AccelerateMatmul.cpp
175+ if (!layoutX) {
176+ return success ();
177+ }
178+
179+ auto dotEncoding = dyn_cast<gpu::DotOperandEncodingAttr>(layoutX);
180+ if (!dotEncoding) {
181+ return emitOpError (" Expected a DotOperandEncodingAttr for values" );
182+ }
183+ if (!isa<gpu::BlockedEncodingAttr, gpu::LinearEncodingAttr>(layoutScale)) {
184+ return emitOpError (
185+ " Expected a BlockOperandEncoding or LinearOperandEncoding "
186+ " for scales" );
187+ }
188+
189+ // Change to support fp8 types
190+ const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1 ;
191+ // Figure out the K dimension for the input A/B. For A/B scale, the K
192+ // dimension is always the last dimension.
193+ const int opIdx = dotEncoding.getOpIdx ();
194+ const bool hasBatch = xShape.size () == 3 ;
195+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
196+
197+ if (xShape[kIdx ] != (32 / elemsPacked) * scaleShape.back ()) {
198+ return emitOpError (" K dimension of first operand must be 16 times "
199+ " larger than last/K dimension of the second operand" );
200+ }
201+
202+ // Check other dimensions match too. For input A/B, we need to figure out the
203+ // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
204+ const int mnIdx = (opIdx == 0 ? 0 : 1 ) + hasBatch;
205+ if (hasBatch && xShape[0 ] != scaleShape[0 ])
206+ return emitOpError (" batch dimension must match between operands" );
207+ if (xShape[mnIdx] != scaleShape[hasBatch]) {
208+ return emitOpError (" M/N dimension must match between operands" );
209+ }
210+
211+ return success ();
212+ }
213+
214+ RankedTensorType
215+ UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor,
216+ ScaleDotElemType inputElemType,
217+ Type outputElemType) {
218+ MLIRContext *ctx = inputTensor.getContext ();
219+ auto xTy = inputTensor.getType ();
220+ if (inputElemType != ScaleDotElemType::E2M1)
221+ return xTy;
222+
223+ auto xShape = xTy.getShape ();
224+ auto newShape = llvm::to_vector (xShape);
225+ auto encoding = xTy.getEncoding ();
226+ if (!encoding) {
227+ newShape.back () *= 2 ;
228+ return RankedTensorType::get (xShape, outputElemType);
229+ }
230+
231+ auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
232+ auto newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
233+ oldEncoding.getParent (),
234+ oldEncoding.getKWidth () * 2 );
235+ // Figure out the K dimension for the input A/B, given that the return
236+ // type is upcasted A/B type so we need to update the proper dim size.
237+ const int opIdx = oldEncoding.getOpIdx ();
238+ const bool hasBatch = xShape.size () == 3 ;
239+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
240+ newShape[kIdx ] *= 2 ;
241+ return RankedTensorType::get (newShape, outputElemType, newVEncoding);
242+ }
243+
136244} // namespace mlir::triton::amdgpu
0 commit comments