-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
Some of the choices of parameters are currently far from optimal, as quickly explored using the following script:
using CUDA, GemmKernels
using Hyperopt
using Octavian
# we don't need super-accurate timings
const samples = 250
function main()
M = K = N = 4096
A = CUDA.rand(Float32, M, K)
B = CUDA.rand(Float32, K, N)
C = CUDA.zeros(Float32, M, N)
C_h = zeros(Float32, M, N)
Octavian.matmul!(C_h, Array(A), Array(B))
# pow2-sized, 128-bit aligned inputs, so we can use aligned layouts.
# we don't have transposed inputs, so everything is column major.
@assert stride(A, 2) % 16 == 0
global_a_layout = Layout.UnsafeAlignedColMajor{eltype(A)}
@assert stride(B, 2) % 16 == 0
global_b_layout = Layout.UnsafeAlignedColMajor{eltype(B)}
# we want to do a simple C = A * B, so no need to load C first.
global_c_layout = Layout.Zero{eltype(C)}
@assert stride(C, 2) % 16 == 0
global_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
# shared layouts are similar.
# the frequently-accessed a/b shmems are padded to avoid bank conflicts.
shared_a_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(A)}, 8}
shared_b_layout = Layout.Padded{Layout.UnsafeAlignedColMajor{eltype(B)}, 8}
shared_c_layout = shared_d_layout = Layout.UnsafeAlignedColMajor{eltype(C)}
# we use the single-stage kernel, for simplicity
kernel = Kernel.matmul_singlestage
# TODO: compute_warp is partially hardcoded in config.jl, requiring M>=4 and N >=2
# TODO: tune warps_per_block (which may affect correctness)
total = 0
attempts = 0
benchmarks = 0
ho = @hyperopt for i = 1000,
OPERATOR_M = 2 .^ (1:4),
OPERATOR_N = 2 .^ (1:4),
OPERATOR_K = 2 .^ (1:4),
BLOCK_M = 2 .^ (1:8),
BLOCK_N = 2 .^ (1:8),
BLOCK_K = 2 .^ (1:8)
op_shape = (M = OPERATOR_M, N = OPERATOR_N, K = OPERATOR_K)
block_shape = (M = BLOCK_M, N = BLOCK_N, K = BLOCK_K)
total += 1
# validate the operator shape
## may not be larger than the block shape
if op_shape.M > block_shape.M ||
op_shape.N > block_shape.N ||
op_shape.K > block_shape.K
return Inf
end
## the FPU operator's base shape is 4x8x1. can we relax this?
if op_shape.M < 4 || op_shape.M % 4 != 0 ||
op_shape.N < 8 || op_shape.N % 8 != 0
return Inf
end
## LocalArray size limits (these are the ways FPUOp instantiates them)
if op_shape.M÷4 * op_shape.K >= 32 ||
op_shape.K * op_shape.N÷8 >= 32 ||
op_shape.M÷4 * op_shape.N÷8 >= 32
# in isolation, i.e. https://github.com/JuliaGPU/GemmKernels.jl/issues/99,
# a LocalArray of 32 elements is fine, but in the context of the kernel,
# it's too large. I don't know why.
return Inf
end
# validate the block shape
## needs to exactly covers the inputs, so that we can use the unsafe layouts.
if M % block_shape.M != 0 || N % block_shape.N != 0 || K % block_shape.K != 0
return Inf
end
## need to be 128-bit aligned so that we can perform vectorized loads
# XXX: is this correct?
if block_shape.M * sizeof(eltype(A)) % 16 != 0 ||
block_shape.N * sizeof(eltype(B)) % 16 != 0 ||
block_shape.K * sizeof(eltype(C)) % 16 != 0
return Inf
end
compute_type = promote_type(eltype(A), eltype(B))
operator = Operator.FPUOp{OPERATOR_M, OPERATOR_N, OPERATOR_K, compute_type, eltype(C)}
conf = GemmKernels.get_config(;
gemm_shape = (; M, N, K), block_shape, operator,
global_a_layout, global_b_layout, global_c_layout, global_d_layout,
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
is_a_col_major = true,
is_b_col_major = true
)
## another LocalArray size limit, these are in the kernel
num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M
num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N
if num_fragments_m * num_fragments_n >= 32
return Inf
end
try
# warm-up & correctness check
attempts += 1
C .= 0
GemmKernels.matmul(conf, A, B, C, C; kernel)
if !(Array(C) ≈ C_h)
@warn "Configuration produced invalid result: $conf"
return Inf
end
# benchmark
benchmarks += 1
device_synchronize()
GC.gc(true)
timings = zeros(samples)
for i in 1:samples
synchronize(stream())
timings[i] = CUDA.@elapsed GemmKernels.matmul(conf, A, B, C, C; kernel)
end
minimum(timings)
catch err
if isa(err, CuError)
@error "Configuration failed: $conf"
rethrow()
end
@info "Skipping configuration: $conf\n" * sprint(Base.showerror, err)
# TODO: introduce GemmKernels.ConfigError, to differentiate from e.g.
# compilation errors, which we want to report verbosely.
Inf
end
end
skips = total - attempts
errors = attempts - benchmarks
println("Out of $total configurations, $skips ($(round(100*skips/total; digits=1))%) were skipped, $errors ($(round(100*errors/total; digits=1))%) errored, and $benchmarks ($(round(100*benchmarks/total; digits=1))%) were actually tested.")
ho
end
isinteractive() || println(main())For example, let's do a 256x256 GEMM, FP32xFP32=FP32, using the FPU operator. On my system (RTX 6000 Ada), the default configuration (8x8x1 (N, M, K) and block 128x128x32) yields:
julia> main()
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 53.589 μs … 115.339 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 55.609 μs ┊ GC (median): 0.00%
Time (mean ± σ): 55.744 μs ± 1.117 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▂▄▆▇▆█▆█▄▃▄▁
▁▁▁▁▁▂▂▃▄▃▄▅▄▅▆▇▇█████████████▇▇▆▅▅▄▄▅▄▄▄▃▃▃▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁ ▄
53.6 μs Histogram: frequency by time 58.6 μs <
Memory estimate: 2.98 KiB, allocs estimate: 50.
The script above optimizes this to 4x16x8 en 16x32x256, which yields:
julia> main()
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 22.549 μs … 87.980 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 24.030 μs ┊ GC (median): 0.00%
Time (mean ± σ): 24.212 μs ± 974.562 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▁▃▃▇▇██▅▆▅▃▃▂
▂▁▁▂▂▂▂▂▃▃▃▄▅▆█████████████▇█▆▆▆▅▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂ ▄
22.5 μs Histogram: frequency by time 26.5 μs <
Memory estimate: 2.98 KiB, allocs estimate: 50.
For reference, CUBLAS:
julia> @benchmark CUDA.@sync mul!(C, A, B)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 19.850 μs … 75.970 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 20.790 μs ┊ GC (median): 0.00%
Time (mean ± σ): 20.884 μs ± 735.922 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▂▂▆███▃▃▂▁
▂▂▁▂▂▂▂▂▂▃▃▄▆▇███████████▇▆▅▅▅▄▄▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂ ▄
19.8 μs Histogram: frequency by time 22.7 μs <
Memory estimate: 592 bytes, allocs estimate: 20.
So a 2x improvement, getting us way closer to CUBLAS.
One problem is that the current implementation has lots of implicit assumptions on the parameter values, so lots of configurations are skipped, because they error or even result in invalid results. This should be fixed before we can fully explore the parameter space.
Metadata
Metadata
Assignees
Labels
No labels