1-
2- using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate
1+ using Metal, GPUArrays, LinearAlgebra, Printf# , AppleAccelerate
32
43testing = (@isdefined TESTING) && TESTING
54
@@ -8,14 +7,15 @@ testing = (@isdefined TESTING) && TESTING
87 using Plots. Measures
98end
109
11- const Ts= [
12- (Int8, Float16),
13- (Int8, Float32),
14- (Int16, Float32),
15- (Float16, Float16),
16- (Float16, Float32),
17- (Float32, Float32),
18- ]
10+ Ts= [
11+ (Int8, Float16),
12+ (Int8, Float32),
13+ (Int16, Float32),
14+ (Float16, Float16),
15+ (Float16, Float32),
16+ (Float32, Float32),
17+ ]
18+ DEFAULT_NS = [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 1500 , 2000 , 2048 , 2500 , 3000 , 4000 , 4096 , 5000 , 6000 , 6144 , 8000 , 8192 ]
1919
2020n_gpu_cores = " ??"
2121# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -138,27 +138,25 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
138138 return results
139139end
140140
141- # function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
142- function runcomparison (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 : 100 : 2000 ... , 2000 , 2048 : 100 : 3000 ... , 4000 , 4096 : 100 : 6000 ... , 6000 , 6144 , 8000 , 8192 ],# , 10000],
143- Fs= [
144- (mpspeakflops, " MPS" ),
145- (graphpeakflops, " MPSGraph" ),
146- (defaultpeakflops, " Default" ),
147- # (anepeakflops, "MPSGraph (ANE)"),
148- # (gpuarrpeakflops, "GPUArrays"),
149- # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
150- ],
151- n_batch= 1 ,
152- ntrials= 5 )
153- res = Dict ()
141+ DEFAULT_FS = [
142+ (mpspeakflops, " MPS" ),
143+ (graphpeakflops, " MPSGraph" ),
144+ (defaultpeakflops, " Default" ),
145+ # (anepeakflops, "MPSGraph (ANE)"),
146+ # (gpuarrpeakflops, "GPUArrays"),
147+ # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
148+ ]
154149
150+ function runcomparison (; Ns= DEFAULT_NS, Fs= DEFAULT_FS, n_batch= 1 , ntrials= 5 )
151+ res = Dict ()
155152 for (inT, outT) in Ts
156153 res[(inT,outT)] = (n_batch, Ns, compare (Ns, Fs, inT, outT; n_batch, ntrials))
157154 end
158155 return res
159156end
160157
161- function plot_results (res, Fs= [" MPS" , " MPSGraph" , " Default" ]; outpath= nothing , outtype= " svg" , plt_title= PLOT_TITLE)
158+ function plot_results (res, Fs= DEFAULT_FS; outpath= nothing , outtype= " svg" , plt_title= PLOT_TITLE)
159+ Fs = get .(Fs, 2 , " You shouldn't be reading this" )
162160 ylim_upper = 9e12
163161 resplts = []
164162
196194
197195if testing
198196 runcomparison (Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 ])
199- else
200- res = runcomparison (Ns = [ 50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 1500 , 2000 ]) # , 2048, 2500, 3000, 4000, 4096, 5000, 6000, 6144, 8000, 8192] )
197+ elseif abspath ( PROGRAM_FILE ) == @__FILE__
198+ res = runcomparison ()
201199 plot_results (res; outpath= " ." )
202200end
0 commit comments