@@ -68,6 +68,16 @@ function gpuarrpeakflops(; n::Integer=4096,
6868 GPUArrays. generic_matmatmul! (c, LinearAlgebra. wrap (a, ' N' ), LinearAlgebra. wrap (b, ' N' ), 1 , 0 )
6969 end
7070end
71+ function defaultpeakflops (; n:: Integer = 4096 ,
72+ n_batch:: Integer = 1 ,
73+ inT:: DataType = Float32,
74+ outT:: DataType = inT,
75+ ntrials:: Integer = 3 ,
76+ verify= true )
77+ _peakflops (n, 1 , inT, outT, ntrials; verify) do c, a, b
78+ LinearAlgebra. generic_matmatmul! (c, ' N' , ' N' , a, b, 1 , 0 )
79+ end
80+ end
7181function mpspeakflops (; n:: Integer = 4096 ,
7282 n_batch:: Integer = 1 ,
7383 inT:: DataType = Float32,
@@ -128,11 +138,13 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
128138 return results
129139end
130140
131- function runcomparison (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 , 2000 , 2048 , 4000 , 4096 , 6000 , 6144 , 8000 , 8192 ],# , 10000],
141+ # function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
142+ function runcomparison (; Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 , 1000 , 1024 : 20 : 2000 ... , 2000 , 2048 : 20 : 3000 ... , 4000 , 4096 : 20 : 6000 ... , 6000 , 6144 , 8000 , 8192 ],# , 10000],
132143 Fs= [
133144 (mpspeakflops, " MPS" ),
134145 (graphpeakflops, " MPSGraph" ),
135- (anepeakflops, " MPSGraph (ANE)" ),
146+ (defaultpeakflops, " Default" ),
147+ # (anepeakflops, "MPSGraph (ANE)"),
136148 # (gpuarrpeakflops, "GPUArrays"),
137149 # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
138150 ],
@@ -146,7 +158,7 @@ function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2
146158 return res
147159end
148160
149- function plot_results (res, Fs= [" MPS" , " MPSGraph" , " MPSGraph (ANE) " ]; outpath= nothing , outtype= " svg" , plt_title= PLOT_TITLE)
161+ function plot_results (res, Fs= [" MPS" , " MPSGraph" , " Default " ]; outpath= nothing , outtype= " svg" , plt_title= PLOT_TITLE)
150162 ylim_upper = 9e12
151163 resplts = []
152164
@@ -164,7 +176,7 @@ function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=not
164176 if maximum (flops) > ylim_upper
165177 ylim_upper = maximum (flops) * 1.02
166178 end
167- plot! (plt, Ns, tmpres[info_str]; linewidth= 1.5 , label= " $(peakf) peak: $info_str " )
179+ plot! (plt, Ns, tmpres[info_str]; linewidth= 1.5 , label= " $(peakf) peak: $info_str " , α = 0.8 )
168180 end
169181 push! (resplts, plt)
170182 push! (n_batches, n_batch)
185197if testing
186198 runcomparison (Ns= [50 , 64 , 100 , 128 , 250 , 256 , 500 , 512 ])
187199end
200+
201+ res = runcomparison ()
202+ plot_results (res; outpath= " ." )
0 commit comments