Skip to content

Commit 6843370

Browse files
committed
Add multithreading to electric field calculation
1 parent 415ccf2 commit 6843370

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/ElectricField/ElectricField.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@ Base.convert(T::Type{ElectricField}, x::NamedTuple) = T(x)
4949

5050

5151

52-
function ElectricField(epot::ElectricPotential{T, 3, S}, point_types::PointTypes{T}) where {T, S}
53-
return ElectricField{T, 3, S, typeof(grid.axes)}(get_electric_field_from_potential( epot, point_types ), epot.grid)
52+
function ElectricField(epot::ElectricPotential{T, 3, S}, point_types::PointTypes{T}; use_nthreads::Int = Base.Threads.nthreads()) where {T, S}
53+
return ElectricField{T, 3, S, typeof(grid.axes)}(get_electric_field_from_potential( epot, point_types; use_nthreads ), epot.grid)
5454
end
5555

5656

57-
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cylindrical}, point_types::PointTypes{T}, fieldvector_coordinates=:xyz)::ElectricField{T, 3, Cylindrical} where {T <: SSDFloat}
57+
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cylindrical}, point_types::PointTypes{T}, fieldvector_coordinates=:xyz; use_nthreads::Int = Base.Threads.threadid())::ElectricField{T, 3, Cylindrical} where {T <: SSDFloat}
5858
p = epot.data
5959
axr::Vector{T} = collect(epot.grid.axes[1])
6060
axφ::Vector{T} = collect(epot.grid.axes[2])
6161
axz::Vector{T} = collect(epot.grid.axes[3])
6262

6363
cyclic::T = epot.grid.axes[2].interval.right
6464
ef = Array{SVector{3, T}}(undef, size(p)...)
65-
for iz in 1:size(ef, 3)
65+
@onthreads 1:use_nthreads for ix in workpart(1:size(ef, 3), 1:use_nthreads, Base.Threads.threadid())
6666
forin 1:size(ef, 2)
6767
for ir in 1:size(ef, 1)
6868
### r ###
@@ -212,7 +212,7 @@ end
212212

213213

214214

215-
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cartesian}, point_types::PointTypes{T})::ElectricField{T, 3, Cartesian} where {T <: SSDFloat}
215+
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cartesian}, point_types::PointTypes{T}; use_nthreads::Int = Base.Threads.nthreads())::ElectricField{T, 3, Cartesian} where {T <: SSDFloat}
216216
axx::Vector{T} = collect(epot.grid.axes[1])
217217
axy::Vector{T} = collect(epot.grid.axes[2])
218218
axz::Vector{T} = collect(epot.grid.axes[3])
@@ -222,7 +222,7 @@ function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cartesi
222222

223223
ef::Array{SVector{3, T}} = Array{SVector{3, T}}(undef, size(epot.data))
224224

225-
for ix in eachindex(axx)
225+
@onthreads 1:use_nthreads for ix in workpart(eachindex(axx), 1:use_nthreads, Base.Threads.threadid())
226226
for iy in eachindex(axy)
227227
for iz in eachindex(axz)
228228
if ix - 1 < 1

src/Simulation/Simulation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ Calculates the [`ElectricField`](@ref) from the [`ElectricPotential`](@ref) stor
11671167
!!! note
11681168
This method only works if `sim.electric_potential` has already been calculated and is not `missing`.
11691169
"""
1170-
function calculate_electric_field!(sim::Simulation{T, CS}; n_points_in_φ::Union{Missing, Int} = missing)::Nothing where {T <: SSDFloat, CS}
1170+
function calculate_electric_field!(sim::Simulation{T, CS}; n_points_in_φ::Union{Missing, Int} = missing, use_nthreads::Int = Base.Threads.nthreads())::Nothing where {T <: SSDFloat, CS}
11711171
@assert !ismissing(sim.electric_potential) "Electric potential has not been calculated yet. Please run `calculate_electric_potential!(sim)` first."
11721172
periodicity::T = width(sim.world.intervals[2])
11731173
e_pot, point_types = if CS == Cylindrical && periodicity == T(0) # 2D, only one point in φ
@@ -1181,15 +1181,15 @@ function calculate_electric_field!(sim::Simulation{T, CS}; n_points_in_φ::Union
11811181
end
11821182
end
11831183
get_2π_potential(sim.electric_potential, n_points_in_φ = n_points_in_φ),
1184-
get_2π_potential(sim.point_types, n_points_in_φ = n_points_in_φ);
1184+
get_2π_potential(sim.point_types, n_points_in_φ = n_points_in_φ)
11851185
elseif CS == Cylindrical
11861186
get_2π_potential(sim.electric_potential),
11871187
get_2π_potential(sim.point_types)
11881188
else
11891189
sim.electric_potential,
11901190
sim.point_types
11911191
end
1192-
sim.electric_field = get_electric_field_from_potential(e_pot, point_types);
1192+
sim.electric_field = get_electric_field_from_potential(e_pot, point_types; use_nthreads)
11931193
nothing
11941194
end
11951195

0 commit comments

Comments
 (0)