@@ -331,64 +331,121 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
331331 patterns.add <CanonicalizeConvertFromSplit>(context);
332332}
333333
334- LogicalResult Fp4ToFpOp::verify () {
335- auto srcTy = cast<RankedTensorType>(getSrc ().getType ());
336- auto resTy = cast<RankedTensorType>(getResult ().getType ());
337- auto rank = srcTy.getRank ();
338-
339- if (rank != resTy.getRank ())
340- return emitError () << " source rank " << rank << " != result rank "
341- << resTy.getRank ();
342-
343- auto srcShape = srcTy.getShape ();
344- auto resShape = resTy.getShape ();
345- auto axis = getAxis ();
346-
347- if (!(0 <= axis && axis < rank))
348- return emitError () << " axis " << axis << " out of range for rank " << rank;
349-
350- auto elemType = resTy.getElementType ();
351- if (!(elemType.isBF16 () || elemType.isF16 ()))
352- return emitError () << " only bf16 or f16 is supported for now, got "
353- << elemType;
354-
355- for (int i = 0 ; i < rank; ++i) {
356- if (i == axis) {
357- if (resShape[i] != srcShape[i] * 2 )
358- return emitError () << " axis " << axis
359- << " dimension must be 2x source dimension (src="
360- << srcShape[i] << " , dst=" << resShape[i] << " )" ;
361- } else {
362- if (resShape[i] != srcShape[i])
363- return emitError () << " dimension " << i
364- << " mismatch (src=" << srcShape[i]
365- << " , dst=" << resShape[i] << " , axis=" << axis
366- << " )" ;
334+ LogicalResult UpcastMXFPOp::verify () {
335+ auto fpType = getFpType ();
336+
337+ auto xTy = getSrc ().getType ();
338+ auto scaleTy = getScale ().getType ();
339+ Builder b (getContext ());
340+ if (xTy.getElementType () != b.getBF16Type () &&
341+ xTy.getElementType () != b.getF16Type () &&
342+ xTy.getElementType () != b.getI8Type ()) {
343+ return emitOpError (
344+ " element type of the first operand must be bf16/fp16 or i8" );
345+ }
346+
347+ if (scaleTy.getElementType () != b.getI8Type ()) {
348+ return emitOpError (" element type of the second operand must be uint8" );
349+ }
350+
351+ auto xShape = xTy.getShape ();
352+ auto scaleShape = scaleTy.getShape ();
353+
354+ if (xShape.size () != scaleShape.size () || xShape.size () < 2 ) {
355+ return emitOpError (
356+ " operands must have the same number of dimensions, at least 2" );
357+ }
358+
359+ if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
360+ fpType == ScaleDotElemType::E5M2)) {
361+ return emitOpError (" NYI: fpType must be E2M1, E4M3, or E5M2" );
362+ }
363+
364+ auto layoutX = xTy.getEncoding ();
365+ auto layoutScale = scaleTy.getEncoding ();
366+ if (bool (layoutX) != bool (layoutScale)) {
367+ return emitOpError (
368+ " Expected either both or neither operands to have an encoding" );
369+ }
370+ // Nothing to check if no encoding. This is used to infer the return type in
371+ // AccelerateMatmul.cpp
372+ if (!layoutX) {
373+ return success ();
374+ }
375+
376+ auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
377+ if (!dotEncoding) {
378+ return emitOpError (" Expected a DotOperandEncodingAttr for values" );
379+ }
380+ if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
381+ return emitOpError (
382+ " Expected a BlockOperandEncoding or LinearOperandEncoding "
383+ " for scales" );
384+ }
385+
386+ if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent ())) {
387+ // Necessary to keep all of the scales of a given block of values in the
388+ // same warp
389+ auto threadsPerWarp =
390+ cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp ();
391+ if (threadsPerWarp != ArrayRef<unsigned >({16 , 2 })) {
392+ return emitOpError (" Expected threads per warp to be {16, 2}" );
367393 }
368394 }
395+
396+ // Change to support fp8 types
397+ const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1 ;
398+ // Figure out the K dimension for the input A/B. For A/B scale, the K
399+ // dimension is always the last dimension.
400+ const int opIdx = dotEncoding.getOpIdx ();
401+ const bool hasBatch = xShape.size () == 3 ;
402+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
403+
404+ if (xShape[kIdx ] != (32 / elemsPacked) * scaleShape.back ()) {
405+ return emitOpError (" K dimension of first operand must be 16 times "
406+ " larger than last/K dimension of the second operand" );
407+ }
408+
409+ // Check other dimensions match too. For input A/B, we need to figure out the
410+ // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
411+ const int mnIdx = (opIdx == 0 ? 0 : 1 ) + hasBatch;
412+ if (hasBatch && xShape[0 ] != scaleShape[0 ])
413+ return emitOpError (" batch dimension must match between operands" );
414+ if (xShape[mnIdx] != scaleShape[hasBatch]) {
415+ return emitOpError (" M/N dimension must match between operands" );
416+ }
417+
369418 return success ();
370419}
371420
372- void Fp4ToFpOp::build (OpBuilder &builder, OperationState &state,
373- TypedValue<RankedTensorType> src, Type elemType,
374- int32_t axis) {
375- auto srcTy = src.getType ();
376- auto shape = llvm::to_vector (srcTy.getShape ());
377- auto rank = srcTy.getRank ();
378- assert (0 <= axis && axis < rank);
379- shape[axis] *= 2 ;
380-
381- Attribute inEnc = srcTy.getEncoding ();
382- Attribute outEnc;
383- auto result =
384- inEnc.getDialect ()
385- .getRegisteredInterface <triton::DialectInferLayoutInterface>()
386- ->inferFp4ToFpOpEncoding (shape, axis, inEnc, outEnc,
387- /* fwdInference=*/ true , state.location );
388- assert (succeeded (result));
389-
390- auto resultTy = RankedTensorType::get (shape, elemType, outEnc);
391- build (builder, state, resultTy, src, axis);
421+ RankedTensorType
422+ UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor,
423+ ScaleDotElemType inputElemType,
424+ Type outputElemType) {
425+ MLIRContext *ctx = inputTensor.getContext ();
426+ auto xTy = inputTensor.getType ();
427+ if (inputElemType != ScaleDotElemType::E2M1)
428+ return xTy;
429+
430+ auto xShape = xTy.getShape ();
431+ auto newShape = llvm::to_vector (xShape);
432+ auto encoding = xTy.getEncoding ();
433+ if (!encoding) {
434+ newShape.back () *= 2 ;
435+ return RankedTensorType::get (xShape, outputElemType);
436+ }
437+
438+ auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
439+ auto newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
440+ oldEncoding.getParent (),
441+ oldEncoding.getKWidth () * 2 );
442+ // Figure out the K dimension for the input A/B, given that the return
443+ // type is upcasted A/B type so we need to update the proper dim size.
444+ const int opIdx = oldEncoding.getOpIdx ();
445+ const bool hasBatch = xShape.size () == 3 ;
446+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
447+ newShape[kIdx ] *= 2 ;
448+ return RankedTensorType::get (newShape, outputElemType, newVEncoding);
392449}
393450
394451OpFoldResult MemDescTransOp::fold (FoldAdaptor adaptor) {
0 commit comments