|
5 | 5 |
|
6 | 6 | #include "mlir/IR/DialectImplementation.h" |
7 | 7 | #include "mlir/IR/OpImplementation.h" |
8 | | - |
9 | | -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" |
10 | | - |
11 | 8 | #include "mlir/Support/LLVM.h" |
12 | 9 | #include "triton/Analysis/Utility.h" |
13 | 10 | #include "triton/Dialect/Triton/IR/Utility.h" |
@@ -316,11 +313,11 @@ SmallVector<unsigned> getOrder(Attribute layout) { |
316 | 313 | // with Intel layouts. |
317 | 314 | // More details: |
318 | 315 | // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 |
319 | | - if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) { |
320 | | - SmallVector<unsigned> order(rank); |
321 | | - std::iota(order.rbegin(), order.rend(), 0); |
322 | | - return order; |
323 | | - } |
| 316 | + // if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) { |
| 317 | + // SmallVector<unsigned> order(rank); |
| 318 | + // std::iota(order.rbegin(), order.rend(), 0); |
| 319 | + // return order; |
| 320 | + // } |
324 | 321 | return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); |
325 | 322 | } |
326 | 323 | if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) { |
@@ -1120,10 +1117,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape, |
1120 | 1117 | return amdWmmaParent.getTotalElemsPerThreadForOperand( |
1121 | 1118 | shape, eltTy, getKWidth(), getOpIdx()); |
1122 | 1119 | } |
1123 | | - if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) { |
1124 | | - return dpasParent.getTotalElemsPerThreadForOperand( |
1125 | | - shape, eltTy, getKWidth(), getOpIdx()); |
1126 | | - } |
| 1120 | + // if (auto dpasParent = |
| 1121 | + // mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) { |
| 1122 | + // return dpasParent.getTotalElemsPerThreadForOperand( |
| 1123 | + // shape, eltTy, getKWidth(), getOpIdx()); |
| 1124 | + // } |
1127 | 1125 | } |
1128 | 1126 | if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) { |
1129 | 1127 | auto shapePerCTA = getShapePerCTA(*this, shape); |
@@ -1188,17 +1186,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const { |
1188 | 1186 | return {}; |
1189 | 1187 | } |
1190 | 1188 | SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const { |
1191 | | - // FIXME: delete if branch for `DpasEncodingAttr` and provide more |
1192 | | - // general solution to make `getOrderForDotOperand` function compatible |
1193 | | - // with Intel layouts. |
1194 | | - // More details: |
1195 | | - // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 |
1196 | | - if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) { |
1197 | | - return ::getOrder(*this); |
1198 | | - } else { |
1199 | | - return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), |
1200 | | - /*kMajor*/ true); |
1201 | | - } |
| 1189 | + // // FIXME: delete if branch for `DpasEncodingAttr` and provide more |
| 1190 | + // // general solution to make `getOrderForDotOperand` function compatible |
| 1191 | + // // with Intel layouts. |
| 1192 | + // // More details: |
| 1193 | + // // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517 |
| 1194 | + // if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) { |
| 1195 | + // return ::getOrder(*this); |
| 1196 | + // } else { |
| 1197 | + // return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), |
| 1198 | + // /*kMajor*/ true); |
| 1199 | + // } |
| 1200 | + return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), |
| 1201 | + /*kMajor*/ true); |
1202 | 1202 | } |
1203 | 1203 |
|
1204 | 1204 | LogicalResult DotOperandEncodingAttr::verify( |
@@ -1241,19 +1241,19 @@ LogicalResult DotOperandEncodingAttr::verify( |
1241 | 1241 | return success(); |
1242 | 1242 | } |
1243 | 1243 |
|
1244 | | - if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) { |
1245 | | - if (kWidth != parentAttr.getOpsPerChannel()) |
1246 | | - return emitError() << "ttg.dot_op kWidth parameter must match the " |
1247 | | - "parent's opsPerChannel"; |
1248 | | - return success(); |
1249 | | - } |
1250 | | - |
1251 | | - if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) { |
1252 | | - if (kWidth != 0) |
1253 | | - return emitError() << "ttg.dot_op kWidth parameter is not supported " |
1254 | | - "when the parent is a warp layout"; |
1255 | | - return success(); |
1256 | | - } |
| 1244 | + // if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) { |
| 1245 | + // if (kWidth != parentAttr.getOpsPerChannel()) |
| 1246 | + // return emitError() << "ttg.dot_op kWidth parameter must match the " |
| 1247 | + // "parent's opsPerChannel"; |
| 1248 | + // return success(); |
| 1249 | + // } |
| 1250 | + // |
| 1251 | + // if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) { |
| 1252 | + // if (kWidth != 0) |
| 1253 | + // return emitError() << "ttg.dot_op kWidth parameter is not supported " |
| 1254 | + // "when the parent is a warp layout"; |
| 1255 | + // return success(); |
| 1256 | + // } |
1257 | 1257 |
|
1258 | 1258 | if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) { |
1259 | 1259 | if (kWidth != 0) |
@@ -2518,9 +2518,6 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { |
2518 | 2518 | } else if (auto linearAttr = mlir::dyn_cast<LinearEncodingAttr>(attr)) { |
2519 | 2519 | os << "linear"; |
2520 | 2520 | return AliasResult::FinalAlias; |
2521 | | - } else if (auto warpAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(attr)) { |
2522 | | - os << "warp"; |
2523 | | - return AliasResult::FinalAlias; |
2524 | 2521 | } /* else if (auto sliceAttr = dyn_cast<SliceEncodingAttr>(attr)) { |
2525 | 2522 | os << "slice"; |
2526 | 2523 | return AliasResult::FinalAlias; |
@@ -3239,8 +3236,7 @@ struct CanonicalizeConvertFromConvert |
3239 | 3236 | auto srcType = op.getSrc().getType(); |
3240 | 3237 | auto dstType = op.getType(); |
3241 | 3238 | if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) && |
3242 | | - (mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) || |
3243 | | - mlir::isa<intel::DpasEncodingAttr>(srcType.getEncoding()))) |
| 3239 | + mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding())) |
3244 | 3240 | return failure(); |
3245 | 3241 |
|
3246 | 3242 | // for hopper MMAv3 |
|
0 commit comments