1- # !/usr/bin/env julia
2-
31"""
4- Example: LASSO Collaborative TMLE
2+ Example: LASSO Collaborative TMLE with CairoMakie Plots
53
6- Demonstrates CV-based variable selection in high-dimensional causal inference.
4+ Demonstrates CV-based variable selection in high-dimensional causal inference
5+ and generates static plots using CairoMakie from the docs environment.
76"""
87
98using Pkg
9+ Pkg. activate (" docs" )
10+
11+ using CairoMakie
12+ using Printf
13+ using Statistics
14+ using Random
15+
1016Pkg. activate (" ." )
1117
1218using TMLE
13- using Random
1419using DataFrames
1520using CategoricalArrays
1621using GLMNet
17- using Statistics
1822using Distributions
1923using LinearAlgebra
20- using StatsBase # For sample() function
24+ using StatsBase
2125
22- println (" 🧬 LASSO Collaborative TMLE Example" )
23- println (" =" ^ 50 )
26+ """
27+ Create a Toeplitz matrix manually from a vector
28+ A Toeplitz matrix has constant diagonals, where T[i,j] = c[|i-j|+1]
29+ """
30+ function create_toeplitz_matrix (c:: Vector{T} ) where T
31+ n = length (c)
32+ matrix = Matrix {T} (undef, n, n)
33+
34+ for i in 1 : n
35+ for j in 1 : n
36+ matrix[i, j] = c[abs (i - j) + 1 ]
37+ end
38+ end
39+
40+ return matrix
41+ end
42+
43+ println (" 🧬 LASSO Collaborative TMLE Example with CairoMakie Plots" )
44+ println (" =" ^ 60 )
2445
2546Random. seed! (123 )
2647
2748function sim3 (; n= 1000 , p= 100 , rho= 0.9 , k= 20 , amplitude= 1.0 , amplitude2= 1.0 , k2= 20 )
28- """
29- Generate high-dimensional data with correlated confounders
30-
31- Parameters:
32- - n: sample size
33- - p: number of confounders
34- - rho: correlation parameter for Toeplitz covariance
35- - k: number of non-zero coefficients for outcome model
36- - amplitude: amplitude for outcome coefficients
37- - amplitude2: amplitude for propensity score coefficients
38- - k2: number of non-zero coefficients for propensity score
39- """
40-
41- function toeplitz_cov (p, rho)
42- return [rho^ abs (i- j) for i in 1 : p, j in 1 : p]
43- end
49+ toeplitz_vector = [rho^ i for i in 0 : (p- 1 )]
50+ Sigma = create_toeplitz_matrix (toeplitz_vector)
4451
45- Sigma = toeplitz_cov (p, rho)
46- mv_normal = MvNormal (zeros (p), Sigma)
52+ mv_normal = MvNormal (zeros (p), Matrix (Sigma))
4753 W_raw = rand (mv_normal, n)'
4854 W = (W_raw .- mean (W_raw, dims= 1 )) ./ std (W_raw, dims= 1 )
4955
@@ -68,55 +74,177 @@ function sim3(; n=1000, p=100, rho=0.9, k=20, amplitude=1.0, amplitude2=1.0, k2=
6874end
6975
7076println (" \n 📊 Generating high-dimensional simulation data..." )
71- n = 10000
72- p = 50
77+ n = 2000
78+ p = 30
7379rho = 0.5
80+ n_bootstrap = 100
7481
75- dataset, true_outcome_vars, true_ps_vars = sim3 (n= n, p= p, rho= rho, k= 20 , k2= 20 )
82+ println (" Simulation parameters:" )
83+ println (" Sample size: $n " )
84+ println (" Confounders: $p " )
85+ println (" Correlation: $rho " )
86+ println (" Bootstrap samples: $n_bootstrap " )
7687
88+ dataset, true_outcome_vars, true_ps_vars = sim3 (n= n, p= p, rho= rho, k= 15 , k2= 15 )
7789all_confounders = [Symbol (" W$i " ) for i in 1 : p]
7890
79- println (" Generated dataset: $n observations, $p confounders" )
80- println (" True treatment effect: 2.0" )
81- println (" Treatment prevalence: $(round (mean (dataset. A .== 1 ), digits= 3 )) " )
82-
8391estimand = ATE (
8492 outcome = :Y ,
8593 treatment_values = (A = (case = 1 , control = 0 ),),
8694 treatment_confounders = (A = all_confounders,)
8795)
8896
89- println (" \n 🔬 CAUSAL INFERENCE COMPARISON " )
97+ println (" \n 🔄 Running bootstrap comparison... " )
9098println (" =" ^ 50 )
9199
92- # Standard TMLE
93- println (" \n 1️⃣ Standard TMLE (uses all $p confounders)" )
94- standard_estimator = Tmle ()
95- standard_result, _ = standard_estimator (estimand, dataset; verbosity= 0 )
96- std_estimate = estimate (standard_result)
97- println (" Estimate: $(round (std_estimate, digits= 3 )) " )
98-
99- # LASSO CTMLE with cv lambda selection
100- println (" \n 2️⃣ LASSO CTMLE (cv lambda selection)" )
101- lasso_strategy = LassoCTMLE (
102- confounders = all_confounders,
103- patience = 6 ,
104- alpha = 1.0
105- )
100+ standard_estimates = Float64[]
101+ lasso_estimates = Float64[]
106102
107- lasso_estimator = Tmle (collaborative_strategy = lasso_strategy)
108- lasso_result, _ = lasso_estimator (estimand, dataset; verbosity= 1 )
109- lasso_estimate = estimate (lasso_result)
110- println (" Estimate: $(round (lasso_estimate, digits= 3 )) " )
103+ print (" Progress: " )
104+ for i in 1 : n_bootstrap
105+ if i % 10 == 0
106+ print (" $i " )
107+ end
108+
109+ boot_indices = sample (1 : n, n, replace= true )
110+ boot_dataset = dataset[boot_indices, :]
111+
112+ standard_estimator = Tmle ()
113+ try
114+ standard_result, _ = standard_estimator (estimand, boot_dataset; verbosity= 0 )
115+ push! (standard_estimates, estimate (standard_result))
116+ catch
117+ push! (standard_estimates, NaN )
118+ end
119+
120+ lasso_strategy = LassoCTMLE (
121+ confounders = all_confounders,
122+ patience = 4 ,
123+ alpha = 1.0
124+ )
125+ lasso_estimator = Tmle (collaborative_strategy = lasso_strategy)
126+ try
127+ lasso_result, _ = lasso_estimator (estimand, boot_dataset; verbosity= 0 )
128+ push! (lasso_estimates, estimate (lasso_result))
129+ catch
130+ push! (lasso_estimates, NaN )
131+ end
132+ end
111133
112- println (" \n 📊 RESULTS SUMMARY" )
113- println (" =" ^ 50 )
114- println (" True treatment effect: 2.000" )
115- println (" Standard TMLE: $(round (std_estimate, digits= 3 )) " )
116- println (" LASSO CTMLE: $(round (lasso_estimate, digits= 3 )) " )
134+ println (" \n ✅ Bootstrap completed!" )
117135
118- println (" \n Absolute deviations from truth:" )
119- println (" Standard TMLE: $(round (abs (std_estimate - 2.0 ), digits= 3 )) " )
120- println (" LASSO CTMLE: $(round (abs (lasso_estimate - 2.0 ), digits= 3 )) " )
136+ valid_standard = filter (! isnan, standard_estimates)
137+ valid_lasso = filter (! isnan, lasso_estimates)
138+
139+ println (" \n Bootstrap Results:" )
140+ println (" =" ^ 50 )
141+ println (" Valid estimates:" )
142+ println (" Standard TMLE: $(length (valid_standard)) /$n_bootstrap " )
143+ println (" LASSO CTMLE: $(length (valid_lasso)) /$n_bootstrap " )
144+
145+ if length (valid_standard) > 10 && length (valid_lasso) > 10
146+ println (" \n Summary Statistics:" )
147+ println (" Standard TMLE:" )
148+ println (" Mean: $(round (mean (valid_standard), digits= 3 )) " )
149+ println (" Std: $(round (std (valid_standard), digits= 3 )) " )
150+ println (" Bias: $(round (abs (mean (valid_standard) - 2.0 ), digits= 3 )) " )
151+
152+ println (" LASSO CTMLE:" )
153+ println (" Mean: $(round (mean (valid_lasso), digits= 3 )) " )
154+ println (" Std: $(round (std (valid_lasso), digits= 3 )) " )
155+ println (" Bias: $(round (abs (mean (valid_lasso) - 2.0 ), digits= 3 )) " )
156+
157+ println (" \n 📊 Creating CairoMakie plots..." )
158+
159+ fig = Figure (size = (1000 , 800 ))
160+
161+ ax1 = Axis (fig[1 , 1 ],
162+ title = " Standard TMLE Distribution" ,
163+ xlabel = " Estimate Value" ,
164+ ylabel = " Frequency" )
165+
166+ ax2 = Axis (fig[1 , 2 ],
167+ title = " LASSO CTMLE Distribution" ,
168+ xlabel = " Estimate Value" ,
169+ ylabel = " Frequency" )
170+
171+ hist! (ax1, valid_standard, bins= 20 , color= (:blue , 0.7 ), strokewidth= 1 , strokecolor= :blue )
172+ hist! (ax2, valid_lasso, bins= 20 , color= (:green , 0.7 ), strokewidth= 1 , strokecolor= :green )
173+
174+ vlines! (ax1, [2.0 ], color= :red , linewidth= 2 , linestyle= :dash )
175+ vlines! (ax1, [mean (valid_standard)], color= :blue , linewidth= 2 , linestyle= :dot )
176+ vlines! (ax2, [2.0 ], color= :red , linewidth= 2 , linestyle= :dash )
177+ vlines! (ax2, [mean (valid_lasso)], color= :green , linewidth= 2 , linestyle= :dot )
178+
179+ ax3 = Axis (fig[2 , 1 : 2 ],
180+ title = " Bootstrap Distribution Comparison" ,
181+ xlabel = " Estimate Value" ,
182+ ylabel = " Frequency" )
183+
184+ hist! (ax3, valid_standard, bins= 20 , color= (:blue , 0.6 ), strokewidth= 1 , strokecolor= :blue , label= " Standard TMLE" )
185+ hist! (ax3, valid_lasso, bins= 20 , color= (:green , 0.6 ), strokewidth= 1 , strokecolor= :green , label= " LASSO CTMLE" )
186+ vlines! (ax3, [2.0 ], color= :red , linewidth= 3 , linestyle= :dash , label= " True ATE = 2.0" )
187+
188+ axislegend (ax3, position= :rt )
189+
190+ plot_filename = " lasso_ctmle_bootstrap_results.png"
191+ save (plot_filename, fig)
192+ println (" 📊 Plot saved as: $plot_filename " )
193+
194+ fig2 = Figure (size = (600 , 400 ))
195+ ax4 = Axis (fig2[1 , 1 ],
196+ title = " Box Plot Comparison" ,
197+ ylabel = " Estimate Value" )
198+
199+ standard_median = median (valid_standard)
200+ standard_q1 = quantile (valid_standard, 0.25 )
201+ standard_q3 = quantile (valid_standard, 0.75 )
202+
203+ lasso_median = median (valid_lasso)
204+ lasso_q1 = quantile (valid_lasso, 0.25 )
205+ lasso_q3 = quantile (valid_lasso, 0.75 )
206+
207+ positions = [1 , 2 ]
208+ medians = [standard_median, lasso_median]
209+ q1s = [standard_q1, lasso_q1]
210+ q3s = [standard_q3, lasso_q3]
211+
212+ for (i, pos) in enumerate (positions)
213+ lines! (ax4, [pos- 0.2 , pos+ 0.2 , pos+ 0.2 , pos- 0.2 , pos- 0.2 ],
214+ [q1s[i], q1s[i], q3s[i], q3s[i], q1s[i]], color= :black , linewidth= 2 )
215+ lines! (ax4, [pos- 0.2 , pos+ 0.2 ], [medians[i], medians[i]], color= :red , linewidth= 3 )
216+ end
217+
218+ hlines! (ax4, [2.0 ], color= :red , linewidth= 2 , linestyle= :dash )
219+
220+ ax4. xticks = (positions, [" Standard TMLE" , " LASSO CTMLE" ])
221+
222+ boxplot_filename = " lasso_ctmle_boxplot.png"
223+ save (boxplot_filename, fig2)
224+ println (" 📊 Box plot saved as: $boxplot_filename " )
225+
226+ println (" \n 📈 Side-by-Side Comparison:" )
227+ println (" =" ^ 70 )
228+ println (" Metric | Standard TMLE | LASSO CTMLE | Difference" )
229+ println (" -" ^ 70 )
230+ @printf (" Mean | %12.4f | %12.4f | %+9.4f\n " ,
231+ mean (valid_standard), mean (valid_lasso),
232+ mean (valid_lasso) - mean (valid_standard))
233+ @printf (" Std Dev | %12.4f | %12.4f | %+9.4f\n " ,
234+ std (valid_standard), std (valid_lasso),
235+ std (valid_lasso) - std (valid_standard))
236+ @printf (" Bias (from 2.0) | %12.4f | %12.4f | %+9.4f\n " ,
237+ abs (mean (valid_standard) - 2.0 ), abs (mean (valid_lasso) - 2.0 ),
238+ abs (mean (valid_lasso) - 2.0 ) - abs (mean (valid_standard) - 2.0 ))
239+
240+ variance_reduction = (var (valid_standard) - var (valid_lasso)) / var (valid_standard) * 100
241+ @printf (" Variance Reduction | %12s | %12s | %+8.1f%%\n " ,
242+ " baseline" , " improved" , variance_reduction)
243+
244+ println (" =" ^ 70 )
245+ println (" 📊 Plots saved as PNG files in current directory!" )
246+ end
121247
122- println (" \n ✅ Example completed successfully!" )
248+ println (" \n ✅ Bootstrap analysis completed successfully!" )
249+ println (" 🎯 Summary: LASSO CTMLE demonstrates automatic variable selection with robust performance" )
250+ println (" 📁 Check the generated PNG files for visualization results!" )
0 commit comments