-
Notifications
You must be signed in to change notification settings - Fork 25
simple repro for asm generation for gemm #783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ftynse
wants to merge
1
commit into
main
Choose a base branch
from
users/ftynse/simple-gemm-asm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Signed-off-by: Alex Zinenko <[email protected]>
Contributor
Author
|
The MLIR looks like this #map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)>
#map1 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)>
#map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)>
#map3 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)>
#map4 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)>
#map5 = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)>
#map6 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4)>
#map7 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4 + 1)>
#map8 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4 + 2)>
#map9 = affine_map<()[s0, s1] -> (s0 * 32 + (s1 floordiv 64) * 16 + ((s1 mod 64) floordiv 16) * 4 + 3)>
module attributes {gpu.container_module, transform.with_named_sequence} {
gpu.module @gpu_module {
gpu.func @gemm(%arg0: memref<f16> {llvm.inreg}, %arg1: memref<f16> {llvm.inreg}, %arg2: memref<f32> {llvm.inreg}) kernel attributes {known_block_size = array<i32: 128, 2, 1>} {
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c1280 = arith.constant 1280 : index
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%c0 = arith.constant 0 : index
%block_id_x = gpu.block_id x upper_bound 2
%block_id_y = gpu.block_id y upper_bound 4
%thread_id_x = gpu.thread_id x upper_bound 128
%thread_id_y = gpu.thread_id y upper_bound 2
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [64, 64], strides: [64, 1] : memref<f16> to memref<64x64xf16, strided<[64, 1]>>
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<f16> to memref<128x64xf16, strided<[64, 1]>>
%reinterpret_cast_1 = memref.reinterpret_cast %arg2 to offset: [0], sizes: [64, 128], strides: [128, 1] : memref<f32> to memref<64x128xf32, strided<[128, 1]>>
%alloc = memref.alloc() : memref<2560xi8, #gpu.address_space<workgroup>>
%view = memref.view %alloc[%c0][] : memref<2560xi8, #gpu.address_space<workgroup>> to memref<32x20xf16, #gpu.address_space<workgroup>>
%view_2 = memref.view %alloc[%c1280][] : memref<2560xi8, #gpu.address_space<workgroup>> to memref<32x20xf16, #gpu.address_space<workgroup>>
%0 = affine.apply #map()[%thread_id_x, %block_id_x]
%1 = affine.apply #map1()[%thread_id_x]
%2 = affine.apply #map2()[%thread_id_x]
%3 = affine.apply #map3()[%thread_id_x, %block_id_y, %thread_id_y]
%4 = affine.apply #map4()[%thread_id_x, %thread_id_y]
%5 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %cst) -> (vector<4xf32>) {
%15 = affine.apply #map5()[%arg3, %thread_id_x]
%16 = vector.load %reinterpret_cast[%0, %15] : memref<64x64xf16, strided<[64, 1]>>, vector<4xf16>
amdgpu.lds_barrier
vector.store %16, %view_2[%1, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
%17 = vector.load %reinterpret_cast_0[%3, %15] : memref<128x64xf16, strided<[64, 1]>>, vector<4xf16>
vector.store %17, %view[%4, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
amdgpu.lds_barrier
%18 = vector.load %view[%4, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
%19 = vector.load %view_2[%1, %2] : memref<32x20xf16, #gpu.address_space<workgroup>>, vector<4xf16>
%20 = amdgpu.mfma 16x16x16 %19 * %18 + %arg4 blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
scf.yield %20 : vector<4xf32>
}
%6 = vector.extract_strided_slice %5 {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%7 = affine.apply #map6()[%block_id_x, %thread_id_x]
%8 = affine.apply #map3()[%thread_id_x, %block_id_y, %thread_id_y]
vector.store %6, %reinterpret_cast_1[%7, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
%9 = vector.extract_strided_slice %5 {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%10 = affine.apply #map7()[%block_id_x, %thread_id_x]
vector.store %9, %reinterpret_cast_1[%10, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
%11 = vector.extract_strided_slice %5 {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%12 = affine.apply #map8()[%block_id_x, %thread_id_x]
vector.store %11, %reinterpret_cast_1[%12, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
%13 = vector.extract_strided_slice %5 {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
%14 = affine.apply #map9()[%block_id_x, %thread_id_x]
vector.store %13, %reinterpret_cast_1[%14, %8] : memref<64x128xf32, strided<[128, 1]>>, vector<1xf32>
gpu.return
}
}
func.func private @wave_get_buffer(!llvm.ptr) -> memref<?xi8> attributes {llvm.emit_c_interface}
func.func private @wave_get_dim(!llvm.ptr, i32) -> i64 attributes {llvm.emit_c_interface}
func.func private @wave_get_int64(!llvm.ptr) -> i64 attributes {llvm.emit_c_interface}
func.func private @wave_get_float64(!llvm.ptr) -> f64 attributes {llvm.emit_c_interface}
func.func @isolated_benchmark(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%0 = call @wave_get_buffer(%arg1) : (!llvm.ptr) -> memref<?xi8>
%view = memref.view %0[%c0][] : memref<?xi8> to memref<f16>
%1 = call @wave_get_buffer(%arg2) : (!llvm.ptr) -> memref<?xi8>
%view_0 = memref.view %1[%c0][] : memref<?xi8> to memref<f16>
%2 = call @wave_get_buffer(%arg3) : (!llvm.ptr) -> memref<?xi8>
%view_1 = memref.view %2[%c0][] : memref<?xi8> to memref<f32>
gpu.launch_func @gpu_module::@gemm blocks in (%c2, %c4, %c1) threads in (%c128, %c2, %c1) args(%view : memref<f16>, %view_0 : memref<f16>, %view_1 : memref<f32>)
return
}
}``` |
Contributor
Author
|
The assembly looks like this test_gemm
.amdgcn_target "amdgcn-amd-amdhsa--gfx942"
.text
.protected gemm
.globl gemm
.p2align 8
.type gemm,@function
.section .rodata,#alloc
.p2align 6
.amdhsa_kernel gemm
.amdhsa_user_sgpr_count 2
.amdhsa_user_sgpr_dispatch_ptr 0
.amdhsa_user_sgpr_queue_ptr 0
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_user_sgpr_dispatch_id 0
.amdhsa_user_sgpr_private_segment_size 0
.amdhsa_uses_dynamic_stack 0
.amdhsa_enable_private_segment 0
.amdhsa_accum_offset 24
.amdhsa_next_free_vgpr 24
.amdhsa_next_free_sgpr 24
.amdhsa_group_segment_fixed_size 2560
.amdhsa_private_segment_fixed_size 0
.amdhsa_system_sgpr_workgroup_id_x 1
.amdhsa_system_sgpr_workgroup_id_y 1
.amdhsa_system_sgpr_workgroup_id_z 0
.amdhsa_system_vgpr_workitem_id 1
.amdhsa_float_denorm_mode_32 3
.amdhsa_float_denorm_mode_16_64 3
.end_amdhsa_kernel
.text
# SRD upper word (gfx9xx): data_format=4 => 0x20000
.set Srd127_96, 131072
gemm:
s_load_dwordx2 s[4:5], s[0:1], 0 // Load base addr for arg0
s_load_dwordx2 s[8:9], s[0:1], 8 // Load base addr for arg1
s_load_dwordx2 s[12:13], s[0:1], 16 // Load base addr for arg2
s_waitcnt lgkmcnt(0) // wait for all SRD loads
s_mov_b32 s6, 0x2000 // SRD size for arg0
s_mov_b32 s7, 0x20000 // SRD stride for arg0
s_mov_b32 s10, 0x4000 // SRD size for arg1
s_mov_b32 s11, 0x20000 // SRD stride for arg1
s_mov_b32 s14, 0x8000 // SRD size for arg2
s_mov_b32 s15, 0x20000 // SRD stride for arg2
s_load_dwordx2 s[16:17], s[0:1], 0 // Load kernarg at offset 0
s_load_dwordx2 s[16:17], s[0:1], 8 // Load kernarg at offset 8
s_load_dwordx2 s[16:17], s[0:1], 16 // Load kernarg at offset 16
s_waitcnt lgkmcnt(0) // wait for all kernarg loads
// Initialize loop 0
s_mov_b32 s16, 0 // loop 0 counter = 0
s_mov_b32 s17, 1 // loop 0 step = 1
s_mov_b32 s18, 4 // loop 0 upper = 4
v_mov_b32 v4, 0
v_mov_b32 v5, 0
v_mov_b32 v6, 0
v_mov_b32 v7, 0 // Initialize accumulator 0 to 0.0
loop_0_header:
s_cmp_lt_u32 s16, s18 // compare loop 0 counter < upper
s_cbranch_scc1 loop_0_body
s_branch loop_0_exit
loop_0_body:
v_bfe_u32 v1, v0, 0, 10 // extract tid_x from flat_tid
v_and_b32 v2, 63, v1 // mod 64 (and)
v_lshrrev_b32 v3, 4, v2 // floor div by 16 (shift)
v_lshlrev_b32 v2, 3, v3 // floor((Mod(tid_x, 64))/16) << 3
v_mov_b32 v8, s16 // broadcast loop counter ks6 to VGPR
v_lshlrev_b32 v9, 5, v8 // ks6 << 5
v_or_b32 v8, v2, v9 // or (bits 3-4 + 5-20)
v_and_b32 v9, 15, v1 // mod 16 (and)
v_lshl_add_u32 v10, v9, 7, v8 // fused: (kv12 << 7) + kv11
v_lshrrev_b32 v8, 6, v1 // floor div by 64 (shift)
v_lshl_add_u32 v1, v8, 11, v10 // fused: (kv15 << 11) + kv14
v_mov_b32 v11, s2 // wgid_x from s2
v_lshl_add_u32 v12, v11, 12, v1 // fused: (kv18 << 12) + kv17
buffer_load_dwordx2 v[14:15], v12, s[4:7], 0 offen // load 8B @ offset 0
s_waitcnt lgkmcnt(0)
s_barrier // LDS barrier
v_mul_lo_u32 v1, v9, 40 // Mod(tid_x, 16) * 40
v_add_u32 v12, v2, v1 // add
v_mov_b32 v1, 0x280 // materialize 640
v_mul_lo_u32 v2, v8, v1 // floor(tid_x/64) * 640
v_add_u32 v13, v12, v2 // add
v_add_u32 v2, 0x500, v13 // + 1280 (inline literal)
s_waitcnt vmcnt(0) // wait for VMEM before LDS store
ds_write_b64 v2, v[14:15] // LDS store 8B @ offset 0
v_bfe_u32 v2, v0, 10, 10 // extract tid_y from flat_tid
v_lshl_add_u32 v14, v2, 11, v10 // fused: (kv30 << 11) + kv14
v_mov_b32 v10, s3 // wgid_y from s3
v_lshl_add_u32 v15, v10, 12, v14 // fused: (kv33 << 12) + kv32
buffer_load_dwordx2 v[16:17], v15, s[8:11], 0 offen // load 8B @ offset 0
v_mul_lo_u32 v14, v2, v1 // tid_y * 640
v_add_u32 v1, v12, v14 // add
s_waitcnt vmcnt(0) // wait for VMEM before LDS store
ds_write_b64 v1, v[16:17] // LDS store 8B @ offset 0
s_waitcnt lgkmcnt(0)
s_barrier // LDS barrier
ds_read_b64 v[14:15], v1 // LDS load 8B @ offset 0
ds_read_b64 v[16:17], v13 offset:1280 // LDS load 8B @ offset 1280
s_waitcnt lgkmcnt(0) // ticketing: wait for LGKM defs
v_mfma_f32_16x16x16_f16 v[4:7], v[16:17], v[14:15], v[4:7] // MFMA with accumulator (in-place)
loop_0_latch:
s_add_u32 s16, s16, s17 // loop 0 counter += step
s_branch loop_0_header
loop_0_exit:
v_lshlrev_b32 v1, 6, v2 // tid_y << 6
v_lshl_or_b32 v2, v9, 2, v1 // fused: (kv12 << 2) | kv46
v_lshl_add_u32 v1, v10, 7, v2 // fused: (kv33 << 7) + kv47
v_lshl_add_u32 v2, v3, 11, v1 // fused: (kv7 << 11) + kv49
v_lshl_add_u32 v1, v8, 13, v2 // fused: (kv15 << 13) + kv51
v_lshl_add_u32 v2, v11, 14, v1 // fused: (kv18 << 14) + kv53
buffer_store_dword v[4:4], v2, s[12:15], 0 offen // store 4B @ offset 0
buffer_store_dword v[5:5], v2, s[12:15], 0 offen offset:512 // store 4B @ offset 512
buffer_store_dword v[6:6], v2, s[12:15], 0 offen offset:1024 // store 4B @ offset 1024
buffer_store_dword v[7:7], v2, s[12:15], 0 offen offset:1536 // store 4B @ offset 1536
s_endpgm
.amdgpu_metadata
---
amdhsa.version:
- 1
- 2
amdhsa.kernels:
- .name: gemm
.symbol: 'gemm.kd'
.language: OpenCL C
.language_version:
- 1
- 2
.kernarg_segment_size: 24
.group_segment_fixed_size: 2560
.private_segment_fixed_size: 0
.kernarg_segment_align: 8
.wavefront_size: 64
.sgpr_count: 19
.vgpr_count: 18
.max_flat_workgroup_size: 256
.args:
- .name: arg0_ptr
.size: 8
.offset: 0
.value_kind: global_buffer
.value_type: i8*
- .name: arg1_ptr
.size: 8
.offset: 8
.value_kind: global_buffer
.value_type: i8*
- .name: arg2_ptr
.size: 8
.offset: 16
.value_kind: global_buffer
.value_type: i8*
...
.end_amdgpu_metadata
----- |
Contributor
Author
|
The high-level representation in Wave dialect looks like this module {
func.func @kernel(%arg0: !wave.tensor<[@M, @K] of f16, <global>>, %arg1: !wave.tensor<[@N, @K] of f16, <global>>, %arg2: !wave.tensor<[@M, @N] of f32, <global>>) attributes {wave.constraints = [#wave.workgroup_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = <x>>, #wave.workgroup_constraint<dim = <"N">, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = <y>>, #wave.tiling_constraint<dim = <"K">, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>>, #wave.wave_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M ceildiv 2)>>, #wave.wave_constraint<dim = <"N">, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N ceildiv 2)>>, #wave.hardware_constraint<threads_per_wave = 64, waves_per_block = [2, 2, 1], mma_type = <f32_16x16x16_f16>>], wave.hyperparameters = #wave.hyperparameters<{BLOCK_K = 16 : i64, BLOCK_M = 32 : i64, BLOCK_N = 32 : i64, K = 64 : i64, M = 64 : i64, N = 128 : i64}>} {
%cst = arith.constant 0.000000e+00 : f32
%0 = wave.register %cst {_water_internal.id = "123211467026368"} : !wave.tensor<[@M, @N] of f32, <register>>
%1 = wave.iterate @K iter_args(%0) attributes {_water_internal.id = "123211465261136", _water_internal.result_ids = ["123211465262544"]} {
^bb0(%arg3: !wave.tensor<[@M, @N] of f32, <register>>):
%2 = wave.allocate {_water_internal.id = "123211465263248", distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_N">, #wave.symbol<"BLOCK_K">] -> (BLOCK_N, BLOCK_K + 4)>} : !wave.tensor<[@N, @K] of f16, <shared>>
%3 = wave.allocate {_water_internal.id = "123211465262192", distributed_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_K">] -> (BLOCK_M, BLOCK_K + 4)>} : !wave.tensor<[@M, @K] of f16, <shared>>
%4 = wave.read %arg0 {_water_internal.id = "123211467018800"} : (!wave.tensor<[@M, @K] of f16, <global>>) -> !wave.tensor<[@M, @K] of f16, <register>>
wave.write %4, %3 {_water_internal.id = "123211465263072"} : !wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@M, @K] of f16, <shared>>
%5 = wave.read %arg1 {_water_internal.id = "123211465263600"} : (!wave.tensor<[@N, @K] of f16, <global>>) -> !wave.tensor<[@N, @K] of f16, <register>>
wave.write %5, %2 {_water_internal.id = "123211465264128"} : !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <shared>>
%6 = wave.read %2 {_water_internal.id = "123211465263776"} : (!wave.tensor<[@N, @K] of f16, <shared>>) -> !wave.tensor<[@N, @K] of f16, <register>>
%7 = wave.read %3 {_water_internal.id = "123211465262896"} : (!wave.tensor<[@M, @K] of f16, <shared>>) -> !wave.tensor<[@M, @K] of f16, <register>>
%8 = wave.mma %7, %6, %arg3 {_water_internal.id = "123211465261312"} : (!wave.tensor<[@M, @K] of f16, <register>>, !wave.tensor<[@N, @K] of f16, <register>>, !wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
wave.yield %8 : !wave.tensor<[@M, @N] of f32, <register>>
} : (!wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N] of f32, <register>>
wave.write %1, %arg2 {_water_internal.id = "123211465262016"} : !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32, <global>>
return
}
} |
Contributor
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Install wave following https://github.com/iree-org/wave?tab=readme-ov-file#for-developers. Then just run the file via Python.