Skip to content

Commit 62094e1

Browse files
authored
[PROTON] Refactor the finalizeOp to reduce buffer write overhead (#8635)
Each warp only writes its own data and utilizes multiple threads when permitted. Simple kernels have overhead reduced from 2x to 1.1-1.0x without tuning `buffer_size`. Long running kernels have overhead reduced from 2.6x to 1.8-1.3x without tuning `buffer_size`.
1 parent a5b948c commit 62094e1

File tree

9 files changed

+394
-125
lines changed

9 files changed

+394
-125
lines changed

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ test-interpret: all
7373

7474
.PHONY: test-proton
7575
test-proton: all
76-
$(PYTEST) --tb=short -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
76+
$(PYTEST) --tb=short -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py -k "not test_overhead"
7777
$(PYTEST) --tb=short -s third_party/proton/test/test_override.py
78+
$(PYTEST) --tb=short -s third_party/proton/test/test_instrumentation.py::test_overhead
7879

7980
.PHONY: test-python
8081
test-python: test-unit test-regression test-interpret test-proton

test/Proton/amd/protongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignme
139139
// CHECK-LABEL: convert_smem_finalize
140140
// CONVERT-BUILTIN: llvm.call_intrinsic "llvm.amdgcn.s.memrealtime"() : () -> i64
141141
// CONVERT-BUILTIN: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr<1>
142-
// CONVERT-BUILTIN: llvm.br ^bb{{.*}}(%{{.*}} : i32)
142+
// CONVERT-BUILTIN: llvm.cond_br %{{.*}}, ^bb{{.*}}, ^bb{{.*}}
143143
// CONVERT-BUILTIN: llvm.call_intrinsic "llvm.amdgcn.s.memrealtime"() : () -> i64
144144
// CONVERT-BUILTIN: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr<1>
145145
// CONVERT-BUILTIN: llvm.br ^bb{{.*}}

test/Proton/nvidia/protongpu_to_llvm.mlir

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -174,33 +174,37 @@ module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignme
174174
#smem = #ttg.shared_memory
175175
module attributes {"ttg.num-warps" = 8 : i32, ttg.profile_scratch_memory_alignment = 128 : i32, ttg.profile_scratch_memory_size = 384 : i32} {
176176
// CHECK-LABEL: convert_smem_finalize
177-
// CHECK-DAG: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<3>, i32)>
178-
// CHECK-DAG: llvm.store
179-
// CHECK-DAG: llvm.cond_br %{{.*}}, ^bb1, ^bb3
180-
// CHECK-DAG: %[[BUFFER_SIZE:.*]] = llvm.mlir.constant(2048 : i32) : i32
181-
// CHECK-DAG: %[[BUFFER_SIZE_OFFSET:.*]] = llvm.mlir.constant(3 : i32) : i32
182-
// CHECK-DAG: %[[BUFFER_SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[BUFFER_SIZE_OFFSET]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>
183-
// CHECK-DAG: llvm.store %[[BUFFER_SIZE]], %[[BUFFER_SIZE_PTR]] : i32, !llvm.ptr<1>
184-
// CHECK-DAG: %[[PRE_FINAL_TIME:.*]] = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.globaltimer"() : () -> i64
185-
// CHECK-DAG: %[[PRE_FINAL_TIME_OFFSET:.*]] = llvm.mlir.constant(6 : i32) : i32
186-
// CHECK-DAG: %[[PRE_FINAL_TIME_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[PRE_FINAL_TIME_OFFSET]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1
187-
// CHECK-DAG: llvm.store %[[PRE_FINAL_TIME]], %[[PRE_FINAL_TIME_PTR]] : i64, !llvm.ptr<1>
188-
// CHECK-DAG: ^bb1:
189-
// CHECK-DAG: llvm.br ^bb2
190-
// CHECK-DAG: ^bb2(%[[I:.*]]: i32):
191-
// CHECK-DAG: llvm.store
192-
// CHECK-DAG: llvm.store
193-
// CHECK-DAG: %[[UPPER:.*]] = llvm.mlir.constant(510 : i32) : i32
194-
// CHECK-DAG: %[[P2:.*]] = llvm.icmp "slt" %[[I]], %[[UPPER]] : i32
195-
// CHECK-DAG: %[[STEP:.*]] = llvm.mlir.constant(2 : i32) : i32
196-
// CHECK-DAG: %[[I_NEW:.*]] = llvm.add %[[I]], %[[STEP]] : i32
197-
// CHECK-DAG: llvm.cond_br %[[P2]], ^bb2(%[[I_NEW]] : i32), ^bb3
198-
// CHECK-DAG: ^bb3:
199-
// CHECK-DAG: %[[POST_FINAL_TIME:.*]] = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.globaltimer"() : () -> i64
200-
// CHECK-DAG: %[[POST_FINAL_TIME_OFFSET:.*]] = llvm.mlir.constant(8 : i32) : i32
201-
// CHECK-DAG: %[[POST_FINAL_TIME_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[POST_FINAL_TIME_OFFSET]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1
202-
// CHECK-DAG: llvm.store %[[POST_FINAL_TIME]], %[[POST_FINAL_TIME_PTR]] : i64, !llvm.ptr<1>
203-
// CHECK-DAG: llvm.return
177+
// CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr<3>, i32)>
178+
// CHECK: llvm.store
179+
// CHECK: llvm.cond_br %{{.*}}, ^bb1, ^bb2
180+
// CHECK: ^bb1: // pred: ^bb0
181+
// CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<1>
182+
// CHECK: llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.globaltimer"() : () -> i64
183+
// CHECK: llvm.store %{{.*}}, %{{.*}} : i64, !llvm.ptr<1>
184+
// CHECK: llvm.br ^bb2
185+
// CHECK: ^bb2: // 2 preds: ^bb0, ^bb1
186+
// CHECK: llvm.cond_br %{{.*}}, ^bb3, ^bb4
187+
// CHECK: ^bb3: // pred: ^bb2
188+
// CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<1>
189+
// CHECK: llvm.br ^bb4
190+
// CHECK: ^bb4: // 2 preds: ^bb2, ^bb3
191+
// CHECK: llvm.cond_br %{{.*}}, ^[[LOOP_HEAD:bb[0-9]+]](%{{.*}} : i32), ^[[EXIT:bb[0-9]+]]
192+
// CHECK: ^[[LOOP_HEAD]](%{{.*}}: i32):
193+
// CHECK: llvm.cond_br %{{.*}}, ^[[LOOP_BODY:bb[0-9]+]](%{{.*}} : i32), ^[[EXIT]]
194+
// CHECK: ^[[LOOP_BODY]](%{{.*}}: i32):
195+
// CHECK: llvm.getelementptr
196+
// CHECK: llvm.store
197+
// CHECK: llvm.store
198+
// CHECK: ^[[EXIT]]:
199+
// CHECK: llvm.cond_br %{{.*}}, ^[[POST:bb[0-9]+]], ^[[RET:bb[0-9]+]]
200+
// CHECK: ^[[POST]]:
201+
// CHECK: %{{.*}} = llvm.mlir.constant(8 : i32) : i32
202+
// CHECK: %[[POST_FINAL_TIME_PTR:.*]] = llvm.getelementptr %{{.*}}{{\[}}%{{.*}}{{\]}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i32
203+
// CHECK: %[[POST_FINAL_TIME:.*]] = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.globaltimer"() : () -> i64
204+
// CHECK: llvm.store %[[POST_FINAL_TIME]], %[[POST_FINAL_TIME_PTR]] : i64, !llvm.ptr<1>
205+
// CHECK: llvm.br ^[[RET]]
206+
// CHECK: ^[[RET]]:
207+
// CHECK: llvm.return
204208
llvm.func @convert_smem_finalize(%arg: !llvm.ptr<1>) attributes {noinline = false, nvvm.kernel = 1 : ui1} {
205209
%0 = ttg.local_alloc : () -> !ttg.memdesc<512xi32, #shared, #smem, mutable>
206210
%1 = proton_gpu.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32, offset = 0 : i32} : !tt.ptr<i32>

test/Proton/proton_to_protongpu.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ module attributes {"ttg.num-warps" = 8 : i32} {
136136
// CHECK: proton_gpu.circular_store start %[[SEGMENT2]], %[[COUNTER4]] {scopeId = 2 : i32} : !proton_gpu.segment<1024, #smem, warp>, i32
137137
// CHECK: %[[COUNTER5:.*]] = proton_gpu.read_counter : i32
138138
// CHECK: proton_gpu.circular_store end %[[SEGMENT2]], %[[COUNTER5]] {scopeId = 2 : i32} : !proton_gpu.segment<1024, #smem, warp>, i32
139-
// CHECK: proton_gpu.save_ctx %[[SEGMENT2]], %[[ARG1]] : !proton_gpu.segment<1024, #smem, warp>, !tt.ptr<i32>
140139
// CHECK: ttg.warp_return
141140
// CHECK: } : (!ttg.memdesc<256xi32, #shared, #smem, mutable>, !tt.ptr<i32>) -> ()
142141
// CHECK: %[[COUNTER6:.*]] = proton_gpu.read_counter : i32

0 commit comments

Comments
 (0)