Skip to content

Commit 43deaf5

Browse files
authored
Merge pull request #120 from JuliaGPU/tb/zero_layout
Add zero layout to optimize alpha/beta=zero.
2 parents ac26708 + 1b2b59c commit 43deaf5

File tree

3 files changed

+46
-62
lines changed

3 files changed

+46
-62
lines changed

src/blas.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,6 @@ using CUDA
44
using GemmKernels
55
using LinearAlgebra
66

7-
# Global layouts
8-
global_layout(::Type{<:CuArray{T}}, ::Val{false}) where {T} = Layout.AlignedColMajor{T}
9-
global_layout(::Type{<:CuArray{T}}, ::Val{true}) where {T} = Layout.AlignedRowMajor{T}
10-
global_layout(::Type{<:Diagonal{Float16, <:CuArray{Float16}}}, transpose) = Layout.Diagonal{Float16}
11-
12-
# Shared layouts for A / B
13-
shared_layout_ab(typ::Type{<:CuArray{Float16}}, transpose) = Layout.Padded{global_layout(typ, transpose), 8}
14-
shared_layout_ab(::Type{<:Diagonal{Float16, <:CuArray{Float16, N}}}, transpose) where {N, P} = shared_layout_ab(CuArray{Float16, N}, transpose)
15-
16-
# Shared layouts for C / D
17-
shared_layout_cd(typ::Type{<:CuArray{T}}, transpose) where {T} = global_layout(typ, transpose)
18-
19-
# Convert matrix to type compatible with kernel
20-
convert_matrix(mat) = mat
21-
convert_matrix(mat::Diagonal{T, A}) where {T, A} = mat.diag
22-
237
# Select the best kernel
248
kernel(layout_a, layout_b) = Kernel.matmul_singlestage
259
kernel(::Type{Layout.AlignedColMajor{T}}, ::Type{Layout.AlignedColMajor{T}}) where {T} = Kernel.matmul_pipelined
@@ -28,7 +12,8 @@ kernel(::Type{Layout.AlignedRowMajor{T}}, ::Type{Layout.AlignedColMajor{T}}) whe
2812
kernel(::Type{Layout.AlignedRowMajor{T}}, ::Type{Layout.AlignedRowMajor{T}}) where {T} = Kernel.matmul_pipelined
2913

3014
# Based on https://github.com/JuliaGPU/CUDA.jl/blob/bd5a2a8800e91eb6a7df89eb5dd4bb8fc503541d/lib/cublas/wrappers.jl#L743-L769
31-
function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number, C)
15+
function gemmEx!(transA::Char, transB::Char, alpha::Number, A::CuMatrix, B::CuMatrix,
16+
beta::Number, C::CuMatrix)
3217
m = size(A, transA == 'N' ? 1 : 2)
3318
k = size(A, transA == 'N' ? 2 : 1)
3419
n = size(B, transB == 'N' ? 2 : 1)
@@ -40,31 +25,48 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number,
4025
transpose_a = (transA == 'T')
4126
transpose_b = (transB == 'T')
4227

43-
a_layout = global_layout(typeof(A), Val(transpose_a))
44-
b_layout = global_layout(typeof(B), Val(transpose_b))
28+
a_layout_base = transpose_a ? Layout.AlignedRowMajor : Layout.AlignedColMajor
29+
b_layout_base = transpose_b ? Layout.AlignedRowMajor : Layout.AlignedColMajor
30+
31+
# determine global memory layouts
32+
## if alpha is zero, we don't need to load A or B
33+
if iszero(alpha)
34+
global_a_layout = Layout.Zero{eltype(A)}
35+
global_b_layout = Layout.Zero{eltype(B)}
36+
else
37+
global_a_layout = a_layout_base{eltype(A)}
38+
global_b_layout = b_layout_base{eltype(B)}
39+
end
40+
## if beta is zero, we don't need to load C
41+
global_c_layout = if iszero(beta)
42+
Layout.Zero{eltype(C)}
43+
else
44+
Layout.AlignedColMajor{eltype(C)}
45+
end
46+
global_d_layout = Layout.AlignedColMajor{eltype(C)}
4547

46-
conf = GemmKernels.get_config(
48+
# determine shared memory layouts
49+
## padded to avoid bank conflicts
50+
shared_a_layout = Layout.Padded{a_layout_base{eltype(A)}, 8}
51+
shared_b_layout = Layout.Padded{b_layout_base{eltype(B)}, 8}
52+
## outputs are never transposed, and padding them doesn't seem worth it
53+
shared_c_layout = shared_d_layout = Layout.AlignedColMajor{eltype(C)}
54+
55+
conf = GemmKernels.get_config(;
4756
gemm_shape = (M = m, N = n, K = k),
4857
operator = Operator.WMMAOp{16, 16, 16, eltype(C)},
4958

50-
global_a_layout = a_layout,
51-
global_b_layout = b_layout,
52-
global_c_layout = global_layout(typeof(C), Val(false)),
53-
global_d_layout = global_layout(typeof(C), Val(false)),
54-
55-
shared_a_layout = shared_layout_ab(typeof(A), Val(transpose_a)),
56-
shared_b_layout = shared_layout_ab(typeof(B), Val(transpose_b)),
57-
shared_c_layout = shared_layout_cd(typeof(C), Val(false)),
58-
shared_d_layout = shared_layout_cd(typeof(C), Val(false)),
59+
global_a_layout, global_b_layout, global_c_layout, global_d_layout,
60+
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
5961

6062
is_a_col_major = !transpose_a,
6163
is_b_col_major = !transpose_b
6264
)
6365

64-
GemmKernels.matmul(convert_matrix(A), convert_matrix(B), convert_matrix(C), convert_matrix(C), conf;
66+
GemmKernels.matmul(A, B, C, C, conf;
6567
transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha),
6668
transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta),
67-
kernel = kernel(a_layout, b_layout)
69+
kernel = kernel(global_a_layout, global_b_layout)
6870
)
6971
end
7072

src/layout.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ abstract type LayoutBase{T} end
5757
@inline eltype(::Type{<:LayoutBase{T}}) where {T} = T
5858
@inline physical_size(::Type{<:LayoutBase{T}}, logical_size::NamedTuple) where {T} = Tuple(logical_size)
5959

60+
# ----
61+
# Zero
62+
# ----
63+
64+
abstract type Zero{T} <: LayoutBase{T} end
65+
66+
@inline function load(::Type{<:Zero{T}}, workspace, tile::Tile{size}) where {T, size}
67+
N = 16 ÷ sizeof(T)
68+
return ntuple(i -> VecElement{T}(zero(T)), Val(N))
69+
end
70+
71+
@inline store!(::Type{<:Zero{T}}, workspace, value, tile::Tile) where {T} = return
72+
6073
# --------------
6174
# Padded layouts
6275
# --------------

test/blas.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,4 @@ CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)
3636
@test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(A_type))));
3737
end
3838
end
39-
40-
@testset "WMMA GEMM (A = diagonal, B = $( !transpose_b ? 'N' : 'T' ))" for transpose_b = [false, true]
41-
@testset "(M = $M, N = $N, K = $K)" for M in [128, 256],
42-
N in [128, 256],
43-
K in [M]
44-
45-
transpose_a = false
46-
47-
alpha = rand(Float32)
48-
beta = rand(Float32)
49-
50-
a_h = rand(Float16, M);
51-
b_h = rand(Float16, (K, N)) / sqrt(Float16(K))
52-
c_h = rand(Float32, (M, N))
53-
54-
# Transpose input if necessary
55-
a_h = transpose_a ? transpose(a_h) : a_h
56-
b_h = transpose_b ? transpose(b_h) : b_h
57-
58-
a = Diagonal(CuArray(a_h))
59-
b = CuArray(b_h)
60-
61-
c_gemmkernels = CuArray(c_h)
62-
GemmKernels.BLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, a, b, beta, c_gemmkernels)
63-
64-
c_cublas = CuArray(c_h)
65-
CUDA.CUBLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, CuArray(Array(Diagonal(a_h))), b, beta, c_cublas)
66-
67-
@test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(Float16))));
68-
end
69-
end
7039
end

0 commit comments

Comments
 (0)