|
5 | 5 |
|
6 | 6 | #include "mlir/IR/DialectImplementation.h" |
7 | 7 | #include "mlir/IR/OpImplementation.h" |
8 | | - |
9 | | -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" |
10 | | - |
11 | 8 | #include "mlir/Support/LLVM.h" |
12 | 9 | #include "triton/Analysis/Utility.h" |
13 | 10 | #include "triton/Dialect/Triton/IR/Utility.h" |
@@ -325,11 +322,11 @@ SmallVector<unsigned> getOrder(Attribute layout) { |
325 | 322 | // with Intel layouts. |
326 | 323 | // More details: |
327 | 324 | // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 |
328 | | - if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) { |
329 | | - SmallVector<unsigned> order(rank); |
330 | | - std::iota(order.rbegin(), order.rend(), 0); |
331 | | - return order; |
332 | | - } |
| 325 | + // if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) { |
| 326 | + // SmallVector<unsigned> order(rank); |
| 327 | + // std::iota(order.rbegin(), order.rend(), 0); |
| 328 | + // return order; |
| 329 | + // } |
333 | 330 | return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); |
334 | 331 | } |
335 | 332 | if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) { |
@@ -1129,10 +1126,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape, |
1129 | 1126 | return amdWmmaParent.getTotalElemsPerThreadForOperand( |
1130 | 1127 | shape, eltTy, getKWidth(), getOpIdx()); |
1131 | 1128 | } |
1132 | | - if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) { |
1133 | | - return dpasParent.getTotalElemsPerThreadForOperand( |
1134 | | - shape, eltTy, getKWidth(), getOpIdx()); |
1135 | | - } |
| 1129 | + // if (auto dpasParent = |
| 1130 | + // mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) { |
| 1131 | + // return dpasParent.getTotalElemsPerThreadForOperand( |
| 1132 | + // shape, eltTy, getKWidth(), getOpIdx()); |
| 1133 | + // } |
1136 | 1134 | } |
1137 | 1135 | if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) { |
1138 | 1136 | auto shapePerCTA = getShapePerCTA(*this, shape); |
@@ -1197,17 +1195,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const { |
1197 | 1195 | return {}; |
1198 | 1196 | } |
1199 | 1197 | SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const { |
1200 | | - // FIXME: delete if branch for `DpasEncodingAttr` and provide more |
1201 | | - // general solution to make `getOrderForDotOperand` function compatible |
1202 | | - // with Intel layouts. |
1203 | | - // More details: |
1204 | | - // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 |
1205 | | - if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) { |
1206 | | - return ::getOrder(*this); |
1207 | | - } else { |
1208 | | - return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), |
1209 | | - /*kMajor*/ true); |
1210 | | - } |
| 1198 | + // // FIXME: delete if branch for `DpasEncodingAttr` and provide more |
| 1199 | + // // general solution to make `getOrderForDotOperand` function compatible |
| 1200 | + // // with Intel layouts. |
| 1201 | + // // More details: |
| 1202 | + // // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 |
| 1203 | + // if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) { |
| 1204 | + // return ::getOrder(*this); |
| 1205 | + // } else { |
| 1206 | + // return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), |
| 1207 | + // /*kMajor*/ true); |
| 1208 | + // } |
| 1209 | + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), |
| 1210 | + /*kMajor*/ true); |
1211 | 1211 | } |
1212 | 1212 |
|
1213 | 1213 | LogicalResult DotOperandEncodingAttr::verify( |
@@ -1250,19 +1250,16 @@ LogicalResult DotOperandEncodingAttr::verify( |
1250 | 1250 | return success(); |
1251 | 1251 | } |
1252 | 1252 |
|
1253 | | - if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) { |
1254 | | - if (kWidth != parentAttr.getOpsPerChannel()) |
1255 | | - return emitError() << "ttg.dot_op kWidth parameter must match the " |
1256 | | - "parent's opsPerChannel"; |
1257 | | - return success(); |
1258 | | - } |
1259 | | - |
1260 | | - if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) { |
1261 | | - if (kWidth != 0) |
1262 | | - return emitError() << "ttg.dot_op kWidth parameter is not supported " |
1263 | | - "when the parent is a warp layout"; |
| 1253 | + if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) { |
1264 | 1254 | return success(); |
1265 | 1255 | } |
| 1256 | + // |
| 1257 | + // if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) { |
| 1258 | + // if (kWidth != 0) |
| 1259 | + // return emitError() << "ttg.dot_op kWidth parameter is not supported " |
| 1260 | + // "when the parent is a warp layout"; |
| 1261 | + // return success(); |
| 1262 | + // } |
1266 | 1263 |
|
1267 | 1264 | if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) { |
1268 | 1265 | if (kWidth != 0) |
@@ -2527,9 +2524,6 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { |
2527 | 2524 | } else if (auto linearAttr = mlir::dyn_cast<LinearEncodingAttr>(attr)) { |
2528 | 2525 | os << "linear"; |
2529 | 2526 | return AliasResult::FinalAlias; |
2530 | | - } else if (auto warpAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(attr)) { |
2531 | | - os << "warp"; |
2532 | | - return AliasResult::FinalAlias; |
2533 | 2527 | } /* else if (auto sliceAttr = dyn_cast<SliceEncodingAttr>(attr)) { |
2534 | 2528 | os << "slice"; |
2535 | 2529 | return AliasResult::FinalAlias; |
@@ -3248,8 +3242,7 @@ struct CanonicalizeConvertFromConvert |
3248 | 3242 | auto srcType = op.getSrc().getType(); |
3249 | 3243 | auto dstType = op.getType(); |
3250 | 3244 | if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) && |
3251 | | - (mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) || |
3252 | | - mlir::isa<intel::DpasEncodingAttr>(srcType.getEncoding()))) |
| 3245 | + mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding())) |
3253 | 3246 | return failure(); |
3254 | 3247 |
|
3255 | 3248 | // for hopper MMAv3 |
|
0 commit comments