1- # using Pkg
2- # Pkg.activate(temp=true)
3- # Pkg.add(url="https://github.com/christiangnrd/Metal.jl/", rev="MPSGraph")
4- # Pkg.add(["GPUArrays", "Plots"])
51
6- # Uncomment if you want to compare with CPU
7- # Pkg.add(["AppleAccelerate"])
8- # using AppleAccelerate
2+ using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate
93
10- using Metal, GPUArrays, LinearAlgebra, Printf
4+ @static if ! haskey (ENV , " CI" )
5+ using Plots
6+ using Plots. Measures
7+ end
118
12- using Plots
13- using Plots. Measures
9+ const Ts= [
10+ (Int8, Float16),
11+ (Int8, Float32),
12+ (Int16, Float32),
13+ (Float16, Float16),
14+ (Float16, Float32),
15+ (Float32, Float32),
16+ ]
1417
1518n_gpu_cores = " ??"
1619# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -123,7 +126,7 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
123126 return results
124127end
125128
126- function main (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 2000 , 2048 , 4000 , 4096 , 6000 , 6144 , 8000 , 8192 ],# , 10000],
129+ function runcomparison (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 2000 , 2048 , 4000 , 4096 , 6000 , 6144 , 8000 , 8192 ],# , 10000],
127130 Fs= [
128131 (mpspeakflops, " MPS" ),
129132 (graphpeakflops, " MPSGraph" ),
@@ -132,46 +135,51 @@ function main(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048
132135 # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
133136 ],
134137 n_batch= 1 ,
135- ntrials= 5 ,
136- outpath= " " ,
137- outtype= " svg" ,
138- plt_title= PLOT_TITLE)
139- Ts= [
140- (Int8, Float16),
141- (Int8, Float32),
142- (Int16, Float32),
143- (Float16, Float16),
144- (Float16, Float32),
145- (Float32, Float32),
146- ]
147-
138+ ntrials= 5 )
148139 res = Dict ()
149140
141+ for (inT, outT) in Ts
142+ res[(inT,outT)] = (n_batch, Ns, compare (Ns, Fs, inT, outT; n_batch, ntrials))
143+ end
144+ return res
145+ end
146+
147+ function plot_results (res, Fs= [" MPS" , " MPSGraph" , " MPSGraph (ANE)" ]; outpath= nothing , outtype= " svg" , plt_title= PLOT_TITLE)
150148 ylim_upper = 9e12
149+ resplts = []
150+
151+ n_batches = []
151152
152153 for (inT, outT) in Ts
153- tmpres = compare ( Ns, Fs, inT, outT; n_batch, ntrials)
154+ n_batch, Ns, tmpres = res[( inT,outT)]
154155
155156 plt = plot (xlabel= " N, n_batch=$(n_batch) " , legendtitle= " ($inT , $outT )" )
156- for (res, (_, info_str)) in zip (tmpres,Fs)
157+ for info_str in Fs
158+ haskey (tmpres, info_str) || continue
159+
157160 flops = tmpres[info_str]
158161 peakf = @sprintf (" %.3e" , maximum (flops))
159162 if maximum (flops) > ylim_upper
160163 ylim_upper = maximum (flops) * 1.02
161164 end
162165 plot! (plt, Ns, tmpres[info_str]; linewidth= 1.5 , label= " $(peakf) peak: $info_str " )
163166 end
164- res[(inT,outT)] = (plt= plt, results= tmpres)
167+ push! (resplts, plt)
168+ push! (n_batches, n_batch)
165169 end
166170
167- finalplot = plot (res[Ts[ 1 ]] . plt, res[Ts[ 2 ]] . plt, res[Ts[ 3 ]] . plt, res[Ts[ 4 ]] . plt, res[Ts[ 5 ]] . plt, res[Ts[ 6 ]] . plt ; layout= (2 ,3 ),
171+ finalplot = plot (resplts ... ; layout= (2 ,3 ),
168172 ylim= (0 ,ylim_upper),
169173 plot_title= plt_title,
170174 tickfonthalign= :left ,
171175 bottommargin= 15 pt,
172176 size= (2000 ,1200 ))
173177 if ! isnothing (outpath)
174- savefig (plot (finalplot, dpi= 500 ), joinpath (outpath, " bench_all_$(n_batch ) .$outtype " ))
178+ savefig (plot (finalplot, dpi= 500 ), joinpath (outpath, " bench_all_$(first (n_batches) ) .$outtype " ))
175179 end
176- return res, finalplot
180+ return finalplot
181+ end
182+
183+ if haskey (ENV , " CI" )
184+ runcomparison (Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 ])
177185end
0 commit comments