Skip to content

Commit 54b2992

Browse files
committed
Merge branch 'master' into tor/turing-diag-normal-extras
2 parents a2c2439 + d0a1063 commit 54b2992

File tree

4 files changed

+66
-82
lines changed

4 files changed

+66
-82
lines changed

Manifest.toml

Lines changed: 52 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
[[AbstractFFTs]]
44
deps = ["LinearAlgebra"]
5-
git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
5+
git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716"
66
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
7-
version = "0.4.1"
7+
version = "0.5.0"
88

99
[[Adapt]]
1010
deps = ["LinearAlgebra"]
@@ -29,15 +29,9 @@ version = "0.8.10"
2929

3030
[[BinaryProvider]]
3131
deps = ["Libdl", "SHA"]
32-
git-tree-sha1 = "29995a7b317bbd06be147e1974a3541ce2502dca"
32+
git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
3333
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
34-
version = "0.5.7"
35-
36-
[[CSTParser]]
37-
deps = ["Tokenize"]
38-
git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
39-
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
40-
version = "0.6.2"
34+
version = "0.5.8"
4135

4236
[[Combinatorics]]
4337
deps = ["LinearAlgebra", "Polynomials", "Test"]
@@ -63,12 +57,6 @@ git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
6357
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
6458
version = "1.3.0"
6559

66-
[[Crayons]]
67-
deps = ["Test"]
68-
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
69-
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
70-
version = "4.0.0"
71-
7260
[[DataAPI]]
7361
git-tree-sha1 = "674b67f344687a88310213ddfa8a2b3c76cc4252"
7462
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
@@ -95,50 +83,44 @@ uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
9583
version = "0.0.4"
9684

9785
[[DiffRules]]
98-
deps = ["Random", "Test"]
99-
git-tree-sha1 = "dc0869fb2f5b23466b32ea799bd82c76480167f7"
86+
deps = ["NaNMath", "Random", "SpecialFunctions"]
87+
git-tree-sha1 = "f734b5f6bc9c909027ef99f6d91d5d9e4b111eed"
10088
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
101-
version = "0.0.10"
89+
version = "0.1.0"
10290

10391
[[Distributed]]
10492
deps = ["Random", "Serialization", "Sockets"]
10593
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
10694

10795
[[Distributions]]
10896
deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
109-
git-tree-sha1 = "058e5e39ceee0f92ccec70c5bf31c90ffb374669"
97+
git-tree-sha1 = "ce189b71fac635d6ec9582dc0f208887db25e6d3"
11098
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
111-
version = "0.21.5"
99+
version = "0.21.8"
112100

113101
[[FFTW]]
114-
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
115-
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
102+
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport"]
103+
git-tree-sha1 = "4cfd3d43819228b9e73ab46600d0af0aa5cedceb"
116104
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
117-
version = "1.0.1"
105+
version = "1.1.0"
118106

119107
[[FillArrays]]
120108
deps = ["LinearAlgebra", "Random", "SparseArrays"]
121-
git-tree-sha1 = "6827a8f73ff12707f209c920d204238a16892b55"
109+
git-tree-sha1 = "b2cf74f09216cfe3c241e8484178ec0ea941870f"
122110
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
123-
version = "0.8.0"
124-
125-
[[FiniteDifferences]]
126-
deps = ["LinearAlgebra", "Printf"]
127-
git-tree-sha1 = "98ae83a564ce5c4066d3ef45ef2310f089fdf99f"
128-
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
129-
version = "0.7.2"
111+
version = "0.8.1"
130112

131113
[[ForwardDiff]]
132-
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
133-
git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
114+
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
115+
git-tree-sha1 = "4407e7b76999eca2646abdb68203bd4302476168"
134116
uuid = "f6369f11-7733-5829-9624-2563aa707210"
135-
version = "0.10.3"
117+
version = "0.10.6"
136118

137119
[[IRTools]]
138120
deps = ["InteractiveUtils", "MacroTools", "Test"]
139-
git-tree-sha1 = "e23faa71b8f54c3fdc99b230b9c2906cafdddca5"
121+
git-tree-sha1 = "72421971e60917b8cd7737f9577c4f0f87eab306"
140122
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
141-
version = "0.2.3"
123+
version = "0.3.0"
142124

143125
[[InteractiveUtils]]
144126
deps = ["Markdown"]
@@ -164,10 +146,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
164146
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
165147

166148
[[MacroTools]]
167-
deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"]
168-
git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76"
149+
deps = ["Compat", "DataStructures", "Test"]
150+
git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06"
169151
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
170-
version = "0.5.1"
152+
version = "0.5.2"
171153

172154
[[Markdown]]
173155
deps = ["Base64"]
@@ -202,35 +184,35 @@ version = "1.1.0"
202184

203185
[[PDMats]]
204186
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
205-
git-tree-sha1 = "9d6a9b3e19634612fb1edcafc4b1d75242b24bde"
187+
git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563"
206188
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
207-
version = "0.9.9"
189+
version = "0.9.10"
208190

209191
[[Parsers]]
210192
deps = ["Dates", "Test"]
211-
git-tree-sha1 = "ef0af6c8601db18c282d092ccbd2f01f3f0cd70b"
193+
git-tree-sha1 = "a23968e107c0544aca91bfab6f7dd34de1206a54"
212194
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
213-
version = "0.3.7"
195+
version = "0.3.9"
214196

215197
[[Pkg]]
216198
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
217199
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
218200

219201
[[Polynomials]]
220202
deps = ["LinearAlgebra", "RecipesBase"]
221-
git-tree-sha1 = "f7c0c07e82798aef542d60a6e6e85e39f4590750"
203+
git-tree-sha1 = "ae71c2329790af97b7682b11241b3609e4d48626"
222204
uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
223-
version = "0.5.3"
205+
version = "0.6.0"
224206

225207
[[Printf]]
226208
deps = ["Unicode"]
227209
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
228210

229211
[[QuadGK]]
230-
deps = ["DataStructures", "LinearAlgebra", "Test"]
231-
git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0"
212+
deps = ["DataStructures", "LinearAlgebra"]
213+
git-tree-sha1 = "1af46bf083b9630a5b27d4fd94f496c5fca642a8"
232214
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
233-
version = "2.0.3"
215+
version = "2.1.1"
234216

235217
[[REPL]]
236218
deps = ["InteractiveUtils", "Markdown", "Sockets"]
@@ -258,10 +240,10 @@ uuid = "ae029012-a4dd-5104-9daa-d747884805df"
258240
version = "0.5.2"
259241

260242
[[Rmath]]
261-
deps = ["BinaryProvider", "Libdl", "Random", "Statistics", "Test"]
262-
git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9"
243+
deps = ["BinaryProvider", "Libdl", "Random", "Statistics"]
244+
git-tree-sha1 = "9825383d3453f4606d77f0a5722495f38001c09e"
263245
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
264-
version = "0.5.0"
246+
version = "0.5.1"
265247

266248
[[SHA]]
267249
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
@@ -287,16 +269,16 @@ deps = ["LinearAlgebra", "Random"]
287269
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
288270

289271
[[SpecialFunctions]]
290-
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
291-
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
272+
deps = ["BinDeps", "BinaryProvider", "Libdl"]
273+
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
292274
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
293-
version = "0.7.2"
275+
version = "0.8.0"
294276

295277
[[StaticArrays]]
296278
deps = ["LinearAlgebra", "Random", "Statistics"]
297-
git-tree-sha1 = "1085ffbf5fd48fdba64ef8e902ca429c4e1212d3"
279+
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
298280
uuid = "90137ffa-7385-5640-81b9-e52037218182"
299-
version = "0.11.1"
281+
version = "0.12.1"
300282

301283
[[Statistics]]
302284
deps = ["LinearAlgebra", "SparseArrays"]
@@ -309,35 +291,30 @@ uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
309291
version = "0.32.0"
310292

311293
[[StatsFuns]]
312-
deps = ["Rmath", "SpecialFunctions", "Test"]
313-
git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e"
294+
deps = ["Rmath", "SpecialFunctions"]
295+
git-tree-sha1 = "67745a79d8e83a83737a7e17a383c54720a97f41"
314296
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
315-
version = "0.8.0"
297+
version = "0.9.0"
316298

317299
[[SuiteSparse]]
318-
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
300+
deps = ["Libdl", "LinearAlgebra", "SparseArrays"]
319301
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
320302

321303
[[Test]]
322304
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
323305
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
324306

325307
[[TimerOutputs]]
326-
deps = ["Crayons", "Printf", "Test", "Unicode"]
327-
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
308+
deps = ["Printf"]
309+
git-tree-sha1 = "311765af81bbb48d7bad01fb016d9c328c6ede03"
328310
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
329-
version = "0.5.0"
330-
331-
[[Tokenize]]
332-
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
333-
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
334-
version = "0.5.6"
311+
version = "0.5.3"
335312

336313
[[Tracker]]
337314
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
338-
git-tree-sha1 = "1aa443d3b4bfa91a8aec32f169a479cb87309910"
315+
git-tree-sha1 = "439e3a4f6d54739bb17c36aa1b5855acec22fc1e"
339316
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
340-
version = "0.2.3"
317+
version = "0.2.5"
341318

342319
[[URIParser]]
343320
deps = ["Test", "Unicode"]
@@ -360,12 +337,12 @@ version = "1.1.3"
360337

361338
[[Zygote]]
362339
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
363-
git-tree-sha1 = "d21e86576e25e4adc09631b34651798775fba99a"
340+
git-tree-sha1 = "e4245b9c5362346e154b62842a89a18e0210b92b"
364341
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
365-
version = "0.3.4"
342+
version = "0.4.1"
366343

367344
[[ZygoteRules]]
368345
deps = ["MacroTools"]
369-
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6"
346+
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
370347
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
371-
version = "0.1.0"
348+
version = "0.2.0"

Project.toml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.1.1"
3+
version = "0.1.2"
44

55
[deps]
66
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
98
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
109
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1211
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1312
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1413
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
15-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1614
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1715
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1816

1917
[compat]
18+
Combinatorics = "0.7"
19+
Distributions = "0.21.6"
20+
ForwardDiff = "0.10.6"
21+
PDMats = "0.9"
22+
StatsFuns = "0.8, 0.9"
23+
Tracker = "0.2.5"
24+
Zygote = "0.4.1"
2025
julia = "1"
2126

2227
[extras]
23-
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
2428
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2529
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
30+
31+
[targets]
32+
test = ["Test", "FiniteDifferences"]

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function turing_chol(A::AbstractMatrix, check)
4141
end
4242
turing_chol(A::Tracker.TrackedMatrix, check) = Tracker.track(turing_chol, A, check)
4343
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
44-
C, back = Zygote.forward(unsafe_cholesky, Tracker.data(A), Tracker.data(check))
44+
C, back = Zygote.pullback(unsafe_cholesky, Tracker.data(A), Tracker.data(check))
4545
return (C.factors, C.info), Δ->back((factors=Tracker.data(Δ[1]),))
4646
end
4747

@@ -104,7 +104,7 @@ function zygote_ldiv(A::Tracker.TrackedMatrix, B::AbstractVecOrMat)
104104
end
105105
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = Tracker.track(zygote_ldiv, A, B)
106106
Tracker.@grad function zygote_ldiv(A, B)
107-
Y, back = Zygote.forward(\, Tracker.data(A), Tracker.data(B))
107+
Y, back = Zygote.pullback(\, Tracker.data(A), Tracker.data(B))
108108
return Y, Δ->back(Tracker.data(Δ))
109109
end
110110

test/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ end
9696
9797
Check that the reverse-mode sensitivities produced by an AD library are correct for `f`
9898
at `x...`, given sensitivity `ȳ` w.r.t. `y = f(x...)` up to `rtol` and `atol`.
99-
`forward` should be either `Tracker.forward` or `Zygote.forward`.
99+
`forward` should be either `Tracker.forward` or `Zygote.pullback`.
100100
"""
101101
function test_reverse_mode_ad(forward, f, ȳ, x...; rtol=1e-8, atol=1e-8)
102102

0 commit comments

Comments
 (0)