Skip to content

Commit ee0199f

Browse files
authored
Various improvements to peakflops() (#49833)
* Various improvements to peakflops Use 4096 as the default matrix size Add kwarg to pick the type of elements in the matrix Add kwarg for number of trials and pick best time
1 parent 520b639 commit ee0199f

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

stdlib/InteractiveUtils/src/InteractiveUtils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ end
301301
# TODO: @deprecate peakflops to LinearAlgebra
302302
export peakflops
303303
"""
304-
peakflops(n::Integer=2000; parallel::Bool=false)
304+
peakflops(n::Integer=4096; eltype::DataType=Float64, ntrials::Integer=3, parallel::Bool=false)
305305
306306
`peakflops` computes the peak flop rate of the computer by using double precision
307307
[`gemm!`](@ref LinearAlgebra.BLAS.gemm!). For more information see
@@ -311,12 +311,12 @@ export peakflops
311311
This function will be moved from `InteractiveUtils` to `LinearAlgebra` in the
312312
future. In Julia 1.1 and later it is available as `LinearAlgebra.peakflops`.
313313
"""
314-
function peakflops(n::Integer=2000; parallel::Bool=false)
315-
# Base.depwarn("`peakflop`s have moved to the LinearAlgebra module, " *
314+
function peakflops(n::Integer=4096; eltype::DataType=Float64, ntrials::Integer=3, parallel::Bool=false)
315+
# Base.depwarn("`peakflops` has moved to the LinearAlgebra module, " *
316316
# "add `using LinearAlgebra` to your imports.", :peakflops)
317317
let LinearAlgebra = Base.require(Base.PkgId(
318318
Base.UUID((0x37e2e46d_f89d_539d,0xb4ee_838fcccc9c8e)), "LinearAlgebra"))
319-
return LinearAlgebra.peakflops(n; parallel = parallel)
319+
return LinearAlgebra.peakflops(n, eltype=eltype, ntrials=ntrials, parallel=parallel)
320320
end
321321
end
322322

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -557,14 +557,20 @@ end
557557
ldiv(F, B)
558558

559559
"""
560-
LinearAlgebra.peakflops(n::Integer=2000; parallel::Bool=false)
560+
LinearAlgebra.peakflops(n::Integer=4096; eltype::DataType=Float64, ntrials::Integer=3, parallel::Bool=false)
561561
562562
`peakflops` computes the peak flop rate of the computer by using double precision
563563
[`gemm!`](@ref LinearAlgebra.BLAS.gemm!). By default, if no arguments are specified, it
564-
multiplies a matrix of size `n x n`, where `n = 2000`. If the underlying BLAS is using
564+
multiplies two `Float64` matrices of size `n x n`, where `n = 4096`. If the underlying BLAS is using
565565
multiple threads, higher flop rates are realized. The number of BLAS threads can be set with
566566
[`BLAS.set_num_threads(n)`](@ref).
567567
568+
If the keyword argument `eltype` is provided, `peakflops` will construct matrices with elements
569+
of type `eltype` for calculating the peak flop rate.
570+
571+
By default, `peakflops` will use the best timing from 3 trials. If the `ntrials` keyword argument
572+
is provided, `peakflops` will use those many trials for picking the best timing.
573+
568574
If the keyword argument `parallel` is set to `true`, `peakflops` is run in parallel on all
569575
the worker processors. The flop rate of the entire parallel computer is returned. When
570576
running in parallel, only 1 BLAS thread is used. The argument `n` still refers to the size
@@ -574,19 +580,21 @@ of the problem that is solved on each processor.
574580
This function requires at least Julia 1.1. In Julia 1.0 it is available from
575581
the standard library `InteractiveUtils`.
576582
"""
577-
function peakflops(n::Integer=2000; parallel::Bool=false)
578-
a = fill(1.,100,100)
579-
t = @elapsed a2 = a*a
580-
a = fill(1.,n,n)
581-
t = @elapsed a2 = a*a
582-
@assert a2[1,1] == n
583+
function peakflops(n::Integer=4096; eltype::DataType=Float64, ntrials::Integer=3, parallel::Bool=false)
584+
t = zeros(Float64, ntrials)
585+
for i=1:ntrials
586+
a = ones(eltype,n,n)
587+
t[i] = @elapsed a2 = a*a
588+
@assert a2[1,1] == n
589+
end
590+
583591
if parallel
584592
let Distributed = Base.require(Base.PkgId(
585593
Base.UUID((0x8ba89e20_285c_5b6f, 0x9357_94700520ee1b)), "Distributed"))
586594
return sum(Distributed.pmap(peakflops, fill(n, Distributed.nworkers())))
587595
end
588596
else
589-
return 2*Float64(n)^3 / t
597+
return 2*Float64(n)^3 / minimum(t)
590598
end
591599
end
592600

stdlib/LinearAlgebra/test/generic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ end
558558
end
559559

560560
@testset "peakflops" begin
561-
@test LinearAlgebra.peakflops() > 0
561+
@test LinearAlgebra.peakflops(1024, eltype=Float32, ntrials=2) > 0
562562
end
563563

564564
@testset "NaN handling: Issue 28972" begin

0 commit comments

Comments
 (0)