Skip to content

Commit 4210274

Browse files
authored
[AMD] NFC: Drop v2 Suffix from Stream Pipeline (#5251)
Since StreamPipelineV2 has been the default for a while, this commit promoted StreamPipelineV2 to the general StreamPipeline by removing 'v2' suffix.
1 parent e1ebeed commit 4210274

File tree

10 files changed

+21
-21
lines changed

10 files changed

+21
-21
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6262
mlir::registerTritonAMDGPUAccelerateMatmul();
6363
mlir::registerTritonAMDGPUOptimizeEpilogue();
6464
mlir::registerTritonAMDGPUReorderInstructions();
65-
mlir::registerTritonAMDGPUStreamPipelineV2();
65+
mlir::registerTritonAMDGPUStreamPipeline();
6666
mlir::registerTritonAMDGPUCanonicalizePointers();
6767
mlir::registerTritonAMDGPUConvertToBufferOps();
6868
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();

test/TritonGPU/amd/amd-instruction-sched.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-0' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP0
22
// RUN: triton-opt %s -split-input-file -triton-amdgpu-insert-instruction-sched-hints -triton-amdgpu-lower-insert-instruction-sched-hints='variant=llvm-iglp-1' -verify-diagnostics | FileCheck %s -check-prefix=INSERT_IGLP1
3-
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
4-
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
5-
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
6-
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
7-
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline-v2='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2
3+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS1
4+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics | FileCheck %s -check-prefix=INSTR_COUNT_NS2
5+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritonamdgpu-accelerate-matmul='arch-generation-name=gfx942 matrix-instruction-size=32 kPack=1' -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' -triton-amdgpu-insert-instruction-sched-hints -decompose-unsupported-amd-conversions -optimize-amd-lds-usage='target-arch=gfx942' -convert-scf-to-cf -convert-index-to-llvm -allocate-shared-memory -convert-triton-amdgpu-to-llvm='arch=gfx942' -triton-amdgpu-lower-insert-instruction-sched-hints='variant=local-prefetch' -debug-only='lower-insert-instruction-sched-hints' -verify-diagnostics 2>&1 | FileCheck %s -check-prefix=USE_LOCAL_PREFETCH_GLOBAL_LOAD
6+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=1' | FileCheck %s -check-prefix=LABELING_PS_1
7+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=hip:gfx942 num-ctas=1 num-warps=4 threads-per-warp=64' -tritongpu-coalesce -tritongpu-remove-layout-conversions -tritonamdgpu-stream-pipeline='num_stages=2' | FileCheck %s -check-prefix=LABELING_PS_2
88

99
module {
1010
// INSERT_IGLP0-LABEL: @test_dot_op

test/TritonGPU/loop-pipeline-hip.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s
22

33
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 4], order = [0, 1]}>
44
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>

test/TritonGPU/loop-pipeline.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling=num-stages=3 -tritongpu-pipeline=num-stages=3 -canonicalize | FileCheck %s --check-prefixes=COMMON,CHECK
2-
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD
3-
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline-v2="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH
2+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline=num_stages=2 -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD
3+
// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline="num_stages=2 prefetch=1" -canonicalize | FileCheck %s --check-prefixes=COMMON,AMD_PREFETCH
44

55
// 4 warps
66
// matmul: 128x32 @ 32x128 -> 128x128

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def make_ttgir(mod, metadata, options):
241241
"num_stages == 0. Now it will not happen anymore; "
242242
"please update to use num_stages == 2 for "
243243
"equivalent behavior in the past.")
244-
amd.passes.ttgpuir.add_stream_pipelinev2(pm, options.num_stages, stream_prefetch)
244+
amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, stream_prefetch)
245245
passes.common.add_canonicalizer(pm)
246246
amd.passes.ttgpuir.insert_instruction_sched_hints(pm)
247247
passes.ttgpuir.add_optimize_dot_operands(pm, True)

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
namespace mlir {
99

10-
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelineV2Pass(int numStages = 2,
11-
int prefetch = 0);
10+
std::unique_ptr<Pass> createTritonAMDGPUStreamPipelinePass(int numStages = 2,
11+
int prefetch = 0);
1212

1313
std::unique_ptr<Pass>
1414
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33

44
include "mlir/Pass/PassBase.td"
55

6-
def TritonAMDGPUStreamPipelineV2 : Pass<"tritonamdgpu-stream-pipeline-v2", "mlir::ModuleOp"> {
6+
def TritonAMDGPUStreamPipeline : Pass<"tritonamdgpu-stream-pipeline", "mlir::ModuleOp"> {
77
let summary = "pipeline";
88

99
let description = [{
1010
Pipeline global loads through registers to shared memory while computing on previous
1111
tile
1212
}];
1313

14-
let constructor = "mlir::createTritonAMDGPUStreamPipelineV2Pass()";
14+
let constructor = "mlir::createTritonAMDGPUStreamPipelinePass()";
1515

1616
let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];
1717

third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ add_triton_library(TritonAMDGPUTransforms
44
ConvertToBufferOps.cpp
55
OptimizeEpilogue.cpp
66
ReorderInstructions.cpp
7-
StreamPipelineV2.cpp
7+
StreamPipeline.cpp
88
MfmaGroup.cpp
99

1010
DEPENDS

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp renamed to third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#define GEN_PASS_CLASSES
2626
#include "TritonAMDGPUTransforms/Passes.h.inc"
2727

28-
#define DEBUG_TYPE "tritonamdgpu-stream-pipeline-v2"
28+
#define DEBUG_TYPE "tritonamdgpu-stream-pipeline"
2929
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
3030
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
3131

@@ -857,7 +857,7 @@ void labelLoadOpsForTritonDot(scf::ForOp forOp) {
857857
}
858858
}
859859

860-
struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base<PipelinePass> {
860+
struct PipelinePass : public TritonAMDGPUStreamPipelineBase<PipelinePass> {
861861
PipelinePass() = default;
862862
PipelinePass(int32_t numStages, int32_t prefetch) {
863863
this->numStages = numStages;
@@ -893,7 +893,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineV2Base<PipelinePass> {
893893
};
894894
} // anonymous namespace
895895

896-
std::unique_ptr<Pass>
897-
mlir::createTritonAMDGPUStreamPipelineV2Pass(int numStages, int prefetch) {
896+
std::unique_ptr<Pass> mlir::createTritonAMDGPUStreamPipelinePass(int numStages,
897+
int prefetch) {
898898
return std::make_unique<PipelinePass>(numStages, prefetch);
899899
}

third_party/amd/python/triton_amd.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) {
7272
mlir::createTritonAMDGPUConvertToBufferOpsPass);
7373
ADD_PASS_WRAPPER_0("add_reorder_instructions",
7474
mlir::createTritonAMDGPUReorderInstructionsPass);
75-
ADD_PASS_WRAPPER_2("add_stream_pipelinev2",
76-
mlir::createTritonAMDGPUStreamPipelineV2Pass, int, int);
75+
ADD_PASS_WRAPPER_2("add_stream_pipeline",
76+
mlir::createTritonAMDGPUStreamPipelinePass, int, int);
7777
}
7878

7979
void addControlConstant(llvm::Module *module, const char *name,

0 commit comments

Comments
 (0)