Skip to content

Commit 6e2578d

Browse files
committed
use batch memory operations
1 parent 0b6a796 commit 6e2578d

File tree

2 files changed

+11
-41
lines changed

2 files changed

+11
-41
lines changed

src/C_JetReconstruction/C_JetReconstruction.jl

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -59,38 +59,10 @@ exception_to_enum(::UndefRefError) = Cint(StatusCode.UndefRefError)
5959
exception_to_enum(::UndefVarError) = Cint(StatusCode.UndefVarError)
6060
exception_to_enum(::StringIndexError) = Cint(StatusCode.StringIndexError)
6161

62-
"""
63-
unsafe_wrap_c_array(ptr::Ptr{T}, array_length::Csize_t) where {T}
64-
65-
Wraps a C array into a Julia `Vector` for both bits and non-bits types.
66-
67-
# Arguments
68-
- `ptr::Ptr{T}`: A pointer to the C array.
69-
- `array_length::Csize_t`: The length of the C array.
70-
71-
# Returns
72-
- A Julia `Vector{T}` containing the elements of the C array.
73-
74-
# Safety
75-
This function use 'unsafe' methods and has undefined behaviour
76-
if pointer isn't valid or length isn't correct.
77-
"""
78-
function unsafe_wrap_c_array(ptr::Ptr{T}, array_length::Csize_t) where {T}
79-
if isbitstype(T)
80-
return unsafe_wrap(Vector{T}, ptr, array_length)
81-
end
82-
83-
vec = Vector{T}(undef, array_length)
84-
for i in eachindex(vec)
85-
@inbounds vec[i] = unsafe_load(ptr, i)
86-
end
87-
return vec
88-
end
89-
9062
"""
9163
make_c_array(v::Vector{T}) where {T}
9264
93-
Helper function for converting a Julia vector to a C-style array.
65+
Helper function for converting a Julia vector of isbits objects to a C-style array.
9466
A C-style array is dynamically allocated and contents of input vector copied to it.
9567
9668
# Arguments
@@ -105,16 +77,15 @@ A C-style array is dynamically allocated and contents of input vector copied to
10577
10678
# Notes
10779
- The caller is responsible for freeing the allocated memory using `Libc.free(ptr)`.
80+
- The `T` type must be an isbits type.
10881
"""
10982
function make_c_array(v::Vector{T}) where {T}
11083
len = length(v)
11184
ptr = Ptr{T}(Libc.malloc(len * sizeof(T)))
11285
if ptr == C_NULL
11386
throw(OutOfMemoryError("Libc.malloc failed to allocate memory"))
11487
end
115-
for i in 1:len
116-
@inbounds unsafe_store!(ptr, v[i], i)
117-
end
88+
unsafe_copyto!(ptr, pointer(v), len)
11889
return ptr, Csize_t(len)
11990
end
12091

@@ -217,8 +188,8 @@ Convert a `C_ClusterSequence` object to a `ClusterSequence` object.
217188
- The input object must remain valid while the output object is used.
218189
"""
219190
function ClusterSequence{T}(c::C_ClusterSequence{T}) where {T}
220-
jets_v = unsafe_wrap_c_array(c.jets, c.jets_length)
221-
history_v = unsafe_wrap_c_array(c.history, c.history_length)
191+
jets_v = unsafe_wrap(Vector{T}, c.jets, c.jets_length)
192+
history_v = unsafe_wrap(Vector{HistoryElement}, c.history, c.history_length)
222193
return ClusterSequence{T}(c.algorithm, c.power, c.R, c.strategy, jets_v,
223194
c.n_initial_jets,
224195
history_v, c.Qtot)
@@ -279,7 +250,7 @@ function c_jet_reconstruct(particles::Ptr{T},
279250
strategy::RecoStrategy.Strategy,
280251
result::Ptr{C_ClusterSequence{U}}) where {T, U}
281252
try
282-
particles_v = unsafe_wrap_c_array(particles, particles_length)
253+
particles_v = unsafe_wrap(Vector{T}, particles, particles_length)
283254
clusterseq = jet_reconstruct(particles_v; p = nothing, algorithm = algorithm, R = R,
284255
strategy = strategy)
285256
c_clusterseq = C_ClusterSequence{U}(clusterseq)

test/test-c-interface.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ function compare_results(ptr::Ptr{C_JetReconstruction.C_ClusterSequence{T}},
1717
@test c_cluster_seq.R cluster_seq.R
1818
@test c_cluster_seq.strategy == cluster_seq.strategy
1919
@test c_cluster_seq.jets_length == length(cluster_seq.jets)
20-
c_jets = C_JetReconstruction.unsafe_wrap_c_array(c_cluster_seq.jets,
21-
c_cluster_seq.jets_length)
20+
c_jets = unsafe_wrap(Vector{T}, c_cluster_seq.jets,
21+
c_cluster_seq.jets_length)
2222
@test all(struct_approx_equal.(c_jets, cluster_seq.jets))
2323
@test c_cluster_seq.n_initial_jets == cluster_seq.n_initial_jets
2424
@test c_cluster_seq.history_length == length(cluster_seq.history)
25-
c_history = C_JetReconstruction.unsafe_wrap_c_array(c_cluster_seq.history,
26-
c_cluster_seq.history_length)
25+
c_history = unsafe_wrap(Vector{JetReconstruction.HistoryElement}, c_cluster_seq.history,
26+
c_cluster_seq.history_length)
2727
@test all(struct_approx_equal.(c_history, cluster_seq.history))
2828
@test c_cluster_seq.Qtot cluster_seq.Qtot
2929
end
@@ -33,8 +33,7 @@ function compare_results(ptr::Ptr{C_JetReconstruction.C_JetsResult{T}},
3333
@test ptr != C_NULL
3434
c_results = unsafe_load(ptr)
3535
@test c_results.length == length(jets)
36-
c_data = C_JetReconstruction.unsafe_wrap_c_array(c_results.data,
37-
c_results.length)
36+
c_data = unsafe_wrap(Vector{T}, c_results.data, c_results.length)
3837
@test all(struct_approx_equal.(c_data, jets))
3938
end
4039

0 commit comments

Comments
 (0)