Skip to content

Commit 06eaead

Browse files
authored
[Codegen] Run LoopCoalescingPass at the end of warp reduce (#19950)
I observed that this improved some of the punet dispatches using warp reduction from 80us to 60us ([example](https://gist.github.com/IanWood1/5b0e4fcb4e90a02525b94ea4347145f5)). However, I'm not seeing the improvement through the noise of CI. Locally, I saw ~0.7ms improvement to punet. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 7c0259c commit 06eaead

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,8 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
943943
/*expandSubgroupReduction=*/true));
944944
funcPassManager.addPass(createCanonicalizerPass());
945945
funcPassManager.addPass(createCSEPass());
946+
funcPassManager.addPass(affine::createLoopCoalescingPass());
947+
funcPassManager.addPass(createCanonicalizerPass());
946948
}
947949

948950
void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {

compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,83 @@ hal.executable private @matvec_fp16 {
333333
// CDNA3-COUNT-24: gpu.shuffle xor
334334
// CDNA3: scf.if
335335
// CDNA3: vector.transfer_write {{.+}} : vector<8xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
336+
337+
// -----
338+
339+
#pipeline_layout = #hal.pipeline.layout<bindings = [
340+
#hal.pipeline.binding<storage_buffer>,
341+
#hal.pipeline.binding<storage_buffer>
342+
]>
343+
hal.executable public @multi_reduction {
344+
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
345+
hal.executable.export public @multi_reduction ordinal(0) layout(#pipeline_layout) {
346+
^bb0(%arg0: !hal.device):
347+
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
348+
hal.return %x, %y, %z : index, index, index
349+
}
350+
builtin.module {
351+
func.func @multi_reduction() {
352+
%cst = arith.constant 0.000000e+00 : f32
353+
%cst_0 = arith.constant 2.304000e+05 : f32
354+
%cst_1 = arith.constant 9.99999974E-6 : f32
355+
%c85483008 = arith.constant 85483008 : index
356+
%c165416448 = arith.constant 165416448 : index
357+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c85483008) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<2x32x60x3840xf16>>
358+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c165416448) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<2x32x60x3840xf32>>
359+
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 32, 60, 3840], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x32x60x3840xf16>> -> tensor<2x32x60x3840xf16>
360+
%3 = tensor.empty() : tensor<2x32x60x3840xf32>
361+
%4 = tensor.empty() : tensor<2x32xf32>
362+
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<2x32x60x3840xf16>) outs(%3 : tensor<2x32x60x3840xf32>) {
363+
^bb0(%in: f16, %out: f32):
364+
%11 = arith.extf %in : f16 to f32
365+
linalg.yield %11 : f32
366+
} -> tensor<2x32x60x3840xf32>
367+
%6 = linalg.fill ins(%cst : f32) outs(%4 : tensor<2x32xf32>) -> tensor<2x32xf32>
368+
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5 : tensor<2x32x60x3840xf32>) outs(%6 : tensor<2x32xf32>) {
369+
^bb0(%in: f32, %out: f32):
370+
%11 = arith.addf %in, %out : f32
371+
linalg.yield %11 : f32
372+
} -> tensor<2x32xf32>
373+
%8 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%7 : tensor<2x32xf32>) outs(%4 : tensor<2x32xf32>) {
374+
^bb0(%in: f32, %out: f32):
375+
%11 = arith.divf %in, %cst_0 : f32
376+
linalg.yield %11 : f32
377+
} -> tensor<2x32xf32>
378+
%9 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%5, %8 : tensor<2x32x60x3840xf32>, tensor<2x32xf32>) outs(%6 : tensor<2x32xf32>) {
379+
^bb0(%in: f32, %in_2: f32, %out: f32):
380+
%11 = arith.subf %in, %in_2 : f32
381+
%12 = arith.mulf %11, %11 : f32
382+
%13 = arith.addf %12, %out : f32
383+
linalg.yield %13 : f32
384+
} -> tensor<2x32xf32>
385+
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2, %8, %9 : tensor<2x32x60x3840xf16>, tensor<2x32xf32>, tensor<2x32xf32>) outs(%3 : tensor<2x32x60x3840xf32>) {
386+
^bb0(%in: f16, %in_2: f32, %in_3: f32, %out: f32):
387+
%11 = arith.divf %in_3, %cst_0 : f32
388+
%12 = arith.addf %11, %cst_1 : f32
389+
%13 = math.rsqrt %12 : f32
390+
%14 = arith.extf %in : f16 to f32
391+
%15 = arith.subf %14, %in_2 : f32
392+
%16 = arith.mulf %15, %13 : f32
393+
linalg.yield %16 : f32
394+
} -> tensor<2x32x60x3840xf32>
395+
flow.dispatch.tensor.store %10, %1, offsets = [0, 0, 0, 0], sizes = [2, 32, 60, 3840], strides = [1, 1, 1, 1] : tensor<2x32x60x3840xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x32x60x3840xf32>>
396+
return
397+
}
398+
}
399+
}
400+
}
401+
402+
// Check that all loops are singly nested.
403+
//
404+
// CHECK: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [64, 1, 1] subgroup_size = 64>
405+
// CHECK: func.func @multi_reduction()
406+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
407+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
408+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
409+
// CHECK-DAG: %[[C225:.+]] = arith.constant 225 : index
410+
// CHECK: %[[RES0:.+]] = scf.for %[[ARG0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C225]] step %[[C1]]
411+
// CHECK-NEXT: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (15, 15) : index, index
412+
// CHECK: %[[RES1:.+]] = scf.for %[[ARG0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C225]] step %[[C1]]
413+
// CHECK-NEXT: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (15, 15) : index, index
414+
// CHECK: scf.for %[[ARG0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C225]] step %[[C1]]
415+
// CHECK-NEXT: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (15, 15) : index, index

0 commit comments

Comments
 (0)