Skip to content

Commit 7f94609

Browse files
authored
[AMD] Support splatted scale in MFMA (#7270)
1 parent dfbea72 commit 7f94609

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

python/test/unit/language/test_matmul.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,8 +885,6 @@ def test_mxfp8_mxfp4_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, B_TR
885885
pytest.skip("Pack along M/N is not enabled on AMD backend")
886886
if not is_hip_cdna4():
887887
pytest.skip("Scaled mxfp4 & mxfp8 matmul is only natively supported on CDNA4")
888-
if CONST_SCALE:
889-
pytest.skip("Constant scale is not supported in AMD backend for now")
890888
if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64):
891889
pytest.skip(f"CDNA4 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants")
892890
if (A_DATA_TYPE == 'float4' and not WITH_A_SCALE) or (B_DATA_TYPE == 'float4' and not WITH_B_SCALE):

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,10 @@ struct ScaledDotOpMFMAConversionHelper : DotOpMFMAConversionHelper {
558558
}
559559

560560
bool existBothScales = aScale && bScale;
561-
bool isAScaleConstant = aScale && aScale.getDefiningOp<arith::ConstantOp>();
562-
bool isBScaleConstant = bScale && bScale.getDefiningOp<arith::ConstantOp>();
561+
bool isAScaleConstant = aScale && isa<arith::ConstantOp, triton::SplatOp>(
562+
aScale.getDefiningOp());
563+
bool isBScaleConstant = bScale && isa<arith::ConstantOp, triton::SplatOp>(
564+
bScale.getDefiningOp());
563565
Value d = op.getD();
564566
auto aTensorTy = cast<RankedTensorType>(a.getType());
565567
auto bTensorTy = cast<RankedTensorType>(b.getType());

0 commit comments

Comments
 (0)