Skip to content

Commit b464a47

Browse files
committed
Move flopscomp to examples
1 parent ab8785b commit b464a47

File tree

2 files changed

+39
-30
lines changed

2 files changed

+39
-30
lines changed

dev/flopscomp.jl renamed to examples/flopscomp.jl

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
# using Pkg
2-
# Pkg.activate(temp=true)
3-
# Pkg.add(url="https://github.com/christiangnrd/Metal.jl/", rev="MPSGraph")
4-
# Pkg.add(["GPUArrays", "Plots"])
51

6-
# Uncomment if you want to compare with CPU
7-
# Pkg.add(["AppleAccelerate"])
8-
# using AppleAccelerate
2+
using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate
93

10-
using Metal, GPUArrays, LinearAlgebra, Printf
4+
@static if !haskey(ENV, "CI")
5+
using Plots
6+
using Plots.Measures
7+
end
118

12-
using Plots
13-
using Plots.Measures
9+
const Ts=[
10+
(Int8, Float16),
11+
(Int8, Float32),
12+
(Int16, Float32),
13+
(Float16, Float16),
14+
(Float16, Float32),
15+
(Float32, Float32),
16+
]
1417

1518
n_gpu_cores = "??"
1619
# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -123,7 +126,7 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
123126
return results
124127
end
125128

126-
function main(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
129+
function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
127130
Fs=[
128131
(mpspeakflops, "MPS"),
129132
(graphpeakflops, "MPSGraph"),
@@ -132,46 +135,51 @@ function main(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048
132135
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
133136
],
134137
n_batch=1,
135-
ntrials=5,
136-
outpath="",
137-
outtype="svg",
138-
plt_title=PLOT_TITLE)
139-
Ts=[
140-
(Int8, Float16),
141-
(Int8, Float32),
142-
(Int16, Float32),
143-
(Float16, Float16),
144-
(Float16, Float32),
145-
(Float32, Float32),
146-
]
147-
138+
ntrials=5)
148139
res = Dict()
149140

141+
for (inT, outT) in Ts
142+
res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
143+
end
144+
return res
145+
end
146+
147+
function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
150148
ylim_upper = 9e12
149+
resplts = []
150+
151+
n_batches = []
151152

152153
for (inT, outT) in Ts
153-
tmpres = compare(Ns, Fs, inT, outT; n_batch, ntrials)
154+
n_batch, Ns, tmpres = res[(inT,outT)]
154155

155156
plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)")
156-
for (res, (_, info_str)) in zip(tmpres,Fs)
157+
for info_str in Fs
158+
haskey(tmpres, info_str) || continue
159+
157160
flops = tmpres[info_str]
158161
peakf = @sprintf("%.3e", maximum(flops))
159162
if maximum(flops) > ylim_upper
160163
ylim_upper = maximum(flops) * 1.02
161164
end
162165
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
163166
end
164-
res[(inT,outT)] = (plt=plt, results=tmpres)
167+
push!(resplts, plt)
168+
push!(n_batches, n_batch)
165169
end
166170

167-
finalplot = plot(res[Ts[1]].plt, res[Ts[2]].plt, res[Ts[3]].plt, res[Ts[4]].plt, res[Ts[5]].plt, res[Ts[6]].plt; layout=(2,3),
171+
finalplot = plot(resplts...; layout=(2,3),
168172
ylim=(0,ylim_upper),
169173
plot_title=plt_title,
170174
tickfonthalign=:left,
171175
bottommargin=15pt,
172176
size=(2000,1200))
173177
if !isnothing(outpath)
174-
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(n_batch).$outtype"))
178+
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
175179
end
176-
return res, finalplot
180+
return finalplot
181+
end
182+
183+
if haskey(ENV, "CI")
184+
runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
177185
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3+
AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
34
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
45
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
56
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"

0 commit comments

Comments
 (0)