Skip to content

Commit 0e53f1a

Browse files
committed
remove tapir extension
1 parent 8c19984 commit 0e53f1a

File tree

4 files changed

+25
-42
lines changed

4 files changed

+25
-42
lines changed

Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2121
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2222
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2323
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
24-
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
2524
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2625
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2726

@@ -31,7 +30,6 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3130
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3231
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3332
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
34-
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
3533
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3634

3735
[extensions]
@@ -40,7 +38,6 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
4038
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4139
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4240
DynamicPPLReverseDiffExt = ["ReverseDiff"]
43-
DynamicPPLTapirExt = ["Tapir"]
4441
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4542

4643
[compat]
@@ -67,7 +64,6 @@ Random = "1.6"
6764
Requires = "1"
6865
ReverseDiff = "1"
6966
Test = "1.6"
70-
Tapir = "0.2.40"
7167
ZygoteRules = "0.2"
7268
julia = "1.6"
7369

ext/DynamicPPLTapirExt.jl

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/DynamicPPL.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,6 @@ end
214214
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
215215
"../ext/DynamicPPLReverseDiffExt.jl"
216216
)
217-
@require Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" include(
218-
"../ext/DynamicPPLTapirExt.jl"
219-
)
220217
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
221218
"../ext/DynamicPPLZygoteRulesExt.jl"
222219
)

src/logdensityfunction.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,31 @@ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
116116
return Accessors.@set f.model = model
117117
end
118118

119+
# TODO: special case for Tapir, should move to a package extension once Julia compat is updated
120+
function DynamicPPL.setmodel(
121+
f::LogDensityProblemsAD.ADGradientWrapper,
122+
model::DynamicPPL.Model,
123+
adtype::ADTypes.AutoTapir,
124+
)
125+
if !hasfield(typeof(f), :rule)
126+
@warn "ADGradientWrapper does not have a `rule` field. Please check Tapir version. It is also possible that `adtype` mismatch `ADGradientWrapper` type."
127+
@warn "Using default rule."
128+
return LogDensityProblemsAD.ADgradient(
129+
Val(:Tapir),
130+
DynamicPPL.setmodel(LogDensityProblemsAD.parent(f), model);
131+
safety_on=adtype.safe_mode,
132+
rule=nothing,
133+
)
134+
else
135+
return LogDensityProblemsAD.ADgradient(
136+
Val(:Tapir),
137+
DynamicPPL.setmodel(LogDensityProblemsAD.parent(f), model);
138+
safety_on=adtype.safe_mode,
139+
rule=f.rule,
140+
)
141+
end
142+
end
143+
119144
# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
120145
# we need to define these annoying methods to ensure that we stay compatible with everything.
121146
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))

0 commit comments

Comments
 (0)