Skip to content

[LLVMGPU] Backwards weight conv regression with --iree-llvmgpu-test-load-to-transpose-load=true #23421

@Max191

Description

@Max191

Steps to reproduce

Compile

# Baseline:
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx950 --iree-opt-level=O3 \
  --iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops \
  --iree-dispatch-creation-enable-split-reduction \
  --iree-llvmgpu-test-load-to-transpose-load=false \
  '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)' \
  --iree-hal-dump-executable-files-to=files_false \
  input.mlir -o output_false.vmfb

# Transpose load:
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx950 --iree-opt-level=O3 \
  --iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops \
  --iree-dispatch-creation-enable-split-reduction \
  --iree-llvmgpu-test-load-to-transpose-load=true \
  '--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-convert-conv-filter-to-channels-last)' \
  --iree-hal-dump-executable-files-to=files_true \
  input.mlir -o output_true.vmfb

Benchmark

iree-benchmark-module --module=output_false.vmfb --device=hip --benchmark_repetitions=10
iree-benchmark-module --module=output_true.vmfb --device=hip --benchmark_repetitions=10

Results on MI355X:

Case Flag Median Time
Baseline false 502.0 us
Transpose load true 654.2 us

Input MLIR

Conv backward weight, 1x1 filter, stride 2, f16, n=32 c=1024 H=50 W=50 k=2048. See input.mlir in the attached zip.

input.mlir (click to expand)
#map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d1 + d5 * 2, d2 + d6 * 2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d0)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
module {
  func.func public @conv1x1_bwd_weight(%arg0: !torch.vtensor<[32,25,25,2048],f16>, %arg1: !torch.vtensor<[32,50,50,1024],f16>, %arg2: !torch.vtensor<[2048,1,1,1024],f16>) -> !torch.vtensor<[2048,1,1,1024],f16> {
    %int0 = torch.constant.int 0
    %int3 = torch.constant.int 3
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %0 = torch.prim.ListConstruct %int0, %int3, %int1, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[32,25,25,2048],f16>, !torch.list<int> -> !torch.vtensor<[32,2048,25,25],f16>
    %int0_0 = torch.constant.int 0
    %int3_1 = torch.constant.int 3
    %int1_2 = torch.constant.int 1
    %int2_3 = torch.constant.int 2
    %2 = torch.prim.ListConstruct %int0_0, %int3_1, %int1_2, %int2_3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.permute %arg1, %2 : !torch.vtensor<[32,50,50,1024],f16>, !torch.list<int> -> !torch.vtensor<[32,1024,50,50],f16>
    %int0_4 = torch.constant.int 0
    %int3_5 = torch.constant.int 3
    %int1_6 = torch.constant.int 1
    %int2_7 = torch.constant.int 2
    %4 = torch.prim.ListConstruct %int0_4, %int3_5, %int1_6, %int2_7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %5 = torch.aten.permute %arg2, %4 : !torch.vtensor<[2048,1,1,1024],f16>, !torch.list<int> -> !torch.vtensor<[2048,1024,1,1],f16>
    %int0_8 = torch.constant.int 0
    %int2_9 = torch.constant.int 2
    %int3_10 = torch.constant.int 3
    %int1_11 = torch.constant.int 1
    %6 = torch.prim.ListConstruct %int0_8, %int2_9, %int3_10, %int1_11 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %7 = torch.aten.permute %1, %6 : !torch.vtensor<[32,2048,25,25],f16>, !torch.list<int> -> !torch.vtensor<[32,25,25,2048],f16>
    %int0_12 = torch.constant.int 0
    %int2_13 = torch.constant.int 2
    %int3_14 = torch.constant.int 3
    %int1_15 = torch.constant.int 1
    %8 = torch.prim.ListConstruct %int0_12, %int2_13, %int3_14, %int1_15 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %9 = torch.aten.permute %3, %8 : !torch.vtensor<[32,1024,50,50],f16>, !torch.list<int> -> !torch.vtensor<[32,50,50,1024],f16>
    %int0_16 = torch.constant.int 0
    %int2_17 = torch.constant.int 2
    %int3_18 = torch.constant.int 3
    %int1_19 = torch.constant.int 1
    %10 = torch.prim.ListConstruct %int0_16, %int2_17, %int3_18, %int1_19 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %11 = torch.aten.permute %5, %10 : !torch.vtensor<[2048,1024,1,1],f16>, !torch.list<int> -> !torch.vtensor<[2048,1,1,1024],f16>
    %12 = torch_c.to_builtin_tensor %7 : !torch.vtensor<[32,25,25,2048],f16> -> tensor<32x25x25x2048xf16>
    %13 = torch_c.to_builtin_tensor %9 : !torch.vtensor<[32,50,50,1024],f16> -> tensor<32x50x50x1024xf16>
    %14 = torch_c.to_builtin_tensor %11 : !torch.vtensor<[2048,1,1,1024],f16> -> tensor<2048x1x1x1024xf16>
    %15 = call @conv1x1_bwd_impl(%12, %13, %14) : (tensor<32x25x25x2048xf16>, tensor<32x50x50x1024xf16>, tensor<2048x1x1x1024xf16>) -> tensor<2048x1x1x1024xf16>
    %16 = torch_c.from_builtin_tensor %15 : tensor<2048x1x1x1024xf16> -> !torch.vtensor<[2048,1,1,1024],f16>
    %none = torch.constant.none
    %int0_20 = torch.constant.int 0
    %int3_21 = torch.constant.int 3
    %int1_22 = torch.constant.int 1
    %int2_23 = torch.constant.int 2
    %17 = torch.prim.ListConstruct %int0_20, %int3_21, %int1_22, %int2_23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %18 = torch.aten.permute %16, %17 : !torch.vtensor<[2048,1,1,1024],f16>, !torch.list<int> -> !torch.vtensor<[2048,1024,1,1],f16>
    %int0_24 = torch.constant.int 0
    %int2_25 = torch.constant.int 2
    %int3_26 = torch.constant.int 3
    %int1_27 = torch.constant.int 1
    %19 = torch.prim.ListConstruct %int0_24, %int2_25, %int3_26, %int1_27 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %20 = torch.aten.permute %18, %19 : !torch.vtensor<[2048,1024,1,1],f16>, !torch.list<int> -> !torch.vtensor<[2048,1,1,1024],f16>
    return %20 : !torch.vtensor<[2048,1,1,1024],f16>
  }
  func.func private @conv1x1_bwd_impl(%arg0: tensor<32x25x25x2048xf16>, %arg1: tensor<32x50x50x1024xf16>, %arg2: tensor<2048x1x1x1024xf16>) -> tensor<2048x1x1x1024xf16> attributes {torch.assume_strict_symbolic_shapes} {
    %0 = torch_c.from_builtin_tensor %arg1 : tensor<32x50x50x1024xf16> -> !torch.vtensor<[32,50,50,1024],f16>
    %1 = torch_c.from_builtin_tensor %arg0 : tensor<32x25x25x2048xf16> -> !torch.vtensor<[32,25,25,2048],f16>
    %false = torch.constant.bool false
    %int5 = torch.constant.int 5
    %none = torch.constant.none
    %int0 = torch.constant.int 0
    %2 = torch.prim.ListConstruct %int0, %int0, %int0, %int0, %int0, %int0, %int0, %int0 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.constant_pad_nd %0, %2, %int0 : !torch.vtensor<[32,50,50,1024],f16>, !torch.list<int>, !torch.int -> !torch.vtensor<[32,50,50,1024],f16>
    %4 = torch_c.to_builtin_tensor %3 : !torch.vtensor<[32,50,50,1024],f16> -> tensor<32x50x50x1024xf16>
    %5 = torch_c.to_builtin_tensor %1 : !torch.vtensor<[32,25,25,2048],f16> -> tensor<32x25x25x2048xf16>
    %6 = util.call @conv1x1_bwd_generic(%4, %5) : (tensor<32x50x50x1024xf16>, tensor<32x25x25x2048xf16>) -> tensor<2048x1x1x1024xf32>
    %7 = torch_c.from_builtin_tensor %6 : tensor<2048x1x1x1024xf32> -> !torch.vtensor<[2048,1,1,1024],f32>
    %8 = torch.aten.to.dtype %7, %int5, %false, %false, %none : !torch.vtensor<[2048,1,1,1024],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[2048,1,1,1024],f16>
    %9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[2048,1,1,1024],f16> -> tensor<2048x1x1x1024xf16>
    return %9 : tensor<2048x1x1x1024xf16>
  }
  util.func private @conv1x1_bwd_generic(%arg0: tensor<32x50x50x1024xf16>, %arg1: tensor<32x25x25x2048xf16>) -> tensor<2048x1x1x1024xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<2048x1x1x1024xf32>
    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2048x1x1x1024xf32>) -> tensor<2048x1x1x1024xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<32x50x50x1024xf16>, tensor<32x25x25x2048xf16>) outs(%1 : tensor<2048x1x1x1024xf32>) {
    ^bb0(%in: f16, %in_0: f16, %out: f32):
      %3 = arith.extf %in : f16 to f32
      %4 = arith.extf %in_0 : f16 to f32
      %5 = arith.mulf %3, %4 : f32
      %6 = arith.addf %out, %5 : f32
      linalg.yield %6 : f32
    } -> tensor<2048x1x1x1024xf32>
    util.return %2 : tensor<2048x1x1x1024xf32>
  }
}

What happened?

With --iree-llvmgpu-test-load-to-transpose-load=true, the global loads and shared memory stores are not separated, and the loads are immediately waited on, which breaks the pipelining optimization.

The instruction reordering appears to be broken by the fact that we have transpose load instructions instead of regular loads. This is confirmed by a hack (branch: https://github.com/Max191/iree/tree/transpose-load-hack) that applies the same index remapping as the transpose load pass but emits regular vector.transfer_read instead of amdgpu.transpose_load (producing numerically incorrect results, but estimating what the performance should be). In this case, the shared memory store is separated from the global load.

Traces from rocprof also confirm that this bad scheduling is the bottleneck.

ISA analysis (inner loop .LBB0_2, comparing with vs without the hack)

With the hack (regular ds_read2_b64), the inner loop is properly software-pipelined -- global loads at the top, waited on at the bottom after ~60 instructions of useful work:

; === Hack (regular loads) -- proper pipelining ===
.LBB0_2:
    buffer_load_dwordx4 ...   ; global loads issued at TOP
    buffer_load_dwordx4 ...
    buffer_load_dwordx4 ...
    buffer_load_dwordx4 ...
    s_waitcnt lgkmcnt(0)
    s_barrier
    ds_read2_b64 ...          ; 8x LDS reads
    ; ... permutes ...
    v_mfma_f32_16x16x32_f16 ... ; 16x MFMAs
    s_barrier
    s_waitcnt vmcnt(3)        ; wait for global loads at BOTTOM
    ds_write2_b64 ...         ; (~60+ instructions after the loads)
    s_waitcnt vmcnt(2)
    ds_write2_b64 ...
    s_waitcnt vmcnt(1)
    ds_write2_b64 ...
    s_waitcnt vmcnt(0)
    ds_write2_b64 ...

Without the hack (transpose ds_read_b64_tr_b16), global loads are issued mid-loop and waited on immediately -- no latency hiding:

; === No hack (transpose loads) -- broken pipelining ===
.LBB0_2:
    ds_read_b64_tr_b16 ...    ; 16x transpose reads at TOP
    ; ... permutes ...
    v_mfma_f32_16x16x32_f16 ... ; only 8 MFMAs before global loads
    buffer_load_dwordx4 v[112:115], ...   ; global load A
    buffer_load_dwordx4 v[116:119], ...   ; global load B
    buffer_load_dwordx4 v[120:123], ...   ; global load C
    v_mfma_f32_16x16x32_f16 ...           ; 1 MFMA
    s_waitcnt vmcnt(2)                     ; *** ~3 instructions after load A ***
    ds_write2_b64 v73, v[112:113], v[114:115] offset1:1
    v_mfma_f32_16x16x32_f16 ...
    buffer_load_dwordx4 v[88:91], ...     ; global load D
    s_waitcnt vmcnt(0)                     ; *** 0 instructions after load D ***
    ds_write2_b64 v72, v[88:89], v[90:91] offset1:1
    ; ... remaining MFMAs + ds_writes ...

Version information

IREE at commit 88d1a2faa6, MI355X (gfx950).

Additional context

  • Attached executable_files.zip has .rocmasm, .optimized.ll, .linked.ll, .mlir, and .hsaco for all three cases
  • The hack branch has one commit replacing amdgpu::TransposeLoadOp::create with vector::TransferReadOp::create in ROCDLLoadToTransposeLoad.cpp

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions