Skip to content

Commit b7346e4

Browse files
some small test time optimizations
1 parent 748ed97 commit b7346e4

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqBase"
22
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "6.34.2"
4+
version = "6.34.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/ensemble/basic_ensemble_solve.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ function solve_batch(prob,alg,ensemblealg::EnsembleThreads,II,pmap_batch_size;kw
215215

216216
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
217217
let
218-
if length(II) == 1
218+
if length(II) == 1 || Threads.nthreads() == 1
219219
for batch_idx in axes(batch_data, 1)
220220
batch_data[batch_idx] = multithreaded_batch(batch_idx)
221221
end
@@ -294,7 +294,7 @@ function thread_monte(prob,II,alg,procid;kwargs...)
294294
batch_data = Vector{Core.Compiler.return_type(multithreaded_batch,Tuple{typeof(first(II))})}(undef,length(II))
295295

296296
let
297-
if length(II) == 1
297+
if length(II) == 1 || Threads.nthreads() == 1
298298
for batch_idx in axes(batch_data, 1)
299299
batch_data[batch_idx] = multithreaded_batch(batch_idx)
300300
end

src/ensemble/ensemble_solutions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ function calculate_ensemble_errors(u;elapsedTime=0.0,converged=false,
7878
if weak_dense_errors
7979
densetimes = collect(range(u[1].t[1], stop=u[1].t[end], length=100))
8080
u_analytic = [[sol.prob.f(Val{:analytic},sol.prob.u0,sol.prob.p,densetimes[i],sol.W(densetimes[i])[1]) for i in eachindex(densetimes)] for sol in u]
81-
dense_weak_errors = [mean([u[j](densetimes)[i] - u_analytic[j][i] for j in 1:length(u)]) for i in eachindex(densetimes)]
81+
udense = [u[j](densetimes) for j in 1:length(u)]
82+
dense_weak_errors = [mean([udense[j][i] - u_analytic[j][i] for j in 1:length(u)]) for i in eachindex(densetimes)]
8283
dense_L2_errors = [sqrt.(sum(abs2,err)/length(err)) for err in dense_weak_errors]
8384
L2_tmp = sqrt(sum(abs2,dense_L2_errors)/length(dense_L2_errors))
8485
max_tmp = maximum([maximum(abs.(err)) for err in dense_weak_errors])

0 commit comments

Comments
 (0)