From d92cd3850b8a6c7abd1fd7c3401e8cc36c1a7344 Mon Sep 17 00:00:00 2001 From: "Ling, Liyang" Date: Tue, 19 Aug 2025 13:38:13 +0000 Subject: [PATCH] Fix fp4 fp4 scaled_dot doest not use dpas issue --- third_party/intel/include/Analysis/DPAS.h | 1 + third_party/intel/lib/Analysis/DPAS.cpp | 4 ++-- .../lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/third_party/intel/include/Analysis/DPAS.h b/third_party/intel/include/Analysis/DPAS.h index 5787b8c16d..1bba3f8d16 100644 --- a/third_party/intel/include/Analysis/DPAS.h +++ b/third_party/intel/include/Analysis/DPAS.h @@ -38,6 +38,7 @@ class DPASAnalysis { FP32_FP32_FP4_BF16, FP32_FP32_FP4_FP16, FP32_FP32_FP4_FP8, + FP32_FP32_FP4_FP4, NOT_APPLICABLE }; diff --git a/third_party/intel/lib/Analysis/DPAS.cpp b/third_party/intel/lib/Analysis/DPAS.cpp index cdbd4c3f1a..a20cc53777 100644 --- a/third_party/intel/lib/Analysis/DPAS.cpp +++ b/third_party/intel/lib/Analysis/DPAS.cpp @@ -157,14 +157,12 @@ DPASAnalysis::getDPASType(OpTy op) { if (dElemTy.isF32()) { if (aElemTy.isBF16() && isa(bElemTy)) return DPASEngineType::FP32_FP32_BF16_FP8; - // 2 E2M1 are packed into 1 int8 if (aElemTy.isBF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_BF16_FP4; if (isa(aElemTy) && bElemTy.isBF16()) return DPASEngineType::FP32_FP32_FP8_BF16; if (aElemTy.isF16() && isa(bElemTy)) return DPASEngineType::FP32_FP32_FP16_FP8; - // 2 E2M1 are packed into 1 int8 if (aElemTy.isF16() && bElemTy.isInteger(8)) return DPASEngineType::FP32_FP32_FP16_FP4; if (isa(aElemTy) && bElemTy.isF16()) @@ -182,6 +180,8 @@ DPASAnalysis::getDPASType(OpTy op) { if (aElemTy.isInteger(8) && isa(bElemTy)) return DPASEngineType::FP32_FP32_FP4_FP8; + if (aElemTy.isInteger(8) && bElemTy.isInteger(8)) + return DPASEngineType::FP32_FP32_FP4_FP4; } } } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp index f3302f47bd..7adf9aa1f7 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/DecomposeScaledBlocked.cpp @@ -51,8 +51,9 @@ class DecomposeScaledBlocked : public OpRewritePattern { scaledA = cvtDotOperand(scaledA, 0); auto scaledB = scaleArg(rewriter, scaledDotOp, 1, computeType); scaledB = cvtDotOperand(scaledB, 1); - auto newDot = rewriter.create(scaledDotOp.getLoc(), scaledA, scaledB, - scaledDotOp.getC()); + auto newDot = + rewriter.create(scaledDotOp.getLoc(), scaledA, scaledB, + scaledDotOp.getC(), InputPrecision::TF32, 0); rewriter.replaceOpWithNewOp(scaledDotOp, scaledDotOp.getType(), newDot);