1- #  using Pkg
2- #  Pkg.activate(temp=true)
3- #  Pkg.add(url="https://github.com/christiangnrd/Metal.jl/", rev="MPSGraph")
4- #  Pkg.add(["GPUArrays", "Plots"])
51
6- #  Uncomment if you want to compare with CPU
7- #  Pkg.add(["AppleAccelerate"])
8- #  using AppleAccelerate
2+ using  Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate
93
10- using  Metal, GPUArrays, LinearAlgebra, Printf
4+ @static  if  ! haskey (ENV , " CI"  )
5+     using  Plots
6+     using  Plots. Measures
7+ end 
118
12- using  Plots
13- using  Plots. Measures
9+ const  Ts= [
10+         (Int8, Float16),
11+         (Int8, Float32),
12+         (Int16, Float32),
13+         (Float16, Float16),
14+         (Float16, Float32),
15+         (Float32, Float32),
16+     ]
1417
1518n_gpu_cores =  " ??" 
1619#  Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -123,7 +126,7 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
123126    return  results
124127end 
125128
126- function  main (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 2000 , 2048 , 4000 , 4096 , 6000 , 6144 , 8000 , 8192 ],# , 10000],
129+ function  runcomparison (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 2000 , 2048 , 4000 , 4096 , 6000 , 6144 , 8000 , 8192 ],# , 10000],
127130                Fs= [
128131                    (mpspeakflops, " MPS"  ),
129132                    (graphpeakflops, " MPSGraph"  ),
@@ -132,46 +135,51 @@ function main(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048
132135                    #  (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
133136                   ],
134137                n_batch= 1 ,
135-                 ntrials= 5 ,
136-                 outpath= " "  ,
137-                 outtype= " svg"  ,
138-                 plt_title= PLOT_TITLE)
139-     Ts= [
140-         (Int8, Float16),
141-         (Int8, Float32),
142-         (Int16, Float32),
143-         (Float16, Float16),
144-         (Float16, Float32),
145-         (Float32, Float32),
146-         ]
147- 
138+                 ntrials= 5 )
148139    res =  Dict ()
149140
141+     for  (inT, outT) in  Ts
142+         res[(inT,outT)] =  (n_batch, Ns, compare (Ns, Fs, inT, outT; n_batch, ntrials))
143+     end 
144+     return  res
145+ end 
146+ 
147+ function  plot_results (res, Fs= [" MPS"  , " MPSGraph"  , " MPSGraph (ANE)"  ]; outpath= nothing , outtype= " svg"  , plt_title= PLOT_TITLE)
150148    ylim_upper =  9e12 
149+     resplts =  []
150+ 
151+     n_batches =  []
151152
152153    for  (inT, outT) in  Ts
153-         tmpres  =   compare ( Ns, Fs,  inT,  outT; n_batch, ntrials) 
154+         n_batch,  Ns, tmpres  =  res[( inT,outT)] 
154155
155156        plt =  plot (xlabel= " N, n_batch=$(n_batch) "  , legendtitle= " ($inT , $outT )"  )
156-         for  (res, (_, info_str)) in  zip (tmpres,Fs)
157+         for  info_str in  Fs
158+             haskey (tmpres, info_str) ||  continue 
159+ 
157160            flops =  tmpres[info_str]
158161            peakf =  @sprintf (" %.3e"  , maximum (flops))
159162            if  maximum (flops) >  ylim_upper
160163                ylim_upper =  maximum (flops) *  1.02 
161164            end 
162165            plot! (plt, Ns, tmpres[info_str]; linewidth= 1.5 , label= " $(peakf)  peak: $info_str "  )
163166        end 
164-         res[(inT,outT)] =  (plt= plt, results= tmpres)
167+         push! (resplts, plt)
168+         push! (n_batches, n_batch)
165169    end 
166170
167-     finalplot =  plot (res[Ts[ 1 ]] . plt, res[Ts[ 2 ]] . plt, res[Ts[ 3 ]] . plt, res[Ts[ 4 ]] . plt, res[Ts[ 5 ]] . plt, res[Ts[ 6 ]] . plt ; layout= (2 ,3 ),
171+     finalplot =  plot (resplts ... ; layout= (2 ,3 ),
168172                     ylim= (0 ,ylim_upper),
169173                     plot_title= plt_title,
170174                     tickfonthalign= :left ,
171175                     bottommargin= 15 pt,
172176                     size= (2000 ,1200 ))
173177    if  ! isnothing (outpath)
174-         savefig (plot (finalplot, dpi= 500 ), joinpath (outpath, " bench_all_$(n_batch ) .$outtype "  ))
178+         savefig (plot (finalplot, dpi= 500 ), joinpath (outpath, " bench_all_$(first (n_batches) ) .$outtype "  ))
175179    end 
176-     return  res, finalplot
180+     return  finalplot
181+ end 
182+ 
183+ if  haskey (ENV , " CI"  )
184+     runcomparison (Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 ])
177185end 
0 commit comments