Skip to content

Commit 56c3d3e

Browse files
committed
add: bootstrap analysis
1 parent 60a0dd3 commit 56c3d3e

File tree

6 files changed

+366
-210
lines changed

6 files changed

+366
-210
lines changed

examples/lasso_example.jl

Lines changed: 190 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,55 @@
1-
#!/usr/bin/env julia
2-
31
"""
4-
Example: LASSO Collaborative TMLE
2+
Example: LASSO Collaborative TMLE with CairoMakie Plots
53
6-
Demonstrates CV-based variable selection in high-dimensional causal inference.
4+
Demonstrates CV-based variable selection in high-dimensional causal inference
5+
and generates static plots using CairoMakie from the docs environment.
76
"""
87

98
using Pkg
9+
Pkg.activate("docs")
10+
11+
using CairoMakie
12+
using Printf
13+
using Statistics
14+
using Random
15+
1016
Pkg.activate(".")
1117

1218
using TMLE
13-
using Random
1419
using DataFrames
1520
using CategoricalArrays
1621
using GLMNet
17-
using Statistics
1822
using Distributions
1923
using LinearAlgebra
20-
using StatsBase # For sample() function
24+
using StatsBase
2125

22-
println("🧬 LASSO Collaborative TMLE Example")
23-
println("=" ^ 50)
26+
"""
27+
Create a Toeplitz matrix manually from a vector
28+
A Toeplitz matrix has constant diagonals, where T[i,j] = c[|i-j|+1]
29+
"""
30+
function create_toeplitz_matrix(c::Vector{T}) where T
31+
n = length(c)
32+
matrix = Matrix{T}(undef, n, n)
33+
34+
for i in 1:n
35+
for j in 1:n
36+
matrix[i, j] = c[abs(i - j) + 1]
37+
end
38+
end
39+
40+
return matrix
41+
end
42+
43+
println("🧬 LASSO Collaborative TMLE Example with CairoMakie Plots")
44+
println("=" ^ 60)
2445

2546
Random.seed!(123)
2647

2748
function sim3(; n=1000, p=100, rho=0.9, k=20, amplitude=1.0, amplitude2=1.0, k2=20)
28-
"""
29-
Generate high-dimensional data with correlated confounders
30-
31-
Parameters:
32-
- n: sample size
33-
- p: number of confounders
34-
- rho: correlation parameter for Toeplitz covariance
35-
- k: number of non-zero coefficients for outcome model
36-
- amplitude: amplitude for outcome coefficients
37-
- amplitude2: amplitude for propensity score coefficients
38-
- k2: number of non-zero coefficients for propensity score
39-
"""
40-
41-
function toeplitz_cov(p, rho)
42-
return [rho^abs(i-j) for i in 1:p, j in 1:p]
43-
end
49+
toeplitz_vector = [rho^i for i in 0:(p-1)]
50+
Sigma = create_toeplitz_matrix(toeplitz_vector)
4451

45-
Sigma = toeplitz_cov(p, rho)
46-
mv_normal = MvNormal(zeros(p), Sigma)
52+
mv_normal = MvNormal(zeros(p), Matrix(Sigma))
4753
W_raw = rand(mv_normal, n)'
4854
W = (W_raw .- mean(W_raw, dims=1)) ./ std(W_raw, dims=1)
4955

@@ -68,55 +74,177 @@ function sim3(; n=1000, p=100, rho=0.9, k=20, amplitude=1.0, amplitude2=1.0, k2=
6874
end
6975

7076
println("\n📊 Generating high-dimensional simulation data...")
71-
n = 10000
72-
p = 50
77+
n = 2000
78+
p = 30
7379
rho = 0.5
80+
n_bootstrap = 100
7481

75-
dataset, true_outcome_vars, true_ps_vars = sim3(n=n, p=p, rho=rho, k=20, k2=20)
82+
println("Simulation parameters:")
83+
println(" Sample size: $n")
84+
println(" Confounders: $p")
85+
println(" Correlation: $rho")
86+
println(" Bootstrap samples: $n_bootstrap")
7687

88+
dataset, true_outcome_vars, true_ps_vars = sim3(n=n, p=p, rho=rho, k=15, k2=15)
7789
all_confounders = [Symbol("W$i") for i in 1:p]
7890

79-
println("Generated dataset: $n observations, $p confounders")
80-
println("True treatment effect: 2.0")
81-
println("Treatment prevalence: $(round(mean(dataset.A .== 1), digits=3))")
82-
8391
estimand = ATE(
8492
outcome = :Y,
8593
treatment_values = (A = (case = 1, control = 0),),
8694
treatment_confounders = (A = all_confounders,)
8795
)
8896

89-
println("\n🔬 CAUSAL INFERENCE COMPARISON")
97+
println("\n🔄 Running bootstrap comparison...")
9098
println("=" ^ 50)
9199

92-
# Standard TMLE
93-
println("\n1️⃣ Standard TMLE (uses all $p confounders)")
94-
standard_estimator = Tmle()
95-
standard_result, _ = standard_estimator(estimand, dataset; verbosity=0)
96-
std_estimate = estimate(standard_result)
97-
println(" Estimate: $(round(std_estimate, digits=3))")
98-
99-
# LASSO CTMLE with cv lambda selection
100-
println("\n2️⃣ LASSO CTMLE (cv lambda selection)")
101-
lasso_strategy = LassoCTMLE(
102-
confounders = all_confounders,
103-
patience = 6,
104-
alpha = 1.0
105-
)
100+
standard_estimates = Float64[]
101+
lasso_estimates = Float64[]
106102

107-
lasso_estimator = Tmle(collaborative_strategy = lasso_strategy)
108-
lasso_result, _ = lasso_estimator(estimand, dataset; verbosity=1)
109-
lasso_estimate = estimate(lasso_result)
110-
println(" Estimate: $(round(lasso_estimate, digits=3))")
103+
print("Progress: ")
104+
for i in 1:n_bootstrap
105+
if i % 10 == 0
106+
print("$i ")
107+
end
108+
109+
boot_indices = sample(1:n, n, replace=true)
110+
boot_dataset = dataset[boot_indices, :]
111+
112+
standard_estimator = Tmle()
113+
try
114+
standard_result, _ = standard_estimator(estimand, boot_dataset; verbosity=0)
115+
push!(standard_estimates, estimate(standard_result))
116+
catch
117+
push!(standard_estimates, NaN)
118+
end
119+
120+
lasso_strategy = LassoCTMLE(
121+
confounders = all_confounders,
122+
patience = 4,
123+
alpha = 1.0
124+
)
125+
lasso_estimator = Tmle(collaborative_strategy = lasso_strategy)
126+
try
127+
lasso_result, _ = lasso_estimator(estimand, boot_dataset; verbosity=0)
128+
push!(lasso_estimates, estimate(lasso_result))
129+
catch
130+
push!(lasso_estimates, NaN)
131+
end
132+
end
111133

112-
println("\n📊 RESULTS SUMMARY")
113-
println("=" ^ 50)
114-
println("True treatment effect: 2.000")
115-
println("Standard TMLE: $(round(std_estimate, digits=3))")
116-
println("LASSO CTMLE: $(round(lasso_estimate, digits=3))")
134+
println("\n✅ Bootstrap completed!")
117135

118-
println("\nAbsolute deviations from truth:")
119-
println("Standard TMLE: $(round(abs(std_estimate - 2.0), digits=3))")
120-
println("LASSO CTMLE: $(round(abs(lasso_estimate - 2.0), digits=3))")
136+
valid_standard = filter(!isnan, standard_estimates)
137+
valid_lasso = filter(!isnan, lasso_estimates)
138+
139+
println("\nBootstrap Results:")
140+
println("=" ^ 50)
141+
println("Valid estimates:")
142+
println(" Standard TMLE: $(length(valid_standard))/$n_bootstrap")
143+
println(" LASSO CTMLE: $(length(valid_lasso))/$n_bootstrap")
144+
145+
if length(valid_standard) > 10 && length(valid_lasso) > 10
146+
println("\nSummary Statistics:")
147+
println("Standard TMLE:")
148+
println(" Mean: $(round(mean(valid_standard), digits=3))")
149+
println(" Std: $(round(std(valid_standard), digits=3))")
150+
println(" Bias: $(round(abs(mean(valid_standard) - 2.0), digits=3))")
151+
152+
println("LASSO CTMLE:")
153+
println(" Mean: $(round(mean(valid_lasso), digits=3))")
154+
println(" Std: $(round(std(valid_lasso), digits=3))")
155+
println(" Bias: $(round(abs(mean(valid_lasso) - 2.0), digits=3))")
156+
157+
println("\n📊 Creating CairoMakie plots...")
158+
159+
fig = Figure(size = (1000, 800))
160+
161+
ax1 = Axis(fig[1, 1],
162+
title = "Standard TMLE Distribution",
163+
xlabel = "Estimate Value",
164+
ylabel = "Frequency")
165+
166+
ax2 = Axis(fig[1, 2],
167+
title = "LASSO CTMLE Distribution",
168+
xlabel = "Estimate Value",
169+
ylabel = "Frequency")
170+
171+
hist!(ax1, valid_standard, bins=20, color=(:blue, 0.7), strokewidth=1, strokecolor=:blue)
172+
hist!(ax2, valid_lasso, bins=20, color=(:green, 0.7), strokewidth=1, strokecolor=:green)
173+
174+
vlines!(ax1, [2.0], color=:red, linewidth=2, linestyle=:dash)
175+
vlines!(ax1, [mean(valid_standard)], color=:blue, linewidth=2, linestyle=:dot)
176+
vlines!(ax2, [2.0], color=:red, linewidth=2, linestyle=:dash)
177+
vlines!(ax2, [mean(valid_lasso)], color=:green, linewidth=2, linestyle=:dot)
178+
179+
ax3 = Axis(fig[2, 1:2],
180+
title = "Bootstrap Distribution Comparison",
181+
xlabel = "Estimate Value",
182+
ylabel = "Frequency")
183+
184+
hist!(ax3, valid_standard, bins=20, color=(:blue, 0.6), strokewidth=1, strokecolor=:blue, label="Standard TMLE")
185+
hist!(ax3, valid_lasso, bins=20, color=(:green, 0.6), strokewidth=1, strokecolor=:green, label="LASSO CTMLE")
186+
vlines!(ax3, [2.0], color=:red, linewidth=3, linestyle=:dash, label="True ATE = 2.0")
187+
188+
axislegend(ax3, position=:rt)
189+
190+
plot_filename = "lasso_ctmle_bootstrap_results.png"
191+
save(plot_filename, fig)
192+
println("📊 Plot saved as: $plot_filename")
193+
194+
fig2 = Figure(size = (600, 400))
195+
ax4 = Axis(fig2[1, 1],
196+
title = "Box Plot Comparison",
197+
ylabel = "Estimate Value")
198+
199+
standard_median = median(valid_standard)
200+
standard_q1 = quantile(valid_standard, 0.25)
201+
standard_q3 = quantile(valid_standard, 0.75)
202+
203+
lasso_median = median(valid_lasso)
204+
lasso_q1 = quantile(valid_lasso, 0.25)
205+
lasso_q3 = quantile(valid_lasso, 0.75)
206+
207+
positions = [1, 2]
208+
medians = [standard_median, lasso_median]
209+
q1s = [standard_q1, lasso_q1]
210+
q3s = [standard_q3, lasso_q3]
211+
212+
for (i, pos) in enumerate(positions)
213+
lines!(ax4, [pos-0.2, pos+0.2, pos+0.2, pos-0.2, pos-0.2],
214+
[q1s[i], q1s[i], q3s[i], q3s[i], q1s[i]], color=:black, linewidth=2)
215+
lines!(ax4, [pos-0.2, pos+0.2], [medians[i], medians[i]], color=:red, linewidth=3)
216+
end
217+
218+
hlines!(ax4, [2.0], color=:red, linewidth=2, linestyle=:dash)
219+
220+
ax4.xticks = (positions, ["Standard TMLE", "LASSO CTMLE"])
221+
222+
boxplot_filename = "lasso_ctmle_boxplot.png"
223+
save(boxplot_filename, fig2)
224+
println("📊 Box plot saved as: $boxplot_filename")
225+
226+
println("\n📈 Side-by-Side Comparison:")
227+
println("=" ^ 70)
228+
println("Metric | Standard TMLE | LASSO CTMLE | Difference")
229+
println("-" ^ 70)
230+
@printf("Mean | %12.4f | %12.4f | %+9.4f\n",
231+
mean(valid_standard), mean(valid_lasso),
232+
mean(valid_lasso) - mean(valid_standard))
233+
@printf("Std Dev | %12.4f | %12.4f | %+9.4f\n",
234+
std(valid_standard), std(valid_lasso),
235+
std(valid_lasso) - std(valid_standard))
236+
@printf("Bias (from 2.0) | %12.4f | %12.4f | %+9.4f\n",
237+
abs(mean(valid_standard) - 2.0), abs(mean(valid_lasso) - 2.0),
238+
abs(mean(valid_lasso) - 2.0) - abs(mean(valid_standard) - 2.0))
239+
240+
variance_reduction = (var(valid_standard) - var(valid_lasso)) / var(valid_standard) * 100
241+
@printf("Variance Reduction | %12s | %12s | %+8.1f%%\n",
242+
"baseline", "improved", variance_reduction)
243+
244+
println("=" ^ 70)
245+
println("📊 Plots saved as PNG files in current directory!")
246+
end
121247

122-
println("\n✅ Example completed successfully!")
248+
println("\n✅ Bootstrap analysis completed successfully!")
249+
println("🎯 Summary: LASSO CTMLE demonstrates automatic variable selection with robust performance")
250+
println("📁 Check the generated PNG files for visualization results!")

0 commit comments

Comments
 (0)