@@ -4,22 +4,6 @@ using CUDA
4
4
using GemmKernels
5
5
using LinearAlgebra
6
6
7
- # Global layouts
8
- global_layout (:: Type{<:CuArray{T}} , :: Val{false} ) where {T} = Layout. AlignedColMajor{T}
9
- global_layout (:: Type{<:CuArray{T}} , :: Val{true} ) where {T} = Layout. AlignedRowMajor{T}
10
- global_layout (:: Type{<:Diagonal{Float16, <:CuArray{Float16}}} , transpose) = Layout. Diagonal{Float16}
11
-
12
- # Shared layouts for A / B
13
- shared_layout_ab (typ:: Type{<:CuArray{Float16}} , transpose) = Layout. Padded{global_layout (typ, transpose), 8 }
14
- shared_layout_ab (:: Type{<:Diagonal{Float16, <:CuArray{Float16, N}}} , transpose) where {N, P} = shared_layout_ab (CuArray{Float16, N}, transpose)
15
-
16
- # Shared layouts for C / D
17
- shared_layout_cd (typ:: Type{<:CuArray{T}} , transpose) where {T} = global_layout (typ, transpose)
18
-
19
- # Convert matrix to type compatible with kernel
20
- convert_matrix (mat) = mat
21
- convert_matrix (mat:: Diagonal{T, A} ) where {T, A} = mat. diag
22
-
23
7
# Select the best kernel
24
8
kernel (layout_a, layout_b) = Kernel. matmul_singlestage
25
9
kernel (:: Type{Layout.AlignedColMajor{T}} , :: Type{Layout.AlignedColMajor{T}} ) where {T} = Kernel. matmul_pipelined
@@ -28,7 +12,8 @@ kernel(::Type{Layout.AlignedRowMajor{T}}, ::Type{Layout.AlignedColMajor{T}}) whe
28
12
kernel (:: Type{Layout.AlignedRowMajor{T}} , :: Type{Layout.AlignedRowMajor{T}} ) where {T} = Kernel. matmul_pipelined
29
13
30
14
# Based on https://github.com/JuliaGPU/CUDA.jl/blob/bd5a2a8800e91eb6a7df89eb5dd4bb8fc503541d/lib/cublas/wrappers.jl#L743-L769
31
- function gemmEx! (transA:: Char , transB:: Char , alpha:: Number , A, B, beta:: Number , C)
15
+ function gemmEx! (transA:: Char , transB:: Char , alpha:: Number , A:: CuMatrix , B:: CuMatrix ,
16
+ beta:: Number , C:: CuMatrix )
32
17
m = size (A, transA == ' N' ? 1 : 2 )
33
18
k = size (A, transA == ' N' ? 2 : 1 )
34
19
n = size (B, transB == ' N' ? 2 : 1 )
@@ -40,31 +25,48 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number,
40
25
transpose_a = (transA == ' T' )
41
26
transpose_b = (transB == ' T' )
42
27
43
- a_layout = global_layout (typeof (A), Val (transpose_a))
44
- b_layout = global_layout (typeof (B), Val (transpose_b))
28
+ a_layout_base = transpose_a ? Layout. AlignedRowMajor : Layout. AlignedColMajor
29
+ b_layout_base = transpose_b ? Layout. AlignedRowMajor : Layout. AlignedColMajor
30
+
31
+ # determine global memory layouts
32
+ # # if alpha is zero, we don't need to load A or B
33
+ if iszero (alpha)
34
+ global_a_layout = Layout. Zero{eltype (A)}
35
+ global_b_layout = Layout. Zero{eltype (B)}
36
+ else
37
+ global_a_layout = a_layout_base{eltype (A)}
38
+ global_b_layout = b_layout_base{eltype (B)}
39
+ end
40
+ # # if beta is zero, we don't need to load C
41
+ global_c_layout = if iszero (beta)
42
+ Layout. Zero{eltype (C)}
43
+ else
44
+ Layout. AlignedColMajor{eltype (C)}
45
+ end
46
+ global_d_layout = Layout. AlignedColMajor{eltype (C)}
45
47
46
- conf = GemmKernels. get_config (
48
+ # determine shared memory layouts
49
+ # # padded to avoid bank conflicts
50
+ shared_a_layout = Layout. Padded{a_layout_base{eltype (A)}, 8 }
51
+ shared_b_layout = Layout. Padded{b_layout_base{eltype (B)}, 8 }
52
+ # # outputs are never transposed, and padding them doesn't seem worth it
53
+ shared_c_layout = shared_d_layout = Layout. AlignedColMajor{eltype (C)}
54
+
55
+ conf = GemmKernels. get_config (;
47
56
gemm_shape = (M = m, N = n, K = k),
48
57
operator = Operator. WMMAOp{16 , 16 , 16 , eltype (C)},
49
58
50
- global_a_layout = a_layout,
51
- global_b_layout = b_layout,
52
- global_c_layout = global_layout (typeof (C), Val (false )),
53
- global_d_layout = global_layout (typeof (C), Val (false )),
54
-
55
- shared_a_layout = shared_layout_ab (typeof (A), Val (transpose_a)),
56
- shared_b_layout = shared_layout_ab (typeof (B), Val (transpose_b)),
57
- shared_c_layout = shared_layout_cd (typeof (C), Val (false )),
58
- shared_d_layout = shared_layout_cd (typeof (C), Val (false )),
59
+ global_a_layout, global_b_layout, global_c_layout, global_d_layout,
60
+ shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
59
61
60
62
is_a_col_major = ! transpose_a,
61
63
is_b_col_major = ! transpose_b
62
64
)
63
65
64
- GemmKernels. matmul (convert_matrix (A), convert_matrix (B), convert_matrix (C), convert_matrix (C) , conf;
66
+ GemmKernels. matmul (A, B, C, C , conf;
65
67
transform_shared_to_regs_a = Transform. Elementwise (x -> x * alpha),
66
68
transform_shared_to_regs_c = Transform. Elementwise (x -> x * beta),
67
- kernel = kernel (a_layout, b_layout )
69
+ kernel = kernel (global_a_layout, global_b_layout )
68
70
)
69
71
end
70
72
0 commit comments