Skip to content

Commit 945faad

Browse files
authored
implement ML model depending on thetaP
* allow constraining mean of PBL parameters to initial values * inspect elbo components of theta-contrained vs unconstrained HVI fits * move CUDA dependency to extension and remove GPUDataHandler, replace by MLDataDevices * factor predict_ζf out of predict_gf ** for inspecting NUTS generated samples on unconstrained scale * fix error in confusing parameter positions ** depending on scenario-templates need to construct different Interpreters passed to forward model * compare plots of marginal posterior densities of several sites * compare HVI to HMC inversion * implement ML model depending on thetaP * default predict_gf for all sites of dataloader of problem * return updated problem from solvers
1 parent 17d7bdc commit 945faad

35 files changed

+2101
-735
lines changed

.github/workflows/CompatHelper.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ jobs:
1414
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1515
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
1616
run: julia -e 'using CompatHelper; CompatHelper.main()'
17-
keepalive-job:
18-
name: Keepalive Workflow
19-
runs-on: ubuntu-latest
20-
permissions:
21-
actions: write
22-
steps:
23-
- uses: actions/checkout@v4
24-
with:
25-
ref: 'keepalive' # The branch, tag or SHA to checkout.
26-
- uses: gautamkrishnar/keepalive-workflow@v2
17+
# keepalive-job:
18+
# name: Keepalive Workflow
19+
# runs-on: ubuntu-latest
20+
# permissions:
21+
# actions: write
22+
# steps:
23+
# - uses: actions/checkout@v4
24+
# with:
25+
# ref: 'keepalive' # The branch, tag or SHA to checkout.
26+
# - uses: gautamkrishnar/keepalive-workflow@v2

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
test/Manifest*.toml
88
dev/Manifest*.toml
99
tmp/
10+
**/tmp.svg
11+
dev/intermediate/*
12+
dev/tmp.pdf

Project.toml

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,55 +6,60 @@ version = "1.0.0-DEV"
66
[deps]
77
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
88
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
9-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
109
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1110
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1211
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
1312
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1413
DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
1514
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
15+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1616
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1717
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1919
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
2020
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
2121
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
22+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2224
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2325
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
24-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2526

2627
[weakdeps]
28+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2729
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2830
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
2931
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
3032

3133
[extensions]
34+
HybridVariationalInferenceCUDAExt = "CUDA"
3235
HybridVariationalInferenceFluxExt = "Flux"
3336
HybridVariationalInferenceLuxExt = "Lux"
3437
HybridVariationalInferenceSimpleChainsExt = "SimpleChains"
3538

3639
[compat]
37-
Bijectors = "0.15.4"
38-
BlockDiagonals = "0.1.42"
39-
CUDA = "5.5.2"
40+
Bijectors = "0.14, 0.15"
41+
BlockDiagonals = "0.1.42, 0.2"
42+
CUDA = "5.7"
4043
ChainRulesCore = "1.25"
4144
Combinatorics = "1.0.2"
4245
CommonSolve = "0.2.4"
4346
ComponentArrays = "0.15.19"
4447
DistributionFits = "0.3.9"
4548
Distributions = "0.25.117"
46-
Flux = "v0.15.2, 0.16"
49+
Flux = "0.14, 0.15, 0.16"
50+
Functors = "0.4, 0.5"
4751
GPUArraysCore = "0.1, 0.2"
48-
LinearAlgebra = "1.10.0"
52+
LinearAlgebra = "1.10"
4953
Lux = "1.4.2"
50-
MLDataDevices = "1.6.9"
54+
MLDataDevices = "1.5, 1.6"
5155
MLUtils = "0.4.5"
5256
Optimization = "3.19.3, 4"
5357
Random = "1.10.0"
5458
SimpleChains = "0.4"
59+
StableRNGs = "1.0.2"
60+
StaticArrays = "1.9.13"
5561
StatsBase = "0.34.4"
5662
StatsFuns = "1.3.2"
57-
Zygote = "0.6.73, 0.7"
5863
julia = "1.10"
5964

6065
[workspace]

_typos.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[default.extend-words]
2-
SOM = "SOM"
2+
SOM = "SOM"
3+
negLogLik = "negLogLik"

dev/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
44
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
55
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
66
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7+
DistributionFits = "45214091-1ed4-4409-9bcf-fdb48a05e921"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
910
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1011
HybridVariationalInference = "a108c475-a4e2-4021-9a84-cfa7df242f64"
12+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
13+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
14+
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1115
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1216
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
1317
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
18+
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1419
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1520
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
1621
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1722
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1823
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
24+
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1925
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2026
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2127
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

0 commit comments

Comments
 (0)