Skip to content

Commit 088258b

Browse files
committed
Push test script
1 parent 8f3c6bd commit 088258b

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

dev/flopscomp.jl

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# using Pkg
2+
# Pkg.activate(temp=true)
3+
# Pkg.add(url="https://github.com/christiangnrd/Metal.jl/", rev="MPSGraph")
4+
# Pkg.add(["GPUArrays", "Plots"])
5+
6+
# Uncomment if you want to compare with CPU
7+
# Pkg.add(["AppleAccelerate"])
8+
# using AppleAccelerate
9+
10+
using Metal, GPUArrays, LinearAlgebra, Printf
11+
12+
using Plots
13+
using Plots.Measures
14+
15+
n_gpu_cores = "??"
16+
# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
17+
system_prof = read(`system_profiler SPDisplaysDataType`, String)
18+
n_gpu_cores = only(match(r"Total Number of Cores:\s*(\d+)", system_prof).captures)
19+
20+
PLOT_TITLE = "Matmul peakflops for $(device().name) ($n_gpu_cores GPU cores)"
21+
22+
function cpupeakflops(; n::Integer=4096,
23+
n_batch::Integer=1,
24+
inT::DataType=Float32,
25+
outT::DataType=inT,
26+
ntrials::Integer=4,
27+
verify=true)
28+
t = Base.zeros(Float64, ntrials)
29+
n_batch == 1 || @warn "n_batch > 1 not supported for `mul!`, running with n_batch=1"
30+
n_batch = 1
31+
shape = (n, n)
32+
for i=1:ntrials
33+
c = zeros(outT, shape...)
34+
a = ones(inT, shape...)
35+
b = ones(inT, shape...)
36+
t[i] = @elapsed mul!(c, a, b)
37+
verify && @assert only(unique(Array(c))) == n
38+
end
39+
40+
return n_batch*2*Float64(n)^3 / minimum(t)
41+
end
42+
function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
43+
t = Base.zeros(Float64, ntrials)
44+
shape = n_batch == 1 ? (n, n) : (n, n, n_batch)
45+
for i=1:ntrials
46+
c = mtl(zeros(outT, shape...))
47+
a = mtl(ones(inT, shape...))
48+
b = mtl(ones(inT, shape...))
49+
t[i] = @elapsed Metal.@sync f(c, a, b)
50+
verify && @assert only(unique(Array(c))) == n
51+
end
52+
53+
return n_batch*2*Float64(n)^3 / minimum(t)
54+
end
55+
function gpuarrpeakflops(; n::Integer=4096,
56+
n_batch::Integer=1,
57+
inT::DataType=Float32,
58+
outT::DataType=inT,
59+
ntrials::Integer=3,
60+
verify=true)
61+
n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1"
62+
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
63+
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
64+
end
65+
end
66+
function mpspeakflops(; n::Integer=4096,
67+
n_batch::Integer=1,
68+
inT::DataType=Float32,
69+
outT::DataType=inT,
70+
ntrials::Integer=3,
71+
verify=true)
72+
_peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
73+
end
74+
function graphpeakflops(; n::Integer=4096,
75+
n_batch::Integer=1,
76+
inT::DataType=Float32,
77+
outT::DataType=inT,
78+
ntrials::Integer=3,
79+
verify=true)
80+
_peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
81+
end
82+
function anepeakflops(; kwargs...)
83+
# VERY HACKY
84+
newDesc = MPSGraphs.MPSGraphCompilationDescriptor()
85+
# Use optimization level 0 to avoid operations being moved to the neural engine
86+
newDesc.optimizationLevel = MPSGraphs.MPSGraphOptimizationLevel1
87+
88+
oldDesc = MPSGraphs._default_exec_desc[].compilationDescriptor
89+
90+
MPSGraphs._default_exec_desc[].compilationDescriptor = newDesc
91+
res = graphpeakflops(; kwargs...)
92+
MPSGraphs._default_exec_desc[].compilationDescriptor = oldDesc
93+
94+
return res
95+
end
96+
97+
function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
98+
results = Dict()
99+
100+
newFs = if (outT == Float16 || (outT == Float32 && inT == Float16))
101+
Fs
102+
else
103+
filter(x -> !occursin("ANE", x[2]),Fs)
104+
end
105+
106+
for (_, info_str) in newFs
107+
results[info_str] = Float64[]
108+
end
109+
110+
prefixstr = "\33[2K\r($inT, $outT) "
111+
@time "$((inT, outT))" begin
112+
for n in Ns
113+
verify = (n < maxintfloat(outT) && (inT != Float16 || (n < maxintfloat(inT))))
114+
n_str = "$n: "
115+
for (f, info_str) in newFs
116+
print(prefixstr, n_str, info_str)
117+
push!(results[info_str], f(; inT, outT, n, n_batch, ntrials, verify))
118+
GC.gc()
119+
end
120+
end
121+
print("\33[2K\r")
122+
end
123+
return results
124+
end
125+
126+
function main(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
127+
Fs=[
128+
(mpspeakflops, "MPS"),
129+
(graphpeakflops, "MPSGraph"),
130+
(anepeakflops, "MPSGraph (ANE)"),
131+
# (gpuarrpeakflops, "GPUArrays"),
132+
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
133+
],
134+
n_batch=1,
135+
ntrials=5,
136+
outpath="",
137+
outtype="svg",
138+
plt_title=PLOT_TITLE)
139+
Ts=[
140+
(Int8, Float16),
141+
(Int8, Float32),
142+
(Int16, Float32),
143+
(Float16, Float16),
144+
(Float16, Float32),
145+
(Float32, Float32),
146+
]
147+
148+
res = Dict()
149+
150+
ylim_upper = 9e12
151+
152+
for (inT, outT) in Ts
153+
tmpres = compare(Ns, Fs, inT, outT; n_batch, ntrials)
154+
155+
plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)")
156+
for (res, (_, info_str)) in zip(tmpres,Fs)
157+
flops = tmpres[info_str]
158+
peakf = @sprintf("%.3e", maximum(flops))
159+
if maximum(flops) > ylim_upper
160+
ylim_upper = maximum(flops) * 1.02
161+
end
162+
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
163+
end
164+
res[(inT,outT)] = (plt=plt, results=tmpres)
165+
end
166+
167+
finalplot = plot(res[Ts[1]].plt, res[Ts[2]].plt, res[Ts[3]].plt, res[Ts[4]].plt, res[Ts[5]].plt, res[Ts[6]].plt; layout=(2,3),
168+
ylim=(0,ylim_upper),
169+
plot_title=plt_title,
170+
tickfonthalign=:left,
171+
bottommargin=15pt,
172+
size=(2000,1200))
173+
if !isnothing(outpath)
174+
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(n_batch).$outtype"))
175+
end
176+
return res, finalplot
177+
end

0 commit comments

Comments
 (0)