Skip to content

Commit 05db8c2

Browse files
committed
Allow gradient on LDF to be called with e.g. views
1 parent e1cb399 commit 05db8c2

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# DynamicPPL Changelog
22

3+
## 0.39.1
4+
5+
`LogDensityFunction` now allows you to call `logdensity_and_gradient(ldf, x)` with `AbstractVector`s `x` that are not plain Vectors (they will be converted internally before calculating the gradient).
6+
37
## 0.39.0
48

59
### Breaking changes

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.0"
3+
version = "0.39.1"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/logdensityfunction.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ struct LogDensityFunction{
148148
F<:Function,
149149
N<:NamedTuple,
150150
ADP<:Union{Nothing,DI.GradientPrep},
151+
# type of the vector passed to logdensity functions
152+
X<:AbstractVector,
151153
}
152154
model::M
153155
adtype::AD
@@ -202,12 +204,17 @@ struct LogDensityFunction{
202204
typeof(getlogdensity),
203205
typeof(all_iden_ranges),
204206
typeof(prep),
207+
typeof(x),
205208
}(
206209
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
207210
)
208211
end
209212
end
210213

214+
function _get_input_vector_type(::LogDensityFunction{T,M,A,G,I,P,X}) where {T,M,A,G,I,P,X}
215+
return X
216+
end
217+
211218
###################################
212219
# LogDensityProblems.jl interface #
213220
###################################
@@ -265,6 +272,7 @@ end
265272
function LogDensityProblems.logdensity_and_gradient(
266273
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
267274
) where {Tlink}
275+
params = convert(_get_input_vector_type(ldf), params)
268276
return DI.value_and_gradient(
269277
LogDensityAt{Tlink}(
270278
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges

test/logdensityfunction.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,25 @@ end
186186
end
187187
end
188188

189+
@testset "logdensity_and_gradient with views" begin
190+
# This test ensures that you can call `logdensity_and_gradient` with an array
191+
# type that isn't the same as the one used in the gradient preparation.
192+
@model function f()
193+
x ~ Normal()
194+
return y ~ Normal()
195+
end
196+
@testset "$adtype" for adtype in test_adtypes
197+
x = randn(2)
198+
ldf = LogDensityFunction(f(); adtype)
199+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
200+
logp_view, grad_view = LogDensityProblems.logdensity_and_gradient(
201+
ldf, (@view x[:])
202+
)
203+
@test logp == logp_view
204+
@test grad == grad_view
205+
end
206+
end
207+
189208
# Test that various different ways of specifying array types as arguments work with all
190209
# ADTypes.
191210
@testset "Array argument types" begin

0 commit comments

Comments
 (0)