@@ -329,121 +329,64 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
329329 patterns.add <CanonicalizeConvertFromSplit>(context);
330330}
331331
332- LogicalResult UpcastMXFPOp::verify () {
333- auto fpType = getFpType ();
334-
335- auto xTy = getSrc ().getType ();
336- auto scaleTy = getScale ().getType ();
337- Builder b (getContext ());
338- if (xTy.getElementType () != b.getBF16Type () &&
339- xTy.getElementType () != b.getF16Type () &&
340- xTy.getElementType () != b.getI8Type ()) {
341- return emitOpError (
342- " element type of the first operand must be bf16/fp16 or i8" );
343- }
344-
345- if (scaleTy.getElementType () != b.getI8Type ()) {
346- return emitOpError (" element type of the second operand must be uint8" );
347- }
348-
349- auto xShape = xTy.getShape ();
350- auto scaleShape = scaleTy.getShape ();
351-
352- if (xShape.size () != scaleShape.size () || xShape.size () < 2 ) {
353- return emitOpError (
354- " operands must have the same number of dimensions, at least 2" );
355- }
356-
357- if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
358- fpType == ScaleDotElemType::E5M2)) {
359- return emitOpError (" NYI: fpType must be E2M1, E4M3, or E5M2" );
360- }
361-
362- auto layoutX = xTy.getEncoding ();
363- auto layoutScale = scaleTy.getEncoding ();
364- if (bool (layoutX) != bool (layoutScale)) {
365- return emitOpError (
366- " Expected either both or neither operands to have an encoding" );
367- }
368- // Nothing to check if no encoding. This is used to infer the return type in
369- // AccelerateMatmul.cpp
370- if (!layoutX) {
371- return success ();
372- }
373-
374- auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
375- if (!dotEncoding) {
376- return emitOpError (" Expected a DotOperandEncodingAttr for values" );
377- }
378- if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
379- return emitOpError (
380- " Expected a BlockOperandEncoding or LinearOperandEncoding "
381- " for scales" );
382- }
383-
384- if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent ())) {
385- // Necessary to keep all of the scales of a given block of values in the
386- // same warp
387- auto threadsPerWarp =
388- cast<DistributedEncodingTrait>(layoutScale).getThreadsPerWarp ();
389- if (threadsPerWarp != ArrayRef<unsigned >({16 , 2 })) {
390- return emitOpError (" Expected threads per warp to be {16, 2}" );
332+ LogicalResult Fp4ToFpOp::verify () {
333+ auto srcTy = cast<RankedTensorType>(getSrc ().getType ());
334+ auto resTy = cast<RankedTensorType>(getResult ().getType ());
335+ auto rank = srcTy.getRank ();
336+
337+ if (rank != resTy.getRank ())
338+ return emitError () << " source rank " << rank << " != result rank "
339+ << resTy.getRank ();
340+
341+ auto srcShape = srcTy.getShape ();
342+ auto resShape = resTy.getShape ();
343+ auto axis = getAxis ();
344+
345+ if (!(0 <= axis && axis < rank))
346+ return emitError () << " axis " << axis << " out of range for rank " << rank;
347+
348+ auto elemType = resTy.getElementType ();
349+ if (!(elemType.isBF16 () || elemType.isF16 ()))
350+ return emitError () << " only bf16 or f16 is supported for now, got "
351+ << elemType;
352+
353+ for (int i = 0 ; i < rank; ++i) {
354+ if (i == axis) {
355+ if (resShape[i] != srcShape[i] * 2 )
356+ return emitError () << " axis " << axis
357+ << " dimension must be 2x source dimension (src="
358+ << srcShape[i] << " , dst=" << resShape[i] << " )" ;
359+ } else {
360+ if (resShape[i] != srcShape[i])
361+ return emitError () << " dimension " << i
362+ << " mismatch (src=" << srcShape[i]
363+ << " , dst=" << resShape[i] << " , axis=" << axis
364+ << " )" ;
391365 }
392366 }
393-
394- // Change to support fp8 types
395- const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1 ;
396- // Figure out the K dimension for the input A/B. For A/B scale, the K
397- // dimension is always the last dimension.
398- const int opIdx = dotEncoding.getOpIdx ();
399- const bool hasBatch = xShape.size () == 3 ;
400- const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
401-
402- if (xShape[kIdx ] != (32 / elemsPacked) * scaleShape.back ()) {
403- return emitOpError (" K dimension of first operand must be 16 times "
404- " larger than last/K dimension of the second operand" );
405- }
406-
407- // Check other dimensions match too. For input A/B, we need to figure out the
408- // index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
409- const int mnIdx = (opIdx == 0 ? 0 : 1 ) + hasBatch;
410- if (hasBatch && xShape[0 ] != scaleShape[0 ])
411- return emitOpError (" batch dimension must match between operands" );
412- if (xShape[mnIdx] != scaleShape[hasBatch]) {
413- return emitOpError (" M/N dimension must match between operands" );
414- }
415-
416367 return success ();
417368}
418369
419- RankedTensorType
420- UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor,
421- ScaleDotElemType inputElemType,
422- Type outputElemType) {
423- MLIRContext *ctx = inputTensor.getContext ();
424- auto xTy = inputTensor.getType ();
425- if (inputElemType != ScaleDotElemType::E2M1)
426- return xTy;
427-
428- auto xShape = xTy.getShape ();
429- auto newShape = llvm::to_vector (xShape);
430- auto encoding = xTy.getEncoding ();
431- if (!encoding) {
432- newShape.back () *= 2 ;
433- return RankedTensorType::get (xShape, outputElemType);
434- }
435-
436- auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
437- auto newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
438- oldEncoding.getParent (),
439- oldEncoding.getKWidth () * 2 );
440- // Figure out the K dimension for the input A/B, given that the return
441- // type is upcasted A/B type so we need to update the proper dim size.
442- const int opIdx = oldEncoding.getOpIdx ();
443- const bool hasBatch = xShape.size () == 3 ;
444- const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
445- newShape[kIdx ] *= 2 ;
446- return RankedTensorType::get (newShape, outputElemType, newVEncoding);
370+ void Fp4ToFpOp::build (OpBuilder &builder, OperationState &state,
371+ TypedValue<RankedTensorType> src, Type elemType,
372+ int32_t axis) {
373+ auto srcTy = src.getType ();
374+ auto shape = llvm::to_vector (srcTy.getShape ());
375+ auto rank = srcTy.getRank ();
376+ assert (0 <= axis && axis < rank);
377+ shape[axis] *= 2 ;
378+
379+ Attribute inEnc = srcTy.getEncoding ();
380+ Attribute outEnc;
381+ auto result =
382+ inEnc.getDialect ()
383+ .getRegisteredInterface <triton::DialectInferLayoutInterface>()
384+ ->inferFp4ToFpOpEncoding (shape, axis, inEnc, outEnc,
385+ /* fwdInference=*/ true , state.location );
386+ assert (succeeded (result));
387+
388+ auto resultTy = RankedTensorType::get (shape, elemType, outEnc);
389+ build (builder, state, resultTy, src, axis);
447390}
448391
449392OpFoldResult MemDescTransOp::fold (FoldAdaptor adaptor) {
0 commit comments