Skip to content

Commit 13495b0

Browse files
Merge pull request #123 from TARGENE/dependent_ps
Fix propensity score dependency structure
2 parents a075c0e + 162ffeb commit 13495b0

File tree

13 files changed

+299
-34
lines changed

13 files changed

+299
-34
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
1919
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
2020
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
2121
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
22-
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2322
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2423
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
2524
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -55,7 +54,6 @@ MLJModels = "0.15, 0.16, 0.17"
5554
MetaGraphsNext = "0.7"
5655
Missings = "1.0"
5756
OrderedCollections = "1.6.3"
58-
PrecompileTools = "1.1.1"
5957
SplitApplyCombine = "1.2.2"
6058
TableOperations = "1.2"
6159
Tables = "1.6"

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ makedocs(;
2424
prettyurls=get(ENV, "CI", "false") == "true",
2525
canonical="https://TARGENE.github.io/TMLE.jl",
2626
assets=String["assets/logo.ico"],
27+
size_threshold=nothing
2728
),
2829
pages=[
2930
"Home" => "index.md",
@@ -32,7 +33,8 @@ makedocs(;
3233
("scm.md", "estimands.md", "estimation.md")],
3334
"Examples" => [
3435
joinpath("examples", "super_learning.md"),
35-
joinpath("examples", "double_robustness.md")
36+
joinpath("examples", "double_robustness.md"),
37+
joinpath("examples", "interactions_correlated.md"),
3638
],
3739
"Integrations" => "integrations.md",
3840
"Estimators' Cheat Sheet" => "estimators_cheatsheet.md",
86.6 KB
Loading
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#=
2+
# Interaction Estimation
3+
4+
In this example we aim to estimate the average interaction effect of two, potentially correlated,
5+
treatment variables `T1` and `T2` on an outcome `Y`.
6+
7+
## Data Generating Process
8+
9+
Let's consider the following structural causal model where the shaded nodes represent the observed variables.
10+
11+
![interaction-graph](../assets/interaction_graph.png)
12+
13+
In other words, only one confounding variable is observed (`W1`). This would be a major problem if we wanted to estimate the
14+
average treatment effect of `T1` or `T2` on `Y` separately. However, here, we are interested in interactions and thus `W1` is
15+
a sufficient adjustment set. This artificial situation is ubiquitous in genetics, where two main sources of confounding exist.
16+
Ancestry, can be estimated (here `W1`) and linkage disequilibrium is usually more challenging to address (here `W2`).
17+
18+
Let us first define some helper functions and import some libraries.
19+
=#
20+
using Distributions
21+
using Random
22+
using DataFrames
23+
using Statistics
24+
using CategoricalArrays
25+
using TMLE
26+
using CairoMakie
27+
using MLJXGBoostInterface
28+
using MLJ
29+
using MLJLinearModels
30+
Random.seed!(123)
31+
32+
μT(w) = [sum(w), sum(w)]
33+
34+
μY(t, w) = 1 + 10t[1] - 3t[2] * t[1] * w
35+
36+
#=
37+
We assume that `W1` and `W2`, the confounding variables, follow a uniform distribution.
38+
=#
39+
40+
generate_W(n) = rand(Uniform(0, 1), 2, n)
41+
42+
#=
43+
`T1` and `T2` are generated via a copula method through a multivariate normal to induce some statistical dependence (via σ).
44+
=#
45+
46+
function generate_T(W, n; σ=0.5, threshold=0)
47+
covariance = [
48+
1. σ
49+
σ 1.
50+
]
51+
T = zeros(Bool, 2, n)
52+
for i in 1:n
53+
dTi = MultivariateNormal(μT(W[:, i]), covariance)
54+
T[:, i] = rand(dTi) .> threshold
55+
end
56+
return T
57+
end
58+
59+
#=
60+
Finally, `Y` is generated through a simple linear model with an interaction term.
61+
=#
62+
63+
function generate_Y(T, W1, n; σY=1)
64+
Y = zeros(Float64, n)
65+
for i in 1:n
66+
dY = Normal(μY(T[:, i], W1[i]), σY)
67+
Y[i] = rand(dY)
68+
end
69+
return Y
70+
end
71+
72+
#=
73+
Importantly, the average interaction effect between `T1` and `T2` is thus ``-3 \mathbb{E}[W] = -1.5``.
74+
75+
We will generate a full dataset with the following function.
76+
=#
77+
78+
function generate_dataset(;n=1000, σ=0.5, threshold=0., σY=1)
79+
80+
W = generate_W(n)
81+
T = generate_T(W, n; σ=σ, threshold=threshold)
82+
83+
W = permutedims(W)
84+
W1 = W[:, 1]
85+
W2 = W[:, 2]
86+
87+
Y = generate_Y(T, W1, n; σY=σY)
88+
89+
T = permutedims(T)
90+
T1 = categorical(T[:, 1])
91+
T2 = categorical(T[:, 2])
92+
93+
return DataFrame(W1=W1, W2=W2, T1=T1, T2=T2, Y=Y)
94+
end
95+
96+
dataset = generate_dataset()
97+
98+
first(dataset, 5)
99+
#=
100+
Let's verify that each treatment level is sufficiently present in the dataset (≈positivity).
101+
=#
102+
103+
combine(groupby(dataset, [:T1, :T2]), proprow => :JOINT_TREATMENT_FREQ)
104+
105+
#=
106+
And that `T1` and `T2` are indeed correlated.
107+
=#
108+
109+
treatment_correlation(dataset) = cor(unwrap.(dataset.T1), unwrap.(dataset.T2))
110+
@assert treatment_correlation(dataset) > 0.2 #hide
111+
treatment_correlation(dataset)
112+
113+
#=
114+
## Estimation
115+
116+
We can now proceed to estimation using TMLE and default (linear) models.
117+
118+
Interactions are defined via the `AIE` function (note that we only set `W1` as a confounder).
119+
=#
120+
121+
Ψ = AIE(
122+
outcome=:Y,
123+
treatment_values= (
124+
T1=(case=1, control=0),
125+
T2=(case=1, control=0)
126+
),
127+
treatment_confounders = [:W1]
128+
)
129+
linear_models = default_models(G=LogisticClassifier(lambda=0), Q_continuous=LinearRegressor())
130+
estimator = TMLEE(models=linear_models, weighted=true)
131+
result, _ = estimator(Ψ, dataset; verbosity=0)
132+
@assert pvalue(significance_test(result, -1.5)) > 0.05 #hide
133+
significance_test(result)
134+
135+
#=
136+
The true effect size is thus covered by our confidence interval.
137+
138+
## Varying levels of correlation
139+
140+
We now vary the correlation level between `T1` and `T2` to observe how it affects the estimation results.
141+
First, let's see how the parameter σ affects the correlation between `T1` and `T2`.
142+
=#
143+
144+
function plot_correlations(;σs = 0.1:0.1:1, n=1000, threshold=0., σY=1.)
145+
fig = Figure()
146+
ax = Axis(fig[1, 1], xlabel="σ", ylabel="Correlation(T1, T2)")
147+
correlations = map(σs) do σ
148+
dataset = generate_dataset(;n=n, σ=σ, threshold=threshold, σY=σY)
149+
return treatment_correlation(dataset)
150+
end
151+
scatter!(ax, σs, correlations, color=:blue)
152+
return fig
153+
end
154+
155+
σs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99]
156+
plot_correlations(;σs=σs, n=10_000)
157+
158+
#=
159+
As expected, the correlation between `T1` and `T2` increases with σ. Let's see how this affects estimation,
160+
for this, we will vary both the dataset size and the correlation level.
161+
=#
162+
163+
function estimate_across_correlation_levels(σs; n=1000, estimator=TMLEE(weighted=true))
164+
results = []
165+
for σ in σs
166+
dataset = generate_dataset(n=n, σ=σ)
167+
result, _ = estimator(Ψ, dataset; verbosity=0)
168+
push!(results, result)
169+
end
170+
Ψ̂s = TMLE.estimate.(results)
171+
errors = last.(confint.(significance_test.(results))) .- Ψ̂s
172+
return Ψ̂s, errors
173+
end
174+
175+
function estimate_across_sample_sizes_and_correlation_levels(ns, σs; estimator=TMLEE(models=linear_models, weighted=true))
176+
results = []
177+
for n in ns
178+
Ψ̂s, errors = estimate_across_correlation_levels(σs; n=n, estimator=estimator)
179+
push!(results, (Ψ̂s, errors))
180+
end
181+
return results
182+
end
183+
184+
function plot_across_sample_sizes_and_correlation_levels(results, ns, σs; title="Estimation via TMLE (GLMs)")
185+
fig = Figure(size=(800, 800))
186+
for (index, n) in enumerate(ns)
187+
Ψ̂s, errors = results[index]
188+
ax = if n == last(ns)
189+
Axis(fig[index, 1], xlabel="σ", ylabel="AIE\n(n=$n)")
190+
else
191+
Axis(fig[index, 1], ylabel="AIE\n(n=$n)", xticklabelsvisible=false)
192+
end
193+
errorbars!(ax, σs, Ψ̂s, errors, color = :blue, whiskerwidth = 10)
194+
scatter!(ax, σs, Ψ̂s, color=:red, markersize=10)
195+
hlines!(ax, [-1.5], color=:green, linestyle=:dash)
196+
end
197+
Label(fig[0, :], title; tellwidth=false, fontsize=24)
198+
return fig
199+
end
200+
201+
ns = [100, 1000, 10_000, 100_000, 1_000_000]
202+
σs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.999]
203+
results = estimate_across_sample_sizes_and_correlation_levels(ns, σs; estimator=TMLEE(models=linear_models, weighted=true))
204+
plot_across_sample_sizes_and_correlation_levels(results, ns, σs; title="Estimation via TMLE (GLMs)")
205+
206+
#=
207+
First, notice that only extreme correlations (>0.9) tend to blow up the size of the confidence intervals. This implies that statistical power may be limited in such circumstances.
208+
209+
Furthermore, and perhaps unexpectedly, coverage decreases as sample size grows for larger correlations. Since we have used simple linear models until now,
210+
this could be due to model misspecification. We can verify this by using a more flexible modelling strategy. Here we will use XGBoost
211+
(with tree_method=`hist` to speed things up a little). Because this model is prone to overfitting we will also use cross-validation (this will take a few minutes).
212+
=#
213+
214+
xgboost_estimator = TMLEE(
215+
models=default_models(G=XGBoostClassifier(tree_method="hist"), Q_continuous=XGBoostRegressor(tree_method="hist")),
216+
weighted=true,
217+
resampling=StratifiedCV(nfolds=3)
218+
)
219+
xgboost_results = estimate_across_sample_sizes_and_correlation_levels(ns, σs, estimator=xgboost_estimator)
220+
plot_across_sample_sizes_and_correlation_levels(xgboost_results, ns, σs; title="Estimation via TMLE (XGboost)")
221+
222+
#=
223+
As expected, XGBoost improves estimation performance in the asymptotic regime, furthermore,
224+
the correlation between `T1` and `T2` seems harmless (except when σ > 0.9 as before).
225+
=#

src/TMLE.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using Statistics
1313
using Distributions
1414
using Zygote
1515
using LogExpFunctions
16-
using PrecompileTools
1716
using Random
1817
using DifferentiationInterface
1918
using Graphs

src/counterfactual_mean_based/estimands.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,14 @@ outcome_mean(Ψ::StatisticalCMCompositeEstimand) = ExpectedValue(Ψ.outcome, Tup
124124

125125
outcome_mean_key::StatisticalCMCompositeEstimand) = variables(outcome_mean(Ψ))
126126

127-
propensity_score::StatisticalCMCompositeEstimand) = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ))
127+
function propensity_score::StatisticalCMCompositeEstimand)
128+
Ψtreatments = TMLE.treatments(Ψ)
129+
return Tuple(map(eachindex(Ψtreatments)) do index
130+
T = Ψtreatments[index]
131+
confounders =.treatment_confounders[T]..., Ψtreatments[index+1:end]...)
132+
ConditionalDistribution(T, confounders)
133+
end)
134+
end
128135

129136
propensity_score_key::StatisticalCMCompositeEstimand) = Tuple(variables(x) for x propensity_score(Ψ))
130137

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1212
MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
1313
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
1414
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
15+
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
1516
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -29,6 +30,7 @@ CSV = "0.10"
2930
DataFrames = "1.5"
3031
MLJLinearModels = "0.10"
3132
StableRNGs = "1.0"
32-
StatisticalMeasures = "0.1.3"
33+
StatisticalMeasures = "0.2"
34+
MLJXGBoostInterface = "0.3"
3335
StatsBase = "0.34"
3436
YAML = "0.4.9"

test/counterfactual_mean_based/3points_interactions.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ end
3636
),
3737
treatment_confounders = (T₁=[:W], T₂=[:W], T₃=[:W])
3838
)
39+
## Check propensity score is well formed
40+
propensity_score = TMLE.propensity_score(Ψ)
41+
@test propensity_score[1] == TMLE.ConditionalDistribution(:T₁, (:T₂, :T₃, :W))
42+
@test propensity_score[2] == TMLE.ConditionalDistribution(:T₂, (:T₃, :W))
43+
@test propensity_score[3] == TMLE.ConditionalDistribution(:T₃, (:W,))
44+
## Define models
3945
models = Dict(
4046
:Y => with_encoder(InteractionTransformer(order=3) |> LinearRegressor()),
4147
:T₁ => LogisticClassifier(lambda=0),
4248
:T₂ => LogisticClassifier(lambda=0),
4349
:T₃ => LogisticClassifier(lambda=0)
4450
)
45-
51+
## Estimate
4652
tmle = TMLEE(models=models, machine_cache=true, max_iter=3, tol=0)
4753
result, cache = tmle(Ψ, dataset, verbosity=0);
4854
test_coverage(result, Ψ₀)
@@ -54,7 +60,7 @@ end
5460
test_coverage(result, Ψ₀)
5561
test_mean_inf_curve_almost_zero(result; atol=1e-10)
5662

57-
# CHecking cache accessors
63+
# Checking cache accessors
5864
@test length(gradients(cache)) == 3
5965
@test length(estimates(cache)) == 3
6066
@test length(epsilons(cache)) == 3

0 commit comments

Comments
 (0)