Skip to content

Commit 425fe79

Browse files
Merge pull request #106 from hexaeder/master
allow for explicit size specification
2 parents 1f8db64 + c335a54 commit 425fe79

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

ext/PreallocationToolsReverseDiffExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ using PreallocationTools
44
isdefined(Base, :get_extension) ? (import ReverseDiff) : (import ..ReverseDiff)
55

66
# PreallocationTools https://github.com/SciML/PreallocationTools.jl/issues/39
7-
function Base.getindex(b::PreallocationTools.LazyBufferCache, u::ReverseDiff.TrackedArray)
8-
s = b.sizemap(size(u)) # required buffer size
7+
function Base.getindex(b::PreallocationTools.LazyBufferCache,
8+
u::ReverseDiff.TrackedArray, s = b.sizemap(size(u)))
99
T = ReverseDiff.TrackedArray
1010
buf = get!(b.bufs, (T, s)) do
1111
# declare type since b.bufs dictionary is untyped

src/PreallocationTools.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ end
202202
A lazily allocated buffer object. Given an array `u`, `b[u]` returns an array of the
203203
same type and size `f(size(u))` (defaulting to the same size), which is allocated as
204204
needed and then cached within `b` for subsequent usage.
205+
206+
Optionally, the size can be explicitly given at calltime using `b[u,s]`, which will
207+
return a cache of size `s`.
205208
"""
206209
struct LazyBufferCache{F <: Function}
207210
bufs::Dict{Any, Any} # a dictionary mapping (type, size) pairs to buffers
@@ -216,15 +219,18 @@ function similar_type(x::AbstractArray{T}, s::NTuple{N, Integer}) where {T, N}
216219
typeof(similar(x, ntuple(Returns(1), N)))
217220
end
218221

219-
function get_tmp(b::LazyBufferCache, u::T) where {T <: AbstractArray}
220-
s = b.sizemap(size(u)) # required buffer size
222+
function get_tmp(
223+
b::LazyBufferCache, u::T, s = b.sizemap(size(u))) where {T <: AbstractArray}
221224
get!(b.bufs, (T, s)) do
222225
similar(u, s) # buffer to allocate if it was not found in b.bufs
223226
end::similar_type(u, s) # declare type since b.bufs dictionary is untyped
224227
end
225228

226229
# override the [] method
227-
Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray} = get_tmp(b, u)
230+
function Base.getindex(
231+
b::LazyBufferCache, u::T, s = b.sizemap(size(u))) where {T <: AbstractArray}
232+
get_tmp(b, u, s)
233+
end
228234

229235
# GeneralLazyBufferCache
230236

test/general_lbc.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,20 @@ y = view(x, 1:900)
4242
@test 0 == @allocated cache[y]
4343
@test cache[y] === get_tmp(cache, y)
4444

45+
@inferred cache[x, 1111]
46+
@test 0 == @allocated cache[x, 1111]
47+
@test size(cache[x, 1111]) == (1111,)
48+
4549
cache_17 = LazyBufferCache(Returns(17))
4650
x = 1:10
4751
@inferred cache_17[x]
4852
@test 0 == @allocated cache_17[x]
4953
@test size(cache_17[x]) == (17,)
5054

55+
@inferred cache_17[x, 1111]
56+
@test 0 == @allocated cache_17[x, 1111]
57+
@test size(cache_17[x, 1111]) == (1111,)
58+
5159
cache = GeneralLazyBufferCache(T -> Vector{T}(undef, 1000))
5260
# GeneralLazyBufferCache is documented not to infer.
5361
# @inferred cache[Float64]

0 commit comments

Comments
 (0)