@@ -303,13 +303,15 @@ LogicalResult UpcastMXFPOp::verify() {
303303
304304 auto xTy = getSrc ().getType ();
305305 auto scaleTy = getScale ().getType ();
306-
307- if (xTy.getElementType () != FloatType::getBF16 (getContext ()) &&
308- xTy.getElementType () != IntegerType::get (getContext (), 8 )) {
309- return emitOpError (" element type of the first operand must be bf16 or i8" );
306+ Builder b (getContext ());
307+ if (xTy.getElementType () != b.getBF16Type () &&
308+ xTy.getElementType () != b.getF16Type () &&
309+ xTy.getElementType () != b.getI8Type ()) {
310+ return emitOpError (
311+ " element type of the first operand must be bf16/fp16 or i8" );
310312 }
311313
312- if (scaleTy.getElementType () != IntegerType::get ( getContext (), 8 )) {
314+ if (scaleTy.getElementType () != b. getI8Type ( )) {
313315 return emitOpError (" element type of the second operand must be uint8" );
314316 }
315317
@@ -383,66 +385,55 @@ LogicalResult UpcastMXFPOp::verify() {
383385 return success ();
384386}
385387
386- LogicalResult UpcastMXFPOp::inferReturnTypes (
387- MLIRContext *ctx, std::optional<Location> loc, ValueRange operands ,
388- DictionaryAttr attributes, OpaqueProperties opaqueProperties ,
389- RegionRange regions, SmallVectorImpl< Type> &inferredReturnTypes ) {
390- auto xTy = cast<RankedTensorType>(operands[ 0 ]. getType () );
391- auto properties = opaqueProperties. as < const Properties *> ();
392- auto typeEncoded = properties-> fp_type . getValue ();
393- auto xShape = xTy. getShape () ;
388+ RankedTensorType
389+ UpcastMXFPOp::deduceOutputType (TypedValue<RankedTensorType> inputTensor ,
390+ ScaleDotElemType inputElemType ,
391+ Type outputElemType ) {
392+ MLIRContext *ctx = inputTensor. getContext ( );
393+ auto xTy = inputTensor. getType ();
394+ if (inputElemType != ScaleDotElemType::E2M1)
395+ return xTy;
394396
397+ auto xShape = xTy.getShape ();
398+ auto newShape = llvm::to_vector (xShape);
395399 auto encoding = xTy.getEncoding ();
396-
397- if (typeEncoded == ScaleDotElemType::E2M1) {
398- RankedTensorType retTy;
399-
400- auto newShape = SmallVector<int64_t >(xShape);
401- if (!encoding) {
402- newShape.back () *= 2 ;
403- retTy = RankedTensorType::get (xShape, FloatType::getBF16 (ctx));
404- } else {
405- Type elemType = FloatType::getBF16 (ctx);
406- Attribute newVEncoding = nullptr ;
407- auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
408- const int opIdx = oldEncoding.getOpIdx ();
409- const bool hasBatch = xShape.size () == 3 ;
410- const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
411- newShape[kIdx ] *= 2 ;
412-
413- // Note: For Intel the dot operands layout's kWidth parameter must match
414- // the parent's DPAS layout opsPerChannel so we need to materialize a
415- // new DPAS layout.
416- if (auto dpasEncoding =
417- dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent ())) {
418- unsigned opsPerChannel =
419- intel::DpasEncodingAttr::getOpsPerChannel (elemType);
420- // e2m1 is packed 2 elements per int8, we must handle continuous 2
421- // elements when upcasting to bf16
422- if (xTy.getElementType () == IntegerType::get (ctx, 8 ))
423- opsPerChannel *= 2 ;
424- auto newDpasEncoding = intel::DpasEncodingAttr::get (
425- ctx, dpasEncoding.getRepeatCount (), dpasEncoding.getSystolicDepth (),
426- dpasEncoding.getExecutionSize (), opsPerChannel,
427- dpasEncoding.getWarpsPerCTA (), dpasEncoding.getRepCluster (),
428- product<unsigned >(dpasEncoding.getThreadsPerWarp ()));
429- newVEncoding = DotOperandEncodingAttr::get (
430- ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel ());
431- } else {
432- // Figure out the K dimension for the input A/B, given that the return
433- // type is upcasted A/B type so we need to update the proper dim size.
434- newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
435- oldEncoding.getParent (),
436- oldEncoding.getKWidth () * 2 );
437- }
438- retTy = RankedTensorType::get (newShape, elemType, newVEncoding);
439- }
440- inferredReturnTypes.push_back (retTy);
400+ if (!encoding) {
401+ newShape.back () *= 2 ;
402+ return RankedTensorType::get (xShape, outputElemType);
403+ }
404+
405+ Attribute newVEncoding = nullptr ;
406+ auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
407+ const int opIdx = oldEncoding.getOpIdx ();
408+ // Note: For Intel the dot operands layout's kWidth parameter must match
409+ // the parent's DPAS layout opsPerChannel so we need to materialize a
410+ // new DPAS layout.
411+ if (auto dpasEncoding =
412+ dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent ())) {
413+ unsigned opsPerChannel =
414+ intel::DpasEncodingAttr::getOpsPerChannel (outputElemType);
415+ // e2m1 is packed 2 elements per int8, we must handle continuous 2
416+ // elements when upcasting to bf16
417+ if (xTy.getElementType () == IntegerType::get (ctx, 8 ))
418+ opsPerChannel *= 2 ;
419+ auto newDpasEncoding = intel::DpasEncodingAttr::get (
420+ ctx, dpasEncoding.getRepeatCount (), dpasEncoding.getSystolicDepth (),
421+ dpasEncoding.getExecutionSize (), opsPerChannel,
422+ dpasEncoding.getWarpsPerCTA (), dpasEncoding.getRepCluster (),
423+ product<unsigned >(dpasEncoding.getThreadsPerWarp ()));
424+ newVEncoding = DotOperandEncodingAttr::get (
425+ ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel ());
441426 } else {
442- inferredReturnTypes.push_back (xTy);
427+ // Figure out the K dimension for the input A/B, given that the return
428+ // type is upcasted A/B type so we need to update the proper dim size.
429+ newVEncoding = DotOperandEncodingAttr::get (ctx, oldEncoding.getOpIdx (),
430+ oldEncoding.getParent (),
431+ oldEncoding.getKWidth () * 2 );
443432 }
444-
445- return success ();
433+ const bool hasBatch = xShape.size () == 3 ;
434+ const int kIdx = (opIdx == 0 ? 1 : 0 ) + hasBatch;
435+ newShape[kIdx ] *= 2 ;
436+ return RankedTensorType::get (newShape, outputElemType, newVEncoding);
446437}
447438
448439OpFoldResult MemDescTransOp::fold (FoldAdaptor adaptor) {
0 commit comments