Skip to content

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Jan 29, 2026

Install wave following https://github.com/iree-org/wave?tab=readme-ov-file#for-developers. Then just run the file via Python.

@ftynse
Copy link
Contributor Author

ftynse commented Jan 29, 2026

The MLIR looks like this

#map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)>
#map1 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)>
#map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)>
#map3 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)>
#map4 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)>
#map5 = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)>
#map6 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4)>
#map7 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4 + 1)>
#map8 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4 + 2)>
#map9 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4 + 3)>
module attributes {gpu.container_module, transform.with_named_sequence} {
  gpu.module @gpu_module {
    gpu.func @gemm(%arg0: memref<f16> {llvm.inreg}, %arg1: memref<f16> {llvm.inreg}, %arg2: memref<f32> {llvm.inreg}) kernel attributes {known_block_size = array<i32: 128, 2, 1>} {
      %c1 = arith.constant 1 : index
      %c4 = arith.constant 4 : index
      %c1280 = arith.constant 1280 : index
      %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
      %c0 = arith.constant 0 : index
      %block_id_x = gpu.block_id  x upper_bound 2
      %block_id_y = gpu.block_id  y upper_bound 4
      %thread_id_x = gpu.thread_id  x upper_bound 128
      %thread_id_y = gpu.thread_id  y upper_bound 2
      %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64, 64], strides: [64, 1] : memref<f16> to memref<64x64xf16, strided<[64, 1]>>
      %reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<f16> to memref<128x64xf16, strided<[64, 1]>>
      %reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64, 128], strides: [128, 1] : memref<f32> to memref<64x128xf32, strided<[128, 1]>>
      %alloc = memref.alloc() : memref<2560xi8, #gpu.address_space<workgroup>>
      %view = memref.view %alloc[%c0][] : memref<2560xi8, #gpu.address_space<workgroup>> to memref<32x20xf16, #gpu.address_space<workgroup>>
      %view_2 = memref.view %alloc[%c1280][] : memref<2560xi8, #gpu.address_space<workgroup>> to memref<32x20xf16, #gpu.address_space<workgroup>>
      %0 = affine.apply #map()[%thread_id_x, %block_id_x]
      %1 = affine.apply #map1()[%thread_id_x]
      %2 = affine.apply #map2()[%thread_id_x]
      %3 = affine.apply #map3()[%thread_id_x, %block_id_y, %thread_id_y]
      %4 = affine.apply #map4()[%thread_id_x, %thread_id_y]
      %5 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %cst) -> (vector<4xf32>) {
        %15 = affine.apply #map5()[%arg3, %thread_id_x]
        %16 = vector.load %reinterpret_cast[%0, %15] : memref<64x64xf16, strided<[64, 1]>>, vector<4xf16>
        amdgpu.lds_barrier
        vector.store %16, %view_2[%1, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
        %17 = vector.load %reinterpret_cast_0[%3, %15] : memref<128x64xf16, strided<[64, 1]>>, vector<4xf16>
        vector.store %17, %view[%4, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
        amdgpu.lds_barrier
        %18 = vector.load %view[%4, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
        %19 = vector.load %view_2[%1, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
        %20 = amdgpu.mfma 16x16x16 %19 * %18 + %arg4 blgp =  none : vector<4xf16>, vector<4xf16>, vector<4xf32>
        scf.yield %20 : vector<4xf32>
      }
      %6 = vector.extract_strided_slice %5 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
      %7 = affine.apply #map6()[%block_id_x, %thread_id_x]
      %8 = affine.apply #map3()[%thread_id_x, %block_id_y, %thread_id_y]
      vector.store %6, %reinterpret_cast_1[%7, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
      %9 = vector.extract_strided_slice %5 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
      %10 = affine.apply #map7()[%block_id_x, %thread_id_x]
      vector.store %9, %reinterpret_cast_1[%10, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
      %11 = vector.extract_strided_slice %5 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
      %12 = affine.apply #map8()[%block_id_x, %thread_id_x]
      vector.store %11, %reinterpret_cast_1[%12, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
      %13 = vector.extract_strided_slice %5 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
      %14 = affine.apply #map9()[%block_id_x, %thread_id_x]
      vector.store %13, %reinterpret_cast_1[%14, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
      gpu.return
    }
  }
  func.func private @wave_get_buffer(!llvm.ptr) -> memref<?xi8> attributes {llvm.emit_c_interface}
  func.func private @wave_get_dim(!llvm.ptr, i32) -> i64 attributes {llvm.emit_c_interface}
  func.func private @wave_get_int64(!llvm.ptr) -> i64 attributes {llvm.emit_c_interface}
  func.func private @wave_get_float64(!llvm.ptr) -> f64 attributes {llvm.emit_c_interface}
  func.func @isolated_benchmark(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
    %c0 = arith.constant 0 : index
    %c2 = arith.constant 2 : index
    %c4 = arith.constant 4 : index
    %c1 = arith.constant 1 : index
    %c128 = arith.constant 128 : index
    %0 = call @wave_get_buffer(%arg1) : (!llvm.ptr) -> memref<?xi8>
    %view = memref.view %0[%c0][] : memref<?xi8> to memref<f16>
    %1 = call @wave_get_buffer(%arg2) : (!llvm.ptr) -> memref<?xi8>
    %view_0 = memref.view %1[%c0][] : memref<?xi8> to memref<f16>
    %2 = call @wave_get_buffer(%arg3) : (!llvm.ptr) -> memref<?xi8>
    %view_1 = memref.view %2[%c0][] : memref<?xi8> to memref<f32>
    gpu.launch_func  @gpu_module::@gemm blocks in (%c2, %c4, %c1) threads in (%c128, %c2, %c1)  args(%view : memref<f16>, %view_0 : memref<f16>, %view_1 : memref<f32>)
    return
  }
}```

@ftynse
Copy link
Contributor Author

ftynse commented Jan 29, 2026

The assembly looks like this

test_gemm
.amdgcn_target "amdgcn-amd-amdhsa--gfx942"
.text
.protected gemm
.globl gemm
.p2align 8
.type gemm,@function

.section .rodata,#alloc
.p2align 6
.amdhsa_kernel gemm
  .amdhsa_user_sgpr_count 2
  .amdhsa_user_sgpr_dispatch_ptr 0
  .amdhsa_user_sgpr_queue_ptr 0
  .amdhsa_user_sgpr_kernarg_segment_ptr 1
  .amdhsa_user_sgpr_dispatch_id 0
  .amdhsa_user_sgpr_private_segment_size 0
  .amdhsa_uses_dynamic_stack 0
  .amdhsa_enable_private_segment 0
  .amdhsa_accum_offset 24
  .amdhsa_next_free_vgpr 24
  .amdhsa_next_free_sgpr 24
  .amdhsa_group_segment_fixed_size 2560
  .amdhsa_private_segment_fixed_size 0
  .amdhsa_system_sgpr_workgroup_id_x 1
  .amdhsa_system_sgpr_workgroup_id_y 1
  .amdhsa_system_sgpr_workgroup_id_z 0
  .amdhsa_system_vgpr_workitem_id 1
  .amdhsa_float_denorm_mode_32 3
  .amdhsa_float_denorm_mode_16_64 3
.end_amdhsa_kernel
.text

# SRD upper word (gfx9xx): data_format=4 => 0x20000
.set Srd127_96, 131072

gemm:
    s_load_dwordx2 s[4:5], s[0:1], 0  // Load base addr for arg0
    s_load_dwordx2 s[8:9], s[0:1], 8  // Load base addr for arg1
    s_load_dwordx2 s[12:13], s[0:1], 16  // Load base addr for arg2
    s_waitcnt lgkmcnt(0)  // wait for all SRD loads
    s_mov_b32 s6, 0x2000  // SRD size for arg0
    s_mov_b32 s7, 0x20000  // SRD stride for arg0
    s_mov_b32 s10, 0x4000  // SRD size for arg1
    s_mov_b32 s11, 0x20000  // SRD stride for arg1
    s_mov_b32 s14, 0x8000  // SRD size for arg2
    s_mov_b32 s15, 0x20000  // SRD stride for arg2
    s_load_dwordx2 s[16:17], s[0:1], 0  // Load kernarg at offset 0
    s_load_dwordx2 s[16:17], s[0:1], 8  // Load kernarg at offset 8
    s_load_dwordx2 s[16:17], s[0:1], 16  // Load kernarg at offset 16
    s_waitcnt lgkmcnt(0)  // wait for all kernarg loads
    // Initialize loop 0
    s_mov_b32 s16, 0  // loop 0 counter = 0
    s_mov_b32 s17, 1  // loop 0 step = 1
    s_mov_b32 s18, 4  // loop 0 upper = 4
    v_mov_b32 v4, 0
    v_mov_b32 v5, 0
    v_mov_b32 v6, 0
    v_mov_b32 v7, 0  // Initialize accumulator 0 to 0.0
loop_0_header:
    s_cmp_lt_u32 s16, s18  // compare loop 0 counter < upper
    s_cbranch_scc1 loop_0_body
    s_branch loop_0_exit
loop_0_body:
    v_bfe_u32 v1, v0, 0, 10  // extract tid_x from flat_tid
    v_and_b32 v2, 63, v1  // mod 64 (and)
    v_lshrrev_b32 v3, 4, v2  // floor div by 16 (shift)
    v_lshlrev_b32 v2, 3, v3  // floor((Mod(tid_x, 64))/16) << 3
    v_mov_b32 v8, s16  // broadcast loop counter ks6 to VGPR
    v_lshlrev_b32 v9, 5, v8  // ks6 << 5
    v_or_b32 v8, v2, v9  // or (bits 3-4 + 5-20)
    v_and_b32 v9, 15, v1  // mod 16 (and)
    v_lshl_add_u32 v10, v9, 7, v8  // fused: (kv12 << 7) + kv11
    v_lshrrev_b32 v8, 6, v1  // floor div by 64 (shift)
    v_lshl_add_u32 v1, v8, 11, v10  // fused: (kv15 << 11) + kv14
    v_mov_b32 v11, s2  // wgid_x from s2
    v_lshl_add_u32 v12, v11, 12, v1  // fused: (kv18 << 12) + kv17
    buffer_load_dwordx2 v[14:15], v12, s[4:7], 0 offen  // load 8B @ offset 0
    s_waitcnt lgkmcnt(0)
    s_barrier  // LDS barrier
    v_mul_lo_u32 v1, v9, 40  // Mod(tid_x, 16) * 40
    v_add_u32 v12, v2, v1  // add
    v_mov_b32 v1, 0x280  // materialize 640
    v_mul_lo_u32 v2, v8, v1  // floor(tid_x/64) * 640
    v_add_u32 v13, v12, v2  // add
    v_add_u32 v2, 0x500, v13  // + 1280 (inline literal)
    s_waitcnt vmcnt(0)  // wait for VMEM before LDS store
    ds_write_b64 v2, v[14:15]  // LDS store 8B @ offset 0
    v_bfe_u32 v2, v0, 10, 10  // extract tid_y from flat_tid
    v_lshl_add_u32 v14, v2, 11, v10  // fused: (kv30 << 11) + kv14
    v_mov_b32 v10, s3  // wgid_y from s3
    v_lshl_add_u32 v15, v10, 12, v14  // fused: (kv33 << 12) + kv32
    buffer_load_dwordx2 v[16:17], v15, s[8:11], 0 offen  // load 8B @ offset 0
    v_mul_lo_u32 v14, v2, v1  // tid_y * 640
    v_add_u32 v1, v12, v14  // add
    s_waitcnt vmcnt(0)  // wait for VMEM before LDS store
    ds_write_b64 v1, v[16:17]  // LDS store 8B @ offset 0
    s_waitcnt lgkmcnt(0)
    s_barrier  // LDS barrier
    ds_read_b64 v[14:15], v1  // LDS load 8B @ offset 0
    ds_read_b64 v[16:17], v13 offset:1280  // LDS load 8B @ offset 1280
    s_waitcnt lgkmcnt(0)  // ticketing: wait for LGKM defs
    v_mfma_f32_16x16x16_f16 v[4:7], v[16:17], v[14:15], v[4:7]  // MFMA with accumulator (in-place)
loop_0_latch:
    s_add_u32 s16, s16, s17  // loop 0 counter += step
    s_branch loop_0_header
loop_0_exit:
    v_lshlrev_b32 v1, 6, v2  // tid_y << 6
    v_lshl_or_b32 v2, v9, 2, v1  // fused: (kv12 << 2) | kv46
    v_lshl_add_u32 v1, v10, 7, v2  // fused: (kv33 << 7) + kv47
    v_lshl_add_u32 v2, v3, 11, v1  // fused: (kv7 << 11) + kv49
    v_lshl_add_u32 v1, v8, 13, v2  // fused: (kv15 << 13) + kv51
    v_lshl_add_u32 v2, v11, 14, v1  // fused: (kv18 << 14) + kv53
    buffer_store_dword v[4:4], v2, s[12:15], 0 offen  // store 4B @ offset 0
    buffer_store_dword v[5:5], v2, s[12:15], 0 offen offset:512  // store 4B @ offset 512
    buffer_store_dword v[6:6], v2, s[12:15], 0 offen offset:1024  // store 4B @ offset 1024
    buffer_store_dword v[7:7], v2, s[12:15], 0 offen offset:1536  // store 4B @ offset 1536
    s_endpgm


.amdgpu_metadata
---
amdhsa.version:
  - 1
  - 2
amdhsa.kernels:
  - .name: gemm
    .symbol: 'gemm.kd'
    .language: OpenCL C
    .language_version:
      - 1
      - 2
    .kernarg_segment_size: 24
    .group_segment_fixed_size: 2560
    .private_segment_fixed_size: 0
    .kernarg_segment_align: 8
    .wavefront_size: 64
    .sgpr_count: 19
    .vgpr_count: 18
    .max_flat_workgroup_size: 256
    .args:
      - .name: arg0_ptr
        .size: 8
        .offset: 0
        .value_kind: global_buffer
        .value_type: i8*
      - .name: arg1_ptr
        .size: 8
        .offset: 8
        .value_kind: global_buffer
        .value_type: i8*
      - .name: arg2_ptr
        .size: 8
        .offset: 16
        .value_kind: global_buffer
        .value_type: i8*
...
.end_amdgpu_metadata
-----

@ftynse
Copy link
Contributor Author

ftynse commented Jan 29, 2026

The high-level representation in Wave dialect looks like this

module {
  func.func @kernel(%arg0: !wave.tensor<[@M, @K] of f16, <global>>, %arg1: !wave.tensor<[@N, @K] of f16, <global>>, %arg2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.constraints = [#wave.workgroup_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = <x>>, #wave.workgroup_constraint<dim = <"N">, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = <y>>, #wave.tiling_constraint<dim = <"K">, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>>, #wave.wave_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M ceildiv 2)>>, #wave.wave_constraint<dim = <"N">, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N ceildiv 2)>>, #wave.hardware_constraint<threads_per_wave = 64, waves_per_block = [2, 2, 1], mma_type = <f32_16x16x16_f16>>], wave.hyperparameters = #wave.hyperparameters<{BLOCK_K = 16 : i64, BLOCK_M = 32 : i64, BLOCK_N = 32 : i64, K = 64 : i64, M = 64 : i64, N = 128 : i64}>} {
    %cst = arith.constant 0.000000e+00 : f32
    %0 = wave.register %cst {_water_internal.id = "123211467026368"} : !wave.tensor<[@M, @N] of f32, <register>>
    %1 = wave.iterate @K iter_args(%0) attributes {_water_internal.id = "123211465261136", _water_internal.result_ids = ["123211465262544"]} {
    ^bb0(%arg3: !wave.tensor<[@M, @N] of f32, <register>>):
      %2 = wave.allocate {_water_internal.id = "123211465263248", distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_N">, #wave.symbol<"BLOCK_K">] -> (BLOCK_N, BLOCK_K + 4)>} : !wave.tensor<[@N, @K] of f16, <shared>>
      %3 = wave.allocate {_water_internal.id = "123211465262192", distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K + 4)>} : !wave.tensor<[@M, @K] of f16, <shared>>
      %4 = wave.read %arg0 {_water_internal.id = "123211467018800"} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>
      wave.write %4, %3 {_water_internal.id = "123211465263072"} : !wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@M, @K] of f16, <shared>>
      %5 = wave.read %arg1 {_water_internal.id = "123211465263600"} : (!wave.tensor<[@N, @K] of f16, <global>>) -> !wave.tensor<[@N, @K] of f16, <register>>
      wave.write %5, %2 {_water_internal.id = "123211465264128"} : !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <shared>>
      %6 = wave.read %2 {_water_internal.id = "123211465263776"} : (!wave.tensor<[@N, @K] of f16, <shared>>) -> !wave.tensor<[@N, @K] of f16, <register>>
      %7 = wave.read %3 {_water_internal.id = "123211465262896"} : (!wave.tensor<[@M, @K] of f16, <shared>>) -> !wave.tensor<[@M, @K] of f16, <register>>
      %8 = wave.mma %7, %6, %arg3 {_water_internal.id = "123211465261312"} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
      wave.yield %8 : !wave.tensor<[@M, @N] of f32, <register>>
    } : (!wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
    wave.write %1, %arg2 {_water_internal.id = "123211465262016"} : !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32, <global>>
    return
  }
}

@ftynse
Copy link
Contributor Author

ftynse commented Jan 29, 2026

cc @fabianmcg @nicolasvasilache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants