Skip to content

Commit e70717d

Browse files
Allow general RNG in random_* functions (#106)
* allow general RNG in random_* functions * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix missing method * add another missing method * add defaults * remove default * remove default * add another default... --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 09a32ee commit e70717d

File tree

3 files changed

+63
-25
lines changed

3 files changed

+63
-25
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
ReadVTK = "dc215faf-f008-4882-a9f7-a79a826fadc3"
1213
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1314
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -32,6 +33,7 @@ ForwardDiff = "0.10.36"
3233
LinearAlgebra = "1"
3334
Meshes = "0.52.1"
3435
Printf = "1"
36+
Random = "1"
3537
ReadVTK = "0.2"
3638
RecipesBase = "1.3.4"
3739
Reexport = "1.2"

src/KernelInterpolation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ using DiffEqCallbacks: PeriodicCallback, PeriodicCallbackAffect
1515
using ForwardDiff: ForwardDiff
1616
using LinearAlgebra: Symmetric, I, norm, tr, muladd, dot, diagind
1717
using Printf: @sprintf
18+
using Random: Random
1819
using ReadVTK: VTKFile, get_points, get_point_data, get_data
1920
using RecipesBase: RecipesBase, @recipe, @series
2021
using SciMLBase: ODEFunction, ODEProblem, ODESolution, DiscreteCallback, u_modified!

src/nodes.jl

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -220,65 +220,80 @@ end
220220

221221
# Some convenience function to create some specific `NodeSet`s
222222
"""
223-
random_hypercube(n, x_min = ntuple(_ -> 0.0, dim), x_max = ntuple(_ -> 1.0, dim); [dim])
223+
random_hypercube([rng], n, x_min = ntuple(_ -> 0.0, dim), x_max = ntuple(_ -> 1.0, dim); [dim])
224224
225225
Create a [`NodeSet`](@ref) with `n` random nodes each of dimension `dim` inside a hypercube defined by
226226
the bounds `x_min` and `x_max`. If the bounds are given as single values, they are applied for
227227
each dimension. If they are `Tuple`s of size `dim` the hypercube has the according bounds.
228228
If `dim` is not given explicitly, it is inferred by the lengths of `x_min` and `x_max` if possible.
229+
Optionally, pass a random number generator `rng`.
229230
"""
230-
function random_hypercube(n::Int, x_min::Real = 0.0, x_max::Real = 1.0; dim = 1)
231-
nodes = x_min .+ (x_max - x_min) .* rand(n, dim)
231+
function random_hypercube(n, x_min = 0.0, x_max = 1.0; kwargs...)
232+
random_hypercube(Random.default_rng(), n, x_min, x_max; kwargs...)
233+
end
234+
235+
function random_hypercube(rng::Random.AbstractRNG, n::Int, x_min::Real = 0.0,
236+
x_max::Real = 1.0; dim = 1)
237+
nodes = x_min .+ (x_max - x_min) .* rand(rng, n, dim)
232238
return NodeSet(nodes)
233239
end
234240

235-
function random_hypercube(n::Int, x_min::NTuple{Dim}, x_max::NTuple{Dim};
241+
function random_hypercube(rng::Random.AbstractRNG, n::Int, x_min::NTuple{Dim},
242+
x_max::NTuple{Dim};
236243
dim = Dim) where {Dim}
237244
@assert dim == Dim
238-
nodes = rand(n, dim)
245+
nodes = rand(rng, n, dim)
239246
for i in 1:dim
240247
nodes[:, i] = x_min[i] .+ (x_max[i] - x_min[i]) .* view(nodes, :, i)
241248
end
242249
return NodeSet(nodes)
243250
end
244251

245252
"""
246-
random_hypercube_boundary(n, x_min = ntuple(_ -> 0.0, dim), x_max = ntuple(_ -> 1.0, dim); [dim])
253+
random_hypercube_boundary([rng], n, x_min = ntuple(_ -> 0.0, dim), x_max = ntuple(_ -> 1.0, dim); [dim])
247254
248255
Create a [`NodeSet`](@ref) with `n` random nodes each of dimension `dim` on the boundary of a hypercube
249256
defined by the bounds `x_min` and `x_max`. If the bounds are given as single values, they are
250257
applied for each dimension. If they are `Tuple`s of size `dim` the hypercube has the according bounds.
251258
If `dim` is not given explicitly, it is inferred by the lengths of `x_min` and `x_max` if possible.
259+
Optionally, pass a random number generator `rng`.
252260
"""
253-
function random_hypercube_boundary(n::Int, x_min::Real = 0.0, x_max::Real = 1.0; dim = 1)
254-
random_hypercube_boundary(n, ntuple(_ -> x_min, dim), ntuple(_ -> x_max, dim))
261+
function random_hypercube_boundary(n, x_min = 0.0, x_max = 1.0; kwargs...)
262+
random_hypercube_boundary(Random.default_rng(), n, x_min, x_max; kwargs...)
263+
end
264+
265+
function random_hypercube_boundary(rng::Random.AbstractRNG, n::Int, x_min::Real = 0.0,
266+
x_max::Real = 1.0; dim = 1)
267+
random_hypercube_boundary(rng, n, ntuple(_ -> x_min, dim), ntuple(_ -> x_max, dim))
255268
end
256269

257-
function project_on_hypercube_boundary!(nodeset::NodeSet{Dim}, x_min::NTuple{Dim},
270+
function project_on_hypercube_boundary!(rng::Random.AbstractRNG, nodeset::NodeSet{Dim},
271+
x_min::NTuple{Dim},
258272
x_max::NTuple{Dim}) where {Dim}
259273
for i in eachindex(nodeset)
260274
# j = argmin([abs.(nodeset[i] .- x_min); abs.(nodeset[i] .- x_max)])
261275
# Project to random axis
262-
j = rand(1:Dim)
263-
if rand([1, 2]) == 1
276+
j = rand(rng, 1:Dim)
277+
if rand(rng, [1, 2]) == 1
264278
nodeset[i][j] = x_min[j]
265279
else
266280
nodeset[i][j] = x_max[j]
267281
end
268282
end
269283
end
270284

271-
function random_hypercube_boundary(n::Int, x_min::NTuple{Dim}, x_max::NTuple{Dim};
285+
function random_hypercube_boundary(rng::Random.AbstractRNG, n::Int, x_min::NTuple{Dim},
286+
x_max::NTuple{Dim};
272287
dim = Dim) where {Dim}
273288
@assert dim == Dim
274289
if dim == 1 && n >= 2
275290
@warn "For one dimension the boundary of the hypercube consists only of 2 points"
276291
return NodeSet([x_min[1], x_max[1]])
277292
end
278293
# First, create random nodes *inside* hypercube
279-
nodeset = random_hypercube(n, x_min, x_max)
294+
nodeset = random_hypercube(rng, n, x_min, x_max)
280295
# Then, project all the nodes on the boundary
281-
project_on_hypercube_boundary!(nodeset, x_min, x_max)
296+
project_on_hypercube_boundary!(rng, nodeset, x_min, x_max)
282297
return nodeset
283298
end
284299

@@ -406,44 +421,64 @@ function homogeneous_hypercube_boundary(n::NTuple{Dim},
406421
end
407422

408423
"""
409-
random_hypersphere(n, r = 1.0, center = zeros(dim); [dim])
424+
random_hypersphere([rng], n, r = 1.0, center = zeros(dim); [dim])
410425
411426
Create a [`NodeSet`](@ref) with `n` random nodes each of dimension `dim` inside a hypersphere with
412427
radius `r` around the center `center`.
413428
If `dim` is not given explicitly, it is inferred by the length of `center` if possible.
429+
Optionally, pass a random number generator `rng`.
414430
"""
415-
function random_hypersphere(n::Int, r = 1.0; dim = 2)
416-
random_hypersphere(n, r, zeros(dim))
431+
function random_hypersphere(n, r = 1.0; kwargs...)
432+
random_hypersphere(Random.default_rng(), n, r; kwargs...)
433+
end
434+
435+
function random_hypersphere(n, r, center; kwargs...)
436+
random_hypersphere(Random.default_rng(), n, r, center; kwargs...)
417437
end
418438

419-
function random_hypersphere(n::Int, r::Real, center::AbstractVector; dim = length(center))
439+
function random_hypersphere(rng::Random.AbstractRNG, n, r = 1.0; dim = 2)
440+
random_hypersphere(rng, n, r, zeros(dim))
441+
end
442+
443+
function random_hypersphere(rng::Random.AbstractRNG, n, r,
444+
center; dim = length(center))
420445
@assert length(center) == dim
421-
nodes = randn(n, dim)
446+
nodes = randn(rng, n, dim)
422447
for i in 1:n
423-
nodes[i, :] .= center .+ r .* nodes[i, :] ./ norm(nodes[i, :]) * rand()^(1 / dim)
448+
nodes[i, :] .= center .+ r .* nodes[i, :] ./ norm(nodes[i, :]) * rand(rng)^(1 / dim)
424449
end
425450
return NodeSet(nodes)
426451
end
427452

428453
"""
429-
random_hypersphere_boundary(n, r = 1.0, center = zeros(dim); [dim])
454+
random_hypersphere_boundary([rng], n, r = 1.0, center = zeros(dim); [dim])
430455
431456
Create a [`NodeSet`](@ref) with `n` random nodes each of dimension `dim` at the boundary of a
432457
hypersphere with radius `r` around the center `center`.
433458
If `dim` is not given explicitly, it is inferred by the length of `center` if possible.
459+
Optionally, pass a random number generator `rng`.
434460
"""
435-
function random_hypersphere_boundary(n::Int, r = 1.0; dim = 2)
436-
random_hypersphere_boundary(n, r, zeros(dim))
461+
function random_hypersphere_boundary(n, r = 1.0; kwargs...)
462+
random_hypersphere_boundary(Random.default_rng(), n, r; kwargs...)
463+
end
464+
465+
function random_hypersphere_boundary(n, r, center; kwargs...)
466+
random_hypersphere_boundary(Random.default_rng(), n, r, center; kwargs...)
467+
end
468+
469+
function random_hypersphere_boundary(rng::Random.AbstractRNG, n, r = 1.0; dim = 2)
470+
random_hypersphere_boundary(rng, n, r, zeros(dim))
437471
end
438472

439-
function random_hypersphere_boundary(n::Int, r::Real, center::AbstractVector;
473+
function random_hypersphere_boundary(rng::Random.AbstractRNG, n, r,
474+
center;
440475
dim = length(center))
441476
@assert length(center) == dim
442477
if dim == 1 && n >= 2
443478
@warn "For one dimension the boundary of the hypersphere consists only of 2 points"
444479
return NodeSet([-r, r])
445480
end
446-
nodes = randn(n, dim)
481+
nodes = randn(rng, n, dim)
447482
for i in 1:n
448483
nodes[i, :] .= center .+ r .* nodes[i, :] ./ norm(nodes[i, :])
449484
end

0 commit comments

Comments
 (0)