@@ -331,121 +331,64 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
331331 patterns.add <CanonicalizeConvertFromSplit>(context);
332332}
333333
334- LogicalResult UpcastMXFPOp::verify () {
335- auto fpType = getFpType ();
336-
337- auto xTy = getSrc ().getType ();
338- auto scaleTy = getScale ().getType ();
339- Builder b (getContext ());
340- if (xTy.getElementType () != b.getBF16Type () &&
341- xTy.getElementType () != b.getF16Type () &&
342- xTy.getElementType () != b.getI8Type ()) {
343- return emitOpError (
344- " element type of the first operand must be bf16/fp16 or i8" );
345- }
346-
347- if (scaleTy.getElementType () != b.getI8Type ()) {
348- return emitOpError (" element type of the second operand must be uint8" );
349- }
350-
351- auto xShape = xTy.getShape ();
352- auto scaleShape = scaleTy.getShape ();
353-
354- if (xShape.size () != scaleShape.size () || xShape.size () < 2 ) {
355- return emitOpError (
356- " operands must have the same number of dimensions, at least 2" );
357- }
358-
359- if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
360- fpType == ScaleDotElemType::E5M2)) {
361- return emitOpError (" NYI: fpType must be E2M1, E4M3, or E5M2" );
362- }
363-
364- auto layoutX = xTy.getEncoding ();
365- auto layoutScale = scaleTy.getEncoding ();
366- if (bool (layoutX) != bool (layoutScale)) {
367- return emitOpError (
368- " Expected either both or neither operands to have an encoding" );
369- }
370- // Nothing to check if no encoding. This is used to infer the return type in
371- // AccelerateMatmul.cpp
372- if (!layoutX) {
373- return success ();
374- }
375-
376- auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
377- if (!dotEncoding) {
378- return emitOpError (" Expected a DotOperandEncodingAttr for values" );
379- }
380- if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
381- return emitOpError (
382- " Expected a BlockOperandEncoding or LinearOperandEncoding "
383- " for scales" );
384- }
385-
386- if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent ())) {
387- // Necessary to keep all of the scales of a given block of values in the
388- // same warp
389- auto threadsPerWarp =
390- cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp ();
391- if (threadsPerWarp != ArrayRef<unsigned >({16 , 2 })) {
392- return emitOpError (" Expected threads per warp to be {16, 2}" );
334+ LogicalResult Fp4ToFpOp::verify () {
335+ auto srcTy = cast<RankedTensorType>(getSrc ().getType ());
336+ auto resTy = cast<RankedTensorType>(getResult ().getType ());
337+ auto rank = srcTy.getRank ();
338+
339+ if (rank != resTy.getRank ())
340+ return emitError () << " source rank " << rank << " != result rank "
341+ << resTy.getRank ();
342+
343+ auto srcShape = srcTy.getShape ();
344+ auto resShape = resTy.getShape ();
345+ auto axis = getAxis ();
346+
347+ if (!(0 <= axis && axis < rank))
348+ return emitError () << " axis " << axis << " out of range for rank " << rank;
349+
350+ auto elemType = resTy.getElementType ();
351+ if (!(elemType.isBF16 () || elemType.isF16 ()))
352+ return emitError () << " only bf16 or f16 is supported for now, got "
353+ << elemType;
354+
355+ for (int i = 0 ; i < rank; ++i) {
356+ if (i == axis) {
357+ if (resShape[i] != srcShape[i] * 2 )
358+ return emitError () << " axis " << axis
359+ << " dimension must be 2x source dimension (src="
360+ << srcShape[i] << " , dst=" << resShape[i] << " )" ;
361+ } else {
362+ if (resShape[i] != srcShape[i])
363+ return emitError () << " dimension " << i
364+ << " mismatch (src=" << srcShape[i]
365+ << " , dst=" << resShape[i] << " , axis=" << axis
366+ << " )" ;
393367 }
394368 }
395-
396- // Change to support fp8 types
397- const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1 ;
398- // Figure out the K dimension for the input A/B. For A/B scale, the K
399- // dimension is always the last dimension.
400- const int opIdx = dotEncoding.getOpIdx ();
401- const bool hasBatch = xShape.size () == 3 ;
402- const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
403-
404- if (xShape[kIdx ] != (32 / elemsPacked) * scaleShape.back ()) {
405- return emitOpError (" K dimension of first operand must be 16 times "
406- " larger than last/K dimension of the second operand" );
407- }
408-
409- // Check other dimensions match too. For input A/B, we need to figure out the
410- // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
411- const int mnIdx = (opIdx == 0 ? 0 : 1 ) + hasBatch;
412- if (hasBatch && xShape[0 ] != scaleShape[0 ])
413- return emitOpError (" batch dimension must match between operands" );
414- if (xShape[mnIdx] != scaleShape[hasBatch]) {
415- return emitOpError (" M/N dimension must match between operands" );
416- }
417-
418369 return success ();
419370}
420371
421- RankedTensorType
422- UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor,
423- ScaleDotElemType inputElemType,
424- Type outputElemType) {
425- MLIRContext *ctx = inputTensor.getContext ();
426- auto xTy = inputTensor.getType ();
427- if (inputElemType != ScaleDotElemType::E2M1)
428- return xTy;
429-
430- auto xShape = xTy.getShape ();
431- auto newShape = llvm::to_vector (xShape);
432- auto encoding = xTy.getEncoding ();
433- if (!encoding) {
434- newShape.back () *= 2 ;
435- return RankedTensorType::get (xShape, outputElemType);
436- }
437-
438- auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
439- auto newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
440- oldEncoding.getParent (),
441- oldEncoding.getKWidth () * 2 );
442- // Figure out the K dimension for the input A/B, given that the return
443- // type is upcasted A/B type so we need to update the proper dim size.
444- const int opIdx = oldEncoding.getOpIdx ();
445- const bool hasBatch = xShape.size () == 3 ;
446- const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
447- newShape[kIdx ] *= 2 ;
448- return RankedTensorType::get (newShape, outputElemType, newVEncoding);
372+ void Fp4ToFpOp::build (OpBuilder &builder, OperationState &state,
373+ TypedValue<RankedTensorType> src, Type elemType,
374+ int32_t axis) {
375+ auto srcTy = src.getType ();
376+ auto shape = llvm::to_vector (srcTy.getShape ());
377+ auto rank = srcTy.getRank ();
378+ assert (0 <= axis && axis < rank);
379+ shape[axis] *= 2 ;
380+
381+ Attribute inEnc = srcTy.getEncoding ();
382+ Attribute outEnc;
383+ auto result =
384+ inEnc.getDialect ()
385+ .getRegisteredInterface <triton::DialectInferLayoutInterface>()
386+ ->inferFp4ToFpOpEncoding (shape, axis, inEnc, outEnc,
387+ /* fwdInference=*/ true , state.location );
388+ assert (succeeded (result));
389+
390+ auto resultTy = RankedTensorType::get (shape, elemType, outEnc);
391+ build (builder, state, resultTy, src, axis);
449392}
450393
451394OpFoldResult MemDescTransOp::fold (FoldAdaptor adaptor) {
0 commit comments