Skip to content

Commit fc37801

Browse files
committed
Don't deepcopy accs
1 parent 80db9e2 commit fc37801

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

src/logdensityfunction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ closure approach will be used. By default, this function returns `false`, i.e.
305305
the constant approach will be used.
306306
"""
307307
use_closure(::ADTypes.AbstractADType) = true
308+
use_closure(::ADTypes.AutoEnzyme) = false
308309

309310
"""
310311
getmodel(f)

src/varinfo.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ function typed_varinfo(vi::UntypedVarInfo)
287287
)
288288
end
289289
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
290-
return VarInfo(nt, deepcopy(vi.accs))
290+
return VarInfo(nt, vi.accs)
291291
end
292292
function typed_varinfo(vi::NTVarInfo)
293293
# This function preserves the behaviour of typed_varinfo(vi) where vi is
@@ -348,7 +348,7 @@ single `VarNamedVector` as its metadata field.
348348
"""
349349
function untyped_vector_varinfo(vi::UntypedVarInfo)
350350
md = metadata_to_varnamedvector(vi.metadata)
351-
return VarInfo(md, deepcopy(vi.accs))
351+
return VarInfo(md, vi.accs)
352352
end
353353
function untyped_vector_varinfo(
354354
rng::Random.AbstractRNG,
@@ -391,12 +391,12 @@ NamedTuple of `VarNamedVector`s as its metadata field.
391391
"""
392392
function typed_vector_varinfo(vi::NTVarInfo)
393393
md = map(metadata_to_varnamedvector, vi.metadata)
394-
return VarInfo(md, deepcopy(vi.accs))
394+
return VarInfo(md, vi.accs)
395395
end
396396
function typed_vector_varinfo(vi::UntypedVectorVarInfo)
397397
new_metas = group_by_symbol(vi.metadata)
398398
nt = NamedTuple(new_metas)
399-
return VarInfo(nt, deepcopy(vi.accs))
399+
return VarInfo(nt, vi.accs)
400400
end
401401
function typed_vector_varinfo(
402402
rng::Random.AbstractRNG,
@@ -447,10 +447,7 @@ function unflatten(vi::VarInfo, x::AbstractVector)
447447
# The below line is finicky for type stability. For instance, assigning the eltype to
448448
# convert to into an intermediate variable makes this unstable (constant propagation)
449449
# fails. Take care when editing.
450-
accs = map(
451-
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc),
452-
deepcopy(getaccs(vi)),
453-
)
450+
accs = map(acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), vi.accs)
454451
return VarInfo(md, accs)
455452
end
456453

@@ -533,7 +530,7 @@ end
533530

534531
function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
535532
metadata = subset(varinfo.metadata, vns)
536-
return VarInfo(metadata, deepcopy(varinfo.accs))
533+
return VarInfo(metadata, varinfo.accs)
537534
end
538535

539536
function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
@@ -622,7 +619,7 @@ end
622619

623620
function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
624621
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
625-
return VarInfo(metadata, deepcopy(varinfo_right.accs))
622+
return VarInfo(metadata, varinfo_right.accs)
626623
end
627624

628625
function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector)
@@ -1014,7 +1011,7 @@ istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn)
10141011
istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans")
10151012

10161013
getaccs(vi::VarInfo) = vi.accs
1017-
setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs
1014+
setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.metadata, accs)
10181015

10191016
# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple).
10201017
isempty(vi::VarInfo) = _isempty(vi.metadata)

0 commit comments

Comments
 (0)