Skip to content

Commit 6cb654f

Browse files
authored
Enable DPAS when sub-group-size=32 on GPU arch Xe+ and later. (#4869)
This PR enables DPAS (Dot Product Accumulate Systolic) when sub-group-size=32 on GPU by introducing a new environment variable TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32. When this environment variable is enabled, the system allows DPAS operations with warp sizes of 16 or 32 threads instead of being restricted to the minimum sub-group size. Key changes: - Adds environment variable support to conditionally enable DPAS for warp size 32 - Modifies DPAS analysis logic to support larger warp sizes when the flag is enabled - Updates the thread-per-warp setting logic to respect the new environment variable Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 5c914b6 commit 6cb654f

File tree

4 files changed

+59
-20
lines changed

4 files changed

+59
-20
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4848
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4949
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
5050
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
51+
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
5152
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
5253
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5354
"TRITON_INTEL_FAST_MATH",

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: env TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED=1 triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s
1+
// RUN: env TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32=1 TRITON_INTEL_DECOMPOSE_SCALED_BLOCKED=1 triton-opt %s -split-input-file --tritonintelgpu-accelerate-matmul | FileCheck %s
22

33
// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 1], A = [32, 16], B = [16, 16], C = [32, 16]}>
44
// CHECK: #[[$DPAS_1:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
@@ -368,3 +368,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
368368
tt.return
369369
}
370370
}
371+
372+
// -----
373+
374+
// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 32, warpsPerCTA = [1, 1], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}>
375+
#blocked = #ttg.blocked<{sizePerThread = [4, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
376+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32, "ttig.min_sg_size" = 16 : i32, "ttig.support_dpas"} {
377+
// CHECK-LABEL: dpas_sub_group_size_32
378+
tt.func @dpas_sub_group_size_32(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
379+
%zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked>
380+
%a = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
381+
%b = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
382+
383+
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 1}>> -> tensor<128x16xf32, #[[$DPAS]]>
384+
%result = tt.dot %a, %b, %zero_f32, inputPrecision = tf32 : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf32, #blocked>
385+
%result_ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x16x!tt.ptr<f32>, #blocked>
386+
tt.store %result_ptr, %result : tensor<128x16x!tt.ptr<f32>, #blocked>
387+
tt.return
388+
}
389+
}

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
33
#include "mlir/IR/BuiltinTypes.h"
44
#include "triton/Dialect/Triton/IR/Dialect.h"
5+
#include <triton/Tools/Sys/GetEnv.hpp>
56
#include <type_traits>
67

78
namespace mlir::triton::gpu::intel {
@@ -66,6 +67,17 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
6667
unsigned minSGSize = mod->getAttrOfType<IntegerAttr>(
6768
TritonIntelGPUDialect::getMinSGSizeAttrName())
6869
.getInt();
70+
bool enableWarp32 =
71+
tools::getBoolEnv("TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32");
72+
assert(minSGSize == 8 || minSGSize == 16 ||
73+
minSGSize == 32 && "Unexpected minimum subgroup size");
74+
75+
if (enableWarp32 && minSGSize != 8) {
76+
// We can support threads_per_warp=16 or 32 on Xe+ and later architectures.
77+
return (threadsPerWarp == 16 || threadsPerWarp == 32) ? Result::True
78+
: Result::False;
79+
}
80+
6981
return (threadsPerWarp == minSGSize) ? Result::True : Result::False;
7082
}
7183

third_party/intel/lib/TritonAnnotateModule/TritonAnnotateModule.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "intel/include/Analysis/DPAS.h"
22
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
33
#include "intel/include/TritonAnnotateModule/Passes.h"
4+
#include <triton/Tools/Sys/GetEnv.hpp>
45

56
namespace mlir::triton::gpu::intel {
67
#define GEN_PASS_DEF_TRITONANNOTATEMODULE
@@ -53,25 +54,31 @@ struct TritonAnnotateModule
5354
void setThreadsPerWarp(ModuleOp &mod,
5455
const DPASAnalysis &dpasAnalysis) const {
5556
Builder builder(mod);
56-
mod.walk([&](FunctionOpInterface funcOp) {
57-
// FIXME: DPAS lowering only implemented for 16 threads per warp, i.e.,
58-
// DPAS is not used for devices like ATS.
59-
constexpr unsigned supportedThreadsPerWarp = 16;
60-
if (minSGSize != supportedThreadsPerWarp)
61-
return WalkResult::interrupt();
62-
63-
if (dpasAnalysis.canUseDPAS(funcOp) == DPASAnalysis::Result::Maybe) {
64-
// Set the threads per warp attribute to allow dot operation to be
65-
// lowered to DPAS instructions.
66-
mod->setAttr(AttrNumThreadsPerWarp,
67-
builder.getI32IntegerAttr(minSGSize));
68-
assert(dpasAnalysis.canUseDPAS(funcOp) == DPASAnalysis::Result::True &&
69-
"DPASAnalysis should report that dot operations can be "
70-
"lowered to DPAS instructions");
71-
return WalkResult::interrupt();
72-
}
73-
return WalkResult::advance();
74-
});
57+
58+
bool enableWarp32 = mlir::triton::tools::getBoolEnv(
59+
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32");
60+
if (!enableWarp32) {
61+
mod.walk([&](FunctionOpInterface funcOp) {
62+
// DPAS lowering only implemented for 16 threads per warp, i.e., DPAS is
63+
// not used for devices like ATS.
64+
constexpr unsigned supportedThreadsPerWarp = 16;
65+
if (minSGSize != supportedThreadsPerWarp)
66+
return WalkResult::interrupt();
67+
68+
if (dpasAnalysis.canUseDPAS(funcOp) == DPASAnalysis::Result::Maybe) {
69+
// Set the threads per warp attribute to allow dot operation to be
70+
// lowered to DPAS instructions.
71+
mod->setAttr(AttrNumThreadsPerWarp,
72+
builder.getI32IntegerAttr(minSGSize));
73+
assert(dpasAnalysis.canUseDPAS(funcOp) ==
74+
DPASAnalysis::Result::True &&
75+
"DPASAnalysis should report that dot operations can be "
76+
"lowered to DPAS instructions");
77+
return WalkResult::interrupt();
78+
}
79+
return WalkResult::advance();
80+
});
81+
}
7582

7683
// If the threads per warp attribute was not set, use the option value.
7784
if (!mod->hasAttr(AttrNumThreadsPerWarp))

0 commit comments

Comments
 (0)