Skip to content

Commit 108e8a1

Browse files
authored
Fix mxfp8 mxfp4 matmul shape mismatch (#4060)
This PR fix the shape mismatch because of different float4 pack dim from the change of `B_TRANS` and `PACK_B_ALONG_K`.
1 parent 1112516 commit 108e8a1

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

python/test/unit/intel/test_mxfp_matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,6 @@ def test_mxfp_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TRANS, PA
111111
pytest.skip("Float4 for both A and B has [ZE]0x78000011 error")
112112
if not PACK_B_ALONG_K and B_DATA_TYPE != "float4":
113113
pytest.xfail("Pack along K can only be False for float4")
114-
if not PACK_B_ALONG_K and B_DATA_TYPE == "float4":
115-
pytest.skip("Pack along K fix depends on https://github.com/intel/intel-xpu-backend-for-triton/pull/4060")
116114

117115
if BLOCK_N == 256 and BLOCK_K == 256:
118116
NUM_STAGES = 2

third_party/intel/lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
185185
DotScaledOp scaledDotOp, int opIdx,
186186
FloatType computeType) const {
187187
auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB();
188+
auto res = scaledDotOp.getD();
188189
auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale();
189190
auto isFp4 =
190191
ScaleDotElemType::E2M1 ==
@@ -199,8 +200,14 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
199200

200201
// 0) Upcast value to computeType (fp16/bf16)
201202
if (isFp4) {
202-
// We always pack along the fastest moving dimension, kDim
203-
v = rewriter.create<Fp4ToFpOp>(loc, v, computeType, kDim);
203+
auto resShape = res.getType().getShape();
204+
auto vShape = v.getType().getShape();
205+
auto packDim = kDim;
206+
if ((opIdx == 0 && resShape[rank - 2] != vShape[rank - 2]) ||
207+
(opIdx == 1 && resShape[rank - 1] != vShape[rank - 1])) {
208+
packDim = (packDim + 1) % 2;
209+
}
210+
v = rewriter.create<Fp4ToFpOp>(loc, v, computeType, packDim);
204211
} else {
205212
auto vType16 = v.getType().clone(computeType);
206213
v = cast<TypedValue<RankedTensorType>>(

0 commit comments

Comments
 (0)