@@ -52,14 +52,23 @@ LogicalResult UpcastMXFPOp::verify() {
5252 " all dimensions except the last must match between operands" );
5353 }
5454
55- auto dotEncoding =
56- dyn_cast_or_null<DotOperandEncodingAttr>(xTy.getEncoding ());
55+ auto layoutX = xTy.getEncoding ();
56+ auto layoutScale = scaleTy.getEncoding ();
57+ if (bool (layoutX) != bool (layoutScale)) {
58+ return emitOpError (
59+ " Expected either both or neither operands to have an encoding" );
60+ }
61+ // Nothing to check if no encoding. This is used to infer the return type in
62+ // AccelerateMatmul.cpp
63+ if (!layoutX) {
64+ return success ();
65+ }
66+
67+ auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
5768 if (!dotEncoding) {
5869 return emitOpError (" Expected a DotOperandEncodingAttr for values" );
5970 }
60-
61- auto blockedScale =
62- dyn_cast_or_null<BlockedEncodingAttr>(scaleTy.getEncoding ());
71+ auto blockedScale = dyn_cast<BlockedEncodingAttr>(layoutScale);
6372 if (!blockedScale) {
6473 return emitOpError (" Expected a BlockOperandEncoding for scales" );
6574 }
@@ -86,22 +95,23 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
8695 auto xShape = xTy.getShape ();
8796
8897 auto encoding = xTy.getEncoding ();
89- if (!encoding) {
90- return emitOptionalError (loc, " expected an encoding" );
91- }
92- if (!mlir::isa<DotOperandEncodingAttr>(encoding)) {
93- return emitOptionalError (loc, " expected a dotOperand encoding" );
94- }
9598
9699 if (typeEncoded == ScaleDotElemType::E2M1) {
97- auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
98- auto newVEncoding = DotOperandEncodingAttr::get (
99- ctx, oldEncoding.getOpIdx (), oldEncoding.getParent (),
100- oldEncoding.getKWidth () * 2 );
100+ RankedTensorType retTy;
101+
101102 auto newShape = SmallVector<int64_t >(xShape);
102103 newShape.back () *= 2 ;
103- inferredReturnTypes.push_back (
104- RankedTensorType::get (newShape, FloatType::getBF16 (ctx), newVEncoding));
104+ if (!encoding) {
105+ retTy = RankedTensorType::get (xShape, FloatType::getBF16 (ctx));
106+ } else {
107+ auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
108+ auto newVEncoding = DotOperandEncodingAttr::get (
109+ ctx, oldEncoding.getOpIdx (), oldEncoding.getParent (),
110+ oldEncoding.getKWidth () * 2 );
111+ retTy = RankedTensorType::get (newShape, FloatType::getBF16 (ctx),
112+ newVEncoding);
113+ }
114+ inferredReturnTypes.push_back (retTy);
105115 } else {
106116 inferredReturnTypes.push_back (xTy);
107117 }
0 commit comments