Skip to content

Commit 85c4be7

Browse files
authored
Fix lowering bug for TF32 Type (#744)
1 parent d9f0127 commit 85c4be7

File tree

7 files changed

+285
-85
lines changed

7 files changed

+285
-85
lines changed

build_tools/patches/0008-amend-xegpu-transpose_bit_width-and-qualified-type-f.patch

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
From 1ee69567682e0f653d17d8eaaa3f54ec40201b44 Mon Sep 17 00:00:00 2001
1+
From 49cf7d3645dece35c0e5a4d48d2a00c801218656 Mon Sep 17 00:00:00 2001
22
From: Chao Chen <[email protected]>
3-
Date: Thu, 2 May 2024 14:53:44 +0000
4-
Subject: [PATCH 1/2] amend xegpu: transpose_bit_width and qualified type for
5-
atomic_amw
3+
Date: Fri, 10 May 2024 14:36:04 +0000
4+
Subject: [PATCH] amend xegpu defintion: - add transpose_bit_width for load nd
5+
- fix type print for atomic_rmw - relax dpas verfier to accept 2D operand
66

77
---
8-
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 +++---
9-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 10 ++++++++++
10-
mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 2 +-
11-
3 files changed, 14 insertions(+), 4 deletions(-)
8+
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 +++---
9+
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 19 ++++++++++++++-----
10+
mlir/test/Dialect/XeGPU/XeGPUOps.mlir | 2 +-
11+
3 files changed, 18 insertions(+), 9 deletions(-)
1212

1313
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
1414
index e477d9a0ca3f..5f95be1c87df 100644
@@ -42,7 +42,7 @@ index e477d9a0ca3f..5f95be1c87df 100644
4242
}
4343

4444
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
45-
index 22959224d56c..e550de6a97cd 100644
45+
index 22959224d56c..858afbd6d8aa 100644
4646
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
4747
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
4848
@@ -219,6 +219,16 @@ LogicalResult LoadNdOp::verify() {
@@ -62,6 +62,29 @@ index 22959224d56c..e550de6a97cd 100644
6262
if (array_len > 1) {
6363
auto it = tdescShape.begin();
6464
tdescShape.insert(it, array_len);
65+
@@ -413,9 +423,8 @@ LogicalResult DpasOp::verify() {
66+
int64_t lhsRank = getLhsType().getRank();
67+
int64_t rhsRank = getRhsType().getRank();
68+
69+
- if (lhsRank != rhsRank || lhsRank != 3)
70+
- return emitOpError(
71+
- "lhs and rhs rank does not match for dpas op, or their rank is not 3.");
72+
+ if (lhsRank != rhsRank)
73+
+ return emitOpError("lhs and rhs rank does not match for dpas op.");
74+
75+
if (getAcc() && getAccType() != getResultType())
76+
return emitOpError("Accumulator and Result for dpas op should have the "
77+
@@ -423,8 +432,8 @@ LogicalResult DpasOp::verify() {
78+
79+
auto lhsShape = getLhsType().getShape();
80+
auto rhsShape = getRhsType().getShape();
81+
- if (lhsShape[1] != rhsShape[0] || lhsShape[2] != rhsShape[2])
82+
- return emitOpError("K-dimension or vnni-factor mismatch.");
83+
+ if (lhsShape[1] != rhsShape[0])
84+
+ return emitOpError("K-dimension mismatch.");
85+
86+
return success();
87+
}
6588
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
6689
index 00d32d2a2ee9..ad037d3fbefd 100644
6790
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ lowerUnpackOrPack(XeGPUOneToNPatterRewriter &rewriter, mlir::Operation *op,
145145
llvm::ArrayRef<int64_t> inGrids,
146146
llvm::ArrayRef<int64_t> outGrids, bool isVnniFormat = false,
147147
bool isForDPASB = false) {
148+
148149
// handle based on the dim0, and save results into intermediates
149150
llvm::SmallVector<mlir::Value> intermediates;
150151
if (inBlkSizes[0] == outBlkSizes[0]) { // do nothing
@@ -269,9 +270,11 @@ class SgTileUnpackOpPattern
269270
// specific attention needed for vectors in vnni format,
270271
// which is applied to load for dpas.
271272
auto loadOp = op.getInVec().getDefiningOp<xetile::LoadTileOp>();
273+
auto elemTy = op.getInVec().getType().getElementType();
272274
bool isDpasA = loadOp && isForDPASA(loadOp);
273275
bool isDpasB = loadOp && isForDPASB(loadOp);
274-
bool isVnniFormat = isDpasA || isDpasB;
276+
bool isVnniFormat = (isDpasA || isDpasB) && elemTy.isIntOrFloat() &&
277+
elemTy.getIntOrFloatBitWidth() < 32;
275278

276279
llvm::ArrayRef<int64_t> outGrids;
277280
mlir::DenseI64ArrayAttr outBlkSizes;

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,23 +197,13 @@ static llvm::SmallVector<unsigned int>
197197
getMMASize(mlir::Type elemTy, const int APrecision, const int BPrecision,
198198
const int CPrecision, const int DPrecision,
199199
std::shared_ptr<XeuArchInterface> uArchInterface) {
200-
assert(elemTy.isIntOrFloat());
201-
auto bits = elemTy.getIntOrFloatBitWidth();
202-
imex::DPASConfig dpasParams;
203-
llvm::SmallVector<unsigned int> result;
204-
switch (bits) {
205-
case 16:
206-
dpasParams = uArchInterface->getDPASConfig(APrecision, BPrecision,
207-
CPrecision, DPrecision);
208-
result = llvm::SmallVector<unsigned int>(
209-
{dpasParams.m, dpasParams.k, dpasParams.n});
210-
break;
211-
default:
212-
result = llvm::SmallVector<unsigned int>({8, 8, 8});
213-
break;
214-
}
215-
return result;
200+
assert(elemTy.isIntOrFloat() && "unsupported element type.");
201+
auto dpasParams = uArchInterface->getDPASConfig(APrecision, BPrecision,
202+
CPrecision, DPrecision);
203+
return llvm::SmallVector<unsigned int>(
204+
{dpasParams.m, dpasParams.k, dpasParams.n});
216205
}
206+
217207
// it blocks a constant dense value if it is used by XeTile operators,
218208
// e.g, tile_mma and store_tile. It currently extends a 2D vector into
219209
// 4D vector with the last 2 dim corresponding to block size.

lib/Utils/XeArch.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,10 @@ mlir::LogicalResult XeuArchInterface::isLegalDpasOp(mlir::Operation *op) {
218218
return op->emitOpError() << "Unsupported dpas config";
219219
}
220220

221-
if ((lhsRank != rhsRank) || (lhsRank != 3)) {
222-
return op->emitOpError()
223-
<< "lhs and rhs rank does not match for dpas op, or "
224-
<< "their rank is not 3. "
225-
<< "\n"
226-
<< "lhsRank: " << lhsRank << "\n"
227-
<< "rhsRank:" << rhsRank;
221+
if (lhsRank != rhsRank) {
222+
return op->emitOpError() << "lhs and rhs rank does not match for dpas op "
223+
<< "(lhsRank: " << lhsRank << ", "
224+
<< "rhsRank:" << rhsRank << ").\n";
228225
}
229226

230227
DPASConfig dpasParams =
@@ -241,7 +238,8 @@ mlir::LogicalResult XeuArchInterface::isLegalDpasOp(mlir::Operation *op) {
241238
<< " dpas config: mxnxk = " << M << "x" << N << "x" << K;
242239
}
243240

244-
unsigned int BNumElements = rhsShape[0] * rhsShape[1] * rhsShape[2];
241+
unsigned int BNumElements = std::accumulate(
242+
rhsShape.begin(), rhsShape.end(), 1, std::multiplies<unsigned>());
245243
// Execution size for matrix B should match dpas params
246244
if (BNumElements != K * N) {
247245
return op->emitOpError()

0 commit comments

Comments
 (0)