Skip to content

Commit b6e3836

Browse files
committed
Remove Intel code
1 parent b1c3c72 commit b6e3836

File tree

2 files changed

+50
-34
lines changed

2 files changed

+50
-34
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,11 @@ SmallVector<unsigned> getOrder(Attribute layout) {
325325
// with Intel layouts.
326326
// More details:
327327
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
328-
if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
329-
SmallVector<unsigned> order(rank);
330-
std::iota(order.rbegin(), order.rend(), 0);
331-
return order;
332-
}
328+
// if (dyn_cast<intel::DpasEncodingAttr>(dotLayout.getParent())) {
329+
// SmallVector<unsigned> order(rank);
330+
// std::iota(order.rbegin(), order.rend(), 0);
331+
// return order;
332+
// }
333333
return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true);
334334
}
335335
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(layout)) {
@@ -1129,10 +1129,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
11291129
return amdWmmaParent.getTotalElemsPerThreadForOperand(
11301130
shape, eltTy, getKWidth(), getOpIdx());
11311131
}
1132-
if (auto dpasParent = mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
1133-
return dpasParent.getTotalElemsPerThreadForOperand(
1134-
shape, eltTy, getKWidth(), getOpIdx());
1135-
}
1132+
// if (auto dpasParent =
1133+
// mlir::dyn_cast<intel::DpasEncodingAttr>(mmaParent)) {
1134+
// return dpasParent.getTotalElemsPerThreadForOperand(
1135+
// shape, eltTy, getKWidth(), getOpIdx());
1136+
// }
11361137
}
11371138
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
11381139
auto shapePerCTA = getShapePerCTA(*this, shape);
@@ -1197,17 +1198,19 @@ SmallVector<unsigned> DotOperandEncodingAttr::getWarpOrder() const {
11971198
return {};
11981199
}
11991200
SmallVector<unsigned> DotOperandEncodingAttr::getThreadOrder() const {
1200-
// FIXME: delete if branch for `DpasEncodingAttr` and provide more
1201-
// general solution to make `getOrderForDotOperand` function compatible
1202-
// with Intel layouts.
1203-
// More details:
1204-
// https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1205-
if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
1206-
return ::getOrder(*this);
1207-
} else {
1208-
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1209-
/*kMajor*/ true);
1210-
}
1201+
// // FIXME: delete if branch for `DpasEncodingAttr` and provide more
1202+
// // general solution to make `getOrderForDotOperand` function compatible
1203+
// // with Intel layouts.
1204+
// // More details:
1205+
// // https://github.com/intel/intel-xpu-backend-for-triton/pull/2517
1206+
// if (mlir::dyn_cast<intel::DpasEncodingAttr>(getParent())) {
1207+
// return ::getOrder(*this);
1208+
// } else {
1209+
// return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1210+
// /*kMajor*/ true);
1211+
// }
1212+
return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(),
1213+
/*kMajor*/ true);
12111214
}
12121215

12131216
LogicalResult DotOperandEncodingAttr::verify(
@@ -1250,19 +1253,16 @@ LogicalResult DotOperandEncodingAttr::verify(
12501253
return success();
12511254
}
12521255

1253-
if (auto parentAttr = mlir::dyn_cast<intel::DpasEncodingAttr>(parent)) {
1254-
if (kWidth != parentAttr.getOpsPerChannel())
1255-
return emitError() << "ttg.dot_op kWidth parameter must match the "
1256-
"parent's opsPerChannel";
1257-
return success();
1258-
}
1259-
1260-
if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
1261-
if (kWidth != 0)
1262-
return emitError() << "ttg.dot_op kWidth parameter is not supported "
1263-
"when the parent is a warp layout";
1256+
if (auto parentAttr = mlir::dyn_cast<MmaEncodingTrait>(parent)) {
12641257
return success();
12651258
}
1259+
//
1260+
// if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
1261+
// if (kWidth != 0)
1262+
// return emitError() << "ttg.dot_op kWidth parameter is not supported "
1263+
// "when the parent is a warp layout";
1264+
// return success();
1265+
// }
12661266

12671267
if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
12681268
if (kWidth != 0)
@@ -3248,8 +3248,7 @@ struct CanonicalizeConvertFromConvert
32483248
auto srcType = op.getSrc().getType();
32493249
auto dstType = op.getType();
32503250
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
3251-
(mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()) ||
3252-
mlir::isa<intel::DpasEncodingAttr>(srcType.getEncoding())))
3251+
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
32533252
return failure();
32543253

32553254
// for hopper MMAv3

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ along the row (resp. col) dimension.
125125
//===----------------------------------------------------------------------===//
126126

127127
def WarpEncodingAttr : TritonGPU_Attr<"WarpEncoding", "intel_warp_encoding",
128-
[], TritonIntelGPU_Dialect> {
128+
[MmaEncodingTrait], TritonIntelGPU_Dialect> {
129129
let mnemonic = "warp";
130130

131131
let description = [{
@@ -144,6 +144,23 @@ def WarpEncodingAttr : TritonGPU_Attr<"WarpEncoding", "intel_warp_encoding",
144144
let extraClassDeclaration = [{
145145
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
146146
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
147+
148+
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const {
149+
llvm::report_fatal_error("NYI. WarpEncodingAttr::getRepOrder");
150+
};
151+
152+
bool supportReduction() const {
153+
llvm::report_fatal_error("NYI. WarpEncodingAttr::supportReduction");
154+
};
155+
156+
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const {
157+
llvm::report_fatal_error("NYI. WarpEncodingAttr::getSizePerThreadForOperand");
158+
};
159+
160+
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const {
161+
llvm::report_fatal_error("NYI. WarpEncodingAttr::getElemsPerThreadForOperands");
162+
};
163+
147164
}];
148165

149166
let hasCustomAssemblyFormat = 1;

0 commit comments

Comments
 (0)