@@ -325,11 +325,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
325325 // with Intel layouts.
326326 // More details:
327327 // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
328- if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent ())) {
329- SmallVector<unsigned > order (rank);
330- std::iota (order.rbegin (), order.rend (), 0 );
331- return order;
332- }
328+ // if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
329+ // SmallVector<unsigned> order(rank);
330+ // std::iota(order.rbegin(), order.rend(), 0);
331+ // return order;
332+ // }
333333 return getOrderForDotOperand (dotLayout.getOpIdx (), rank, /* kMajor*/ true );
334334 }
335335 if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1129,10 +1129,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
11291129 return amdWmmaParent.getTotalElemsPerThreadForOperand (
11301130 shape, eltTy, getKWidth (), getOpIdx ());
11311131 }
1132- if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
1133- return dpasParent.getTotalElemsPerThreadForOperand (
1134- shape, eltTy, getKWidth (), getOpIdx ());
1135- }
1132+ // if (auto dpasParent =
1133+ // mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
1134+ // return dpasParent.getTotalElemsPerThreadForOperand(
1135+ // shape, eltTy, getKWidth(), getOpIdx());
1136+ // }
11361137 }
11371138 if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent ())) {
11381139 auto shapePerCTA = getShapePerCTA (*this , shape);
@@ -1197,17 +1198,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
11971198 return {};
11981199}
11991200SmallVector<unsigned > DotOperandEncodingAttr::getThreadOrder () const {
1200- // FIXME: delete if branch for `DpasEncodingAttr` and provide more
1201- // general solution to make `getOrderForDotOperand` function compatible
1202- // with Intel layouts.
1203- // More details:
1204- // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1205- if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent ())) {
1206- return ::getOrder (*this );
1207- } else {
1208- return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
1209- /* kMajor*/ true );
1210- }
1201+ // // FIXME: delete if branch for `DpasEncodingAttr` and provide more
1202+ // // general solution to make `getOrderForDotOperand` function compatible
1203+ // // with Intel layouts.
1204+ // // More details:
1205+ // // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1206+ // if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
1207+ // return ::getOrder(*this);
1208+ // } else {
1209+ // return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1210+ // /*kMajor*/ true);
1211+ // }
1212+ return getOrderForDotOperand (getOpIdx (), getWarpsPerCTA ().size (),
1213+ /* kMajor*/ true );
12111214}
12121215
12131216LogicalResult DotOperandEncodingAttr::verify (
@@ -1250,19 +1253,16 @@ LogicalResult DotOperandEncodingAttr::verify(
12501253 return success ();
12511254 }
12521255
1253- if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
1254- if (kWidth != parentAttr.getOpsPerChannel ())
1255- return emitError () << " ttg.dot_op kWidth parameter must match the "
1256- " parent's opsPerChannel" ;
1257- return success ();
1258- }
1259-
1260- if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
1261- if (kWidth != 0 )
1262- return emitError () << " ttg.dot_op kWidth parameter is not supported "
1263- " when the parent is a warp layout" ;
1256+ if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) {
12641257 return success ();
12651258 }
1259+ //
1260+ // if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
1261+ // if (kWidth != 0)
1262+ // return emitError() << "ttg.dot_op kWidth parameter is not supported "
1263+ // "when the parent is a warp layout";
1264+ // return success();
1265+ // }
12661266
12671267 if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
12681268 if (kWidth != 0 )
@@ -3248,8 +3248,7 @@ struct CanonicalizeConvertFromConvert
32483248 auto srcType = op.getSrc ().getType ();
32493249 auto dstType = op.getType ();
32503250 if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding ()) &&
3251- (mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding ()) ||
3252- mlir::isa<intel::DpasEncodingAttr>(srcType.getEncoding ())))
3251+ mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding ()))
32533252 return failure ();
32543253
32553254 // for hopper MMAv3
0 commit comments