Skip to content

Commit a05b678

Browse files
authored
Merge pull request #509 from JuliaPhysics/field_multithreading
2 parents 415ccf2 + 4e2c897 commit a05b678

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

src/ElectricField/ElectricField.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@ Base.convert(T::Type{ElectricField}, x::NamedTuple) = T(x)
4949

5050

5151

52-
function ElectricField(epot::ElectricPotential{T, 3, S}, point_types::PointTypes{T}) where {T, S}
53-
return ElectricField{T, 3, S, typeof(grid.axes)}(get_electric_field_from_potential( epot, point_types ), epot.grid)
52+
function ElectricField(epot::ElectricPotential{T, 3, S}, point_types::PointTypes{T}; use_nthreads::Int = Base.Threads.nthreads()) where {T, S}
53+
return ElectricField{T, 3, S, typeof(grid.axes)}(get_electric_field_from_potential( epot, point_types; use_nthreads ), epot.grid)
5454
end
5555

5656

57-
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cylindrical}, point_types::PointTypes{T}, fieldvector_coordinates=:xyz)::ElectricField{T, 3, Cylindrical} where {T <: SSDFloat}
57+
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cylindrical}, point_types::PointTypes{T}; use_nthreads::Int = Base.Threads.threadid())::ElectricField{T, 3, Cylindrical} where {T <: SSDFloat}
5858
p = epot.data
5959
axr::Vector{T} = collect(epot.grid.axes[1])
6060
axφ::Vector{T} = collect(epot.grid.axes[2])
6161
axz::Vector{T} = collect(epot.grid.axes[3])
6262

6363
cyclic::T = epot.grid.axes[2].interval.right
6464
ef = Array{SVector{3, T}}(undef, size(p)...)
65-
for iz in 1:size(ef, 3)
65+
@onthreads 1:use_nthreads for iz in workpart(1:size(ef, 3), 1:use_nthreads, Base.Threads.threadid())
6666
forin 1:size(ef, 2)
6767
for ir in 1:size(ef, 1)
6868
### r ###
@@ -157,9 +157,7 @@ function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cylindr
157157
end
158158
end
159159
end
160-
if fieldvector_coordinates == :xyz
161-
ef = convert_field_vectors_to_xyz(ef, axφ)
162-
end
160+
ef = convert_field_vectors_to_xyz(ef, axφ)
163161
return ElectricField(ef, point_types.grid)
164162
end
165163

@@ -177,6 +175,7 @@ function convert_field_vectors_to_xyz(field::Array{SArray{Tuple{3},T,1,3},3}, φ
177175
end
178176

179177

178+
180179
function interpolated_scalarfield(spot::ScalarPotential{T, 3, Cylindrical}) where {T}
181180
@inbounds knots = spot.grid.axes[1].ticks, cat(spot.grid.axes[2].ticks,T(2π),dims=1), spot.grid.axes[3].ticks
182181
ext_data = cat(spot.data, spot.data[:,1:1,:], dims=2)
@@ -212,17 +211,17 @@ end
212211

213212

214213

215-
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cartesian}, point_types::PointTypes{T})::ElectricField{T, 3, Cartesian} where {T <: SSDFloat}
214+
function get_electric_field_from_potential(epot::ElectricPotential{T, 3, Cartesian}, point_types::PointTypes{T}; use_nthreads::Int = Base.Threads.nthreads())::ElectricField{T, 3, Cartesian} where {T <: SSDFloat}
216215
axx::Vector{T} = collect(epot.grid.axes[1])
217216
axy::Vector{T} = collect(epot.grid.axes[2])
218217
axz::Vector{T} = collect(epot.grid.axes[3])
219-
axx_ext::Vector{T} = get_extended_ticks(epot.grid.axes[1])
220-
axy_ext::Vector{T} = get_extended_ticks(epot.grid.axes[2])
221-
axz_ext::Vector{T} = get_extended_ticks(epot.grid.axes[3])
218+
# axx_ext::Vector{T} = get_extended_ticks(epot.grid.axes[1])
219+
# axy_ext::Vector{T} = get_extended_ticks(epot.grid.axes[2])
220+
# axz_ext::Vector{T} = get_extended_ticks(epot.grid.axes[3])
222221

223222
ef::Array{SVector{3, T}} = Array{SVector{3, T}}(undef, size(epot.data))
224223

225-
for ix in eachindex(axx)
224+
@onthreads 1:use_nthreads for ix in workpart(eachindex(axx), 1:use_nthreads, Base.Threads.threadid())
226225
for iy in eachindex(axy)
227226
for iz in eachindex(axz)
228227
if ix - 1 < 1

src/Simulation/Simulation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ Calculates the [`ElectricField`](@ref) from the [`ElectricPotential`](@ref) stor
11671167
!!! note
11681168
This method only works if `sim.electric_potential` has already been calculated and is not `missing`.
11691169
"""
1170-
function calculate_electric_field!(sim::Simulation{T, CS}; n_points_in_φ::Union{Missing, Int} = missing)::Nothing where {T <: SSDFloat, CS}
1170+
function calculate_electric_field!(sim::Simulation{T, CS}; n_points_in_φ::Union{Missing, Int} = missing, use_nthreads::Int = Base.Threads.nthreads())::Nothing where {T <: SSDFloat, CS}
11711171
@assert !ismissing(sim.electric_potential) "Electric potential has not been calculated yet. Please run `calculate_electric_potential!(sim)` first."
11721172
periodicity::T = width(sim.world.intervals[2])
11731173
e_pot, point_types = if CS == Cylindrical && periodicity == T(0) # 2D, only one point in φ
@@ -1181,15 +1181,15 @@ function calculate_electric_field!(sim::Simulation{T, CS}; n_points_in_φ::Union
11811181
end
11821182
end
11831183
get_2π_potential(sim.electric_potential, n_points_in_φ = n_points_in_φ),
1184-
get_2π_potential(sim.point_types, n_points_in_φ = n_points_in_φ);
1184+
get_2π_potential(sim.point_types, n_points_in_φ = n_points_in_φ)
11851185
elseif CS == Cylindrical
11861186
get_2π_potential(sim.electric_potential),
11871187
get_2π_potential(sim.point_types)
11881188
else
11891189
sim.electric_potential,
11901190
sim.point_types
11911191
end
1192-
sim.electric_field = get_electric_field_from_potential(e_pot, point_types);
1192+
sim.electric_field = get_electric_field_from_potential(e_pot, point_types; use_nthreads)
11931193
nothing
11941194
end
11951195

0 commit comments

Comments
 (0)