-
Notifications
You must be signed in to change notification settings - Fork 846
Description
Steps to reproduce
Compile
# Baseline:
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx950 --iree-opt-level=O3 \
--iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops \
--iree-dispatch-creation-enable-split-reduction \
--iree-llvmgpu-test-load-to-transpose-load=false \
'--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)' \
--iree-hal-dump-executable-files-to=files_false \
input.mlir -o output_false.vmfb
# Transpose load:
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx950 --iree-opt-level=O3 \
--iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops \
--iree-dispatch-creation-enable-split-reduction \
--iree-llvmgpu-test-load-to-transpose-load=true \
'--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)' \
--iree-hal-dump-executable-files-to=files_true \
input.mlir -o output_true.vmfbBenchmark
iree-benchmark-module --module=output_false.vmfb --device=hip --benchmark_repetitions=10
iree-benchmark-module --module=output_true.vmfb --device=hip --benchmark_repetitions=10Results on MI355X:
| Case | Flag | Median Time |
|---|---|---|
| Baseline | false |
502.0 us |
| Transpose load | true |
654.2 us |
Input MLIR
Conv backward weight, 1x1 filter, stride 2, f16, n=32 c=1024 H=50 W=50 k=2048. See input.mlir in the attached zip.
input.mlir (click to expand)
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5 * 2, d2 + d6 * 2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
module {
func.func public @conv1x1_bwd_weight(%arg0: !torch.vtensor<[32,25,25,2048],f16>, %arg1: !torch.vtensor<[32,50,50,1024],f16>, %arg2: !torch.vtensor<[2048,1,1,1024],f16>) -> !torch.vtensor<[2048,1,1,1024],f16> {
%int0 = torch.constant.int 0
%int3 = torch.constant.int 3
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%0 = torch.prim.ListConstruct %int0, %int3, %int1, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[32,25,25,2048],f16>, !torch.list<int> -> !torch.vtensor<[32,2048,25,25],f16>
%int0_0 = torch.constant.int 0
%int3_1 = torch.constant.int 3
%int1_2 = torch.constant.int 1
%int2_3 = torch.constant.int 2
%2 = torch.prim.ListConstruct %int0_0, %int3_1, %int1_2, %int2_3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.permute %arg1, %2 : !torch.vtensor<[32,50,50,1024],f16>, !torch.list<int> -> !torch.vtensor<[32,1024,50,50],f16>
%int0_4 = torch.constant.int 0
%int3_5 = torch.constant.int 3
%int1_6 = torch.constant.int 1
%int2_7 = torch.constant.int 2
%4 = torch.prim.ListConstruct %int0_4, %int3_5, %int1_6, %int2_7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.permute %arg2, %4 : !torch.vtensor<[2048,1,1,1024],f16>, !torch.list<int> -> !torch.vtensor<[2048,1024,1,1],f16>
%int0_8 = torch.constant.int 0
%int2_9 = torch.constant.int 2
%int3_10 = torch.constant.int 3
%int1_11 = torch.constant.int 1
%6 = torch.prim.ListConstruct %int0_8, %int2_9, %int3_10, %int1_11 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%7 = torch.aten.permute %1, %6 : !torch.vtensor<[32,2048,25,25],f16>, !torch.list<int> -> !torch.vtensor<[32,25,25,2048],f16>
%int0_12 = torch.constant.int 0
%int2_13 = torch.constant.int 2
%int3_14 = torch.constant.int 3
%int1_15 = torch.constant.int 1
%8 = torch.prim.ListConstruct %int0_12, %int2_13, %int3_14, %int1_15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%9 = torch.aten.permute %3, %8 : !torch.vtensor<[32,1024,50,50],f16>, !torch.list<int> -> !torch.vtensor<[32,50,50,1024],f16>
%int0_16 = torch.constant.int 0
%int2_17 = torch.constant.int 2
%int3_18 = torch.constant.int 3
%int1_19 = torch.constant.int 1
%10 = torch.prim.ListConstruct %int0_16, %int2_17, %int3_18, %int1_19 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%11 = torch.aten.permute %5, %10 : !torch.vtensor<[2048,1024,1,1],f16>, !torch.list<int> -> !torch.vtensor<[2048,1,1,1024],f16>
%12 = torch_c.to_builtin_tensor %7 : !torch.vtensor<[32,25,25,2048],f16> -> tensor<32x25x25x2048xf16>
%13 = torch_c.to_builtin_tensor %9 : !torch.vtensor<[32,50,50,1024],f16> -> tensor<32x50x50x1024xf16>
%14 = torch_c.to_builtin_tensor %11 : !torch.vtensor<[2048,1,1,1024],f16> -> tensor<2048x1x1x1024xf16>
%15 = call @conv1x1_bwd_impl(%12, %13, %14) : (tensor<32x25x25x2048xf16>, tensor<32x50x50x1024xf16>, tensor<2048x1x1x1024xf16>) -> tensor<2048x1x1x1024xf16>
%16 = torch_c.from_builtin_tensor %15 : tensor<2048x1x1x1024xf16> -> !torch.vtensor<[2048,1,1,1024],f16>
%none = torch.constant.none
%int0_20 = torch.constant.int 0
%int3_21 = torch.constant.int 3
%int1_22 = torch.constant.int 1
%int2_23 = torch.constant.int 2
%17 = torch.prim.ListConstruct %int0_20, %int3_21, %int1_22, %int2_23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%18 = torch.aten.permute %16, %17 : !torch.vtensor<[2048,1,1,1024],f16>, !torch.list<int> -> !torch.vtensor<[2048,1024,1,1],f16>
%int0_24 = torch.constant.int 0
%int2_25 = torch.constant.int 2
%int3_26 = torch.constant.int 3
%int1_27 = torch.constant.int 1
%19 = torch.prim.ListConstruct %int0_24, %int2_25, %int3_26, %int1_27 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%20 = torch.aten.permute %18, %19 : !torch.vtensor<[2048,1024,1,1],f16>, !torch.list<int> -> !torch.vtensor<[2048,1,1,1024],f16>
return %20 : !torch.vtensor<[2048,1,1,1024],f16>
}
func.func private @conv1x1_bwd_impl(%arg0: tensor<32x25x25x2048xf16>, %arg1: tensor<32x50x50x1024xf16>, %arg2: tensor<2048x1x1x1024xf16>) -> tensor<2048x1x1x1024xf16> attributes {torch.assume_strict_symbolic_shapes} {
%0 = torch_c.from_builtin_tensor %arg1 : tensor<32x50x50x1024xf16> -> !torch.vtensor<[32,50,50,1024],f16>
%1 = torch_c.from_builtin_tensor %arg0 : tensor<32x25x25x2048xf16> -> !torch.vtensor<[32,25,25,2048],f16>
%false = torch.constant.bool false
%int5 = torch.constant.int 5
%none = torch.constant.none
%int0 = torch.constant.int 0
%2 = torch.prim.ListConstruct %int0, %int0, %int0, %int0, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.constant_pad_nd %0, %2, %int0 : !torch.vtensor<[32,50,50,1024],f16>, !torch.list<int>, !torch.int -> !torch.vtensor<[32,50,50,1024],f16>
%4 = torch_c.to_builtin_tensor %3 : !torch.vtensor<[32,50,50,1024],f16> -> tensor<32x50x50x1024xf16>
%5 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[32,25,25,2048],f16> -> tensor<32x25x25x2048xf16>
%6 = util.call @conv1x1_bwd_generic(%4, %5) : (tensor<32x50x50x1024xf16>, tensor<32x25x25x2048xf16>) -> tensor<2048x1x1x1024xf32>
%7 = torch_c.from_builtin_tensor %6 : tensor<2048x1x1x1024xf32> -> !torch.vtensor<[2048,1,1,1024],f32>
%8 = torch.aten.to.dtype %7, %int5, %false, %false, %none : !torch.vtensor<[2048,1,1,1024],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2048,1,1,1024],f16>
%9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2048,1,1,1024],f16> -> tensor<2048x1x1x1024xf16>
return %9 : tensor<2048x1x1x1024xf16>
}
util.func private @conv1x1_bwd_generic(%arg0: tensor<32x50x50x1024xf16>, %arg1: tensor<32x25x25x2048xf16>) -> tensor<2048x1x1x1024xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<2048x1x1x1024xf32>
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x1x1x1024xf32>) -> tensor<2048x1x1x1024xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<32x50x50x1024xf16>, tensor<32x25x25x2048xf16>) outs(%1 : tensor<2048x1x1x1024xf32>) {
^bb0(%in: f16, %in_0: f16, %out: f32):
%3 = arith.extf %in : f16 to f32
%4 = arith.extf %in_0 : f16 to f32
%5 = arith.mulf %3, %4 : f32
%6 = arith.addf %out, %5 : f32
linalg.yield %6 : f32
} -> tensor<2048x1x1x1024xf32>
util.return %2 : tensor<2048x1x1x1024xf32>
}
}What happened?
With --iree-llvmgpu-test-load-to-transpose-load=true, the global loads and shared memory stores are not separated, and the loads are immediately waited on, which breaks the pipelining optimization.
The instruction reordering appears to be broken by the fact that we have transpose load instructions instead of regular loads. This is confirmed by a hack (branch: https://github.com/Max191/iree/tree/transpose-load-hack) that applies the same index remapping as the transpose load pass but emits regular vector.transfer_read instead of amdgpu.transpose_load (producing numerically incorrect results, but estimating what the performance should be). In this case, the shared memory store is separated from the global load.
Traces from rocprof also confirm that this bad scheduling is the bottleneck.
ISA analysis (inner loop .LBB0_2, comparing with vs without the hack)
With the hack (regular ds_read2_b64), the inner loop is properly software-pipelined -- global loads at the top, waited on at the bottom after ~60 instructions of useful work:
; === Hack (regular loads) -- proper pipelining ===
.LBB0_2:
buffer_load_dwordx4 ... ; global loads issued at TOP
buffer_load_dwordx4 ...
buffer_load_dwordx4 ...
buffer_load_dwordx4 ...
s_waitcnt lgkmcnt(0)
s_barrier
ds_read2_b64 ... ; 8x LDS reads
; ... permutes ...
v_mfma_f32_16x16x32_f16 ... ; 16x MFMAs
s_barrier
s_waitcnt vmcnt(3) ; wait for global loads at BOTTOM
ds_write2_b64 ... ; (~60+ instructions after the loads)
s_waitcnt vmcnt(2)
ds_write2_b64 ...
s_waitcnt vmcnt(1)
ds_write2_b64 ...
s_waitcnt vmcnt(0)
ds_write2_b64 ...Without the hack (transpose ds_read_b64_tr_b16), global loads are issued mid-loop and waited on immediately -- no latency hiding:
; === No hack (transpose loads) -- broken pipelining ===
.LBB0_2:
ds_read_b64_tr_b16 ... ; 16x transpose reads at TOP
; ... permutes ...
v_mfma_f32_16x16x32_f16 ... ; only 8 MFMAs before global loads
buffer_load_dwordx4 v[112:115], ... ; global load A
buffer_load_dwordx4 v[116:119], ... ; global load B
buffer_load_dwordx4 v[120:123], ... ; global load C
v_mfma_f32_16x16x32_f16 ... ; 1 MFMA
s_waitcnt vmcnt(2) ; *** ~3 instructions after load A ***
ds_write2_b64 v73, v[112:113], v[114:115] offset1:1
v_mfma_f32_16x16x32_f16 ...
buffer_load_dwordx4 v[88:91], ... ; global load D
s_waitcnt vmcnt(0) ; *** 0 instructions after load D ***
ds_write2_b64 v72, v[88:89], v[90:91] offset1:1
; ... remaining MFMAs + ds_writes ...Version information
IREE at commit 88d1a2faa6, MI355X (gfx950).
Additional context
- Attached
executable_files.ziphas.rocmasm,.optimized.ll,.linked.ll,.mlir, and.hsacofor all three cases - The hack branch has one commit replacing
amdgpu::TransposeLoadOp::createwithvector::TransferReadOp::createinROCDLLoadToTransposeLoad.cpp