Skip to content

Commit 5d4a8e9

Browse files
Merge pull request #109 from TARGENE/cv_check_and_doc
Cv check and doc
2 parents 3d27968 + d3ba349 commit 5d4a8e9

File tree

17 files changed

+623
-124
lines changed

17 files changed

+623
-124
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TMLE"
22
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
33
authors = ["Olivier Labayle"]
4-
version = "0.16.0"
4+
version = "0.16.1"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ makedocs(;
3434
joinpath("examples", "super_learning.md"),
3535
joinpath("examples", "double_robustness.md")
3636
],
37+
"Estimators' Cheat Sheet" => "estimators_cheatsheet.md",
3738
"Resources" => "resources.md",
3839
"API Reference" => "api.md"
3940
],
18.9 KB
Loading

docs/src/estimators_cheatsheet.md

Lines changed: 304 additions & 0 deletions
Large diffs are not rendered by default.

docs/src/index.md

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ CurrentModule = TMLE
66

77
## Overview
88

9-
TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in efficient and unbiased estimation of causal effects, you are in the right place. Since TMLE uses machine-learning methods to estimate nuisance estimands, the present package is based upon [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/).
9+
TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in leveraging the power of modern machine-learning methods while preserving interpretability and statistical inference guarantees, you are in the right place. TMLE.jl is compatible with any [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) compliant algorithm and any dataset respecting the [Tables](https://tables.juliadata.org/stable/) interface.
1010

1111
## Installation
1212

@@ -20,7 +20,7 @@ Pkg> add TMLE
2020

2121
To run an estimation procedure, we need 3 ingredients:
2222

23-
1. A dataset: here a simulation dataset.
23+
### 1. A dataset: here a simulation dataset
2424

2525
For illustration, assume we know the actual data generating process is as follows:
2626

@@ -52,7 +52,7 @@ dataset = (Y=Y, T=categorical(T), W=W)
5252
nothing # hide
5353
```
5454

55-
2. A quantity of interest: here the Average Treatment Effect (ATE).
55+
### 2. A quantity of interest: here the Average Treatment Effect (ATE)
5656

5757
The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as:
5858

@@ -64,7 +64,7 @@ The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as
6464
)
6565
```
6666

67-
3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE).
67+
### 3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE)
6868

6969
```@example quick-start
7070
tmle = TMLEE()
@@ -79,3 +79,35 @@ using Test # hide
7979
@test pvalue(OneSampleTTest(result, 2.5)) > 0.05 # hide
8080
nothing # hide
8181
```
82+
83+
## Scope and Distinguishing Features
84+
85+
The goal of this package is to provide an entry point for semi-parametric asymptotic unbiased and efficient estimation in Julia. The two main general estimators that are known to achieve these properties are the One-Step estimator and the Targeted Maximum-Likelihood estimator. Most of the current effort has been centered around estimands that are composite of the counterfactual mean.
86+
87+
Distinguishing Features:
88+
89+
- Estimands: Counterfactual Mean, Average Treatment Effect, Interactions, Any composition thereof
90+
- Estimators: TMLE, One-Step, in both canonical and cross-validated versions.
91+
- Machine-Learning: Any [MLJ](https://alan-turing-institute.github.io/MLJ.jl/stable/) compatible model
92+
- Dataset: Any dataset respecting the [Tables](https://tables.juliadata.org/stable/) interface (e.g. [DataFrames.jl](https://dataframes.juliadata.org/stable/))
93+
- Factorial Treatment Variables:
94+
- Multiple treatments
95+
- Categorical treatment values
96+
97+
## Citing TMLE.jl
98+
99+
If you use TMLE.jl for your own work and would like to cite us, here are the BibTeX and APA formats:
100+
101+
- BibTeX
102+
103+
```bibtex
104+
@software{Labayle_TMLE_jl,
105+
author = {Labayle, Olivier and Beentjes, Sjoerd and Khamseh, Ava and Ponting, Chris},
106+
title = {{TMLE.jl}},
107+
url = {https://github.com/olivierlabayle/TMLE.jl}
108+
}
109+
```
110+
111+
- APA
112+
113+
Labayle, O., Beentjes, S., Khamseh, A., & Ponting, C. TMLE.jl [Computer software]. https://github.com/olivierlabayle/TMLE.jl

docs/src/resources.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ These are two very clear introductions to causal inference and semi-parametric e
99
- [Introduction to Modern Causal Inference](https://alejandroschuler.github.io/mci/) (Alejandro Schuler, Mark J. van der Laan).
1010
- [A Ride in Targeted Learning Territory](https://achambaz.github.io/tlride/) (David Benkeser, Antoine Chambaz).
1111

12-
## Text Books
12+
## Youtube
13+
14+
- [Targeted Learning Webinar Series](https://youtube.com/playlist?list=PLy_CaFomwGGGH10tbq9zSyfHVrdklMaLe&si=BfJZ2fvDtGUZwELy)
15+
- [TL Briefs](https://youtube.com/playlist?list=PLy_CaFomwGGFMxFtf4gkmC70dP9J6Q3Wa&si=aBZUnjJtOidIjhwR)
16+
17+
## Books and Lecture Notes
1318

1419
- [Targeted Learning](https://link.springer.com/book/10.1007/978-1-4419-9782-1) (Mark J. van der Laan, Sherri Rose).
20+
- [STATS 361: Causal Inference](https://web.stanford.edu/~swager/stats361.pdf)
1521

1622
## Journal articles
1723

docs/src/user_guide/estimation.md

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ CurrentModule = TMLE
44

55
# Estimation
66

7-
## Estimating a single Estimand
7+
## Constructing and Using Estimators
88

99
```@setup estimation
1010
using Random
@@ -51,11 +51,12 @@ scm = SCM([
5151
)
5252
```
5353

54-
Once a statistical estimand has been defined, we can proceed with estimation. At the moment, we provide 3 main types of estimators:
54+
Once a statistical estimand has been defined, we can proceed with estimation. There are two semi-parametric efficient estimators in TMLE.jl:
5555

56-
- Targeted Maximum Likelihood Estimator (`TMLEE`)
57-
- One-Step Estimator (`OSE`)
58-
- Naive Plugin Estimator (`NAIVE`)
56+
- The Targeted Maximum-Likelihood Estimator (`TMLEE`)
57+
- The One-Step Estimator (`OSE`)
58+
59+
While they have similar asymptotic properties, their finite sample performance may be different. They also have a very distinguishing feature, the TMLE is a plugin estimator, which means it respects the natural bounds of the estimand of interest. In contrast, the OSE may in theory report values outside these bounds. In practice, this is not often the case and the estimand of interest may not impose any restriction on its domain.
5960

6061
Drawing from the example dataset and `SCM` from the Walk Through section, we can estimate the ATE for `T₁`. Let's use TMLE:
6162

@@ -72,27 +73,25 @@ result₁
7273
nothing # hide
7374
```
7475

75-
We see that both models corresponding to variables `Y` and `T₁` were fitted in the process but that the model for `T₂` was not because it was not necessary to estimate this estimand.
76-
77-
The `cache` contains estimates for the nuisance functions that were necessary to estimate the ATE. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate.
76+
The `cache` (see below) contains estimates for the nuisance functions that were necessary to estimate the ATE. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate.
7877

7978
```@example estimation
8079
ϵ = last_fluctuation_epsilon(cache)
8180
```
8281

83-
The `result₁` structure corresponds to the estimation result and should report 3 main elements:
82+
The `result₁` structure corresponds to the estimation result and will display the result of a T-Test including:
8483

8584
- A point estimate.
8685
- A 95% confidence interval.
8786
- A p-value (Corresponding to the test that the estimand is different than 0).
8887

89-
This is only summary statistics but since both the TMLE and OSE are asymptotically linear estimators, standard Z/T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed.
88+
Both the TMLE and OSE are asymptotically linear estimators, standard Z/T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed and `confint` and `pvalue` methods used.
9089

9190
```@example estimation
92-
tmle_test_result₁ = OneSampleTTest(result₁)
91+
tmle_test_result₁ = pvalue(OneSampleTTest(result₁))
9392
```
9493

95-
We could now get an interest in the Average Treatment Effect of `T₂` that we will estimate with an `OSE`:
94+
Let us now turn to the Average Treatment Effect of `T₂`, we will estimate it with a `OSE`:
9695

9796
```@example estimation
9897
Ψ₂ = ATE(
@@ -109,24 +108,73 @@ nothing # hide
109108

110109
Again, required nuisance functions are fitted and stored in the cache.
111110

112-
## CV-Estimation
111+
## Specifying Models
113112

114-
Both TMLE and OSE can be used with sample-splitting, which, for an additional computational cost, further reduces the assumptions we need to make regarding our data generating process ([see here](https://arxiv.org/abs/2203.06469)). Note that this sample-splitting procedure should not be confused with the sample-splitting happening in Super Learning. Using both CV-TMLE and Super-Learning will result in two nested sample-splitting loops.
113+
By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is essentially a `NamedTuple` mapping variables' names to their respective model.
115114

116-
To leverage sample-splitting, simply specify a `resampling` strategy when building an estimator:
115+
Rather than specifying a specific model for each variable it may be easier to override the default models using the `default_models` function:
116+
117+
For example one can override all default models with XGBoost models from `MLJXGBoostInterface`:
117118

118119
```@example estimation
119-
cvtmle = TMLEE(resampling=CV())
120-
cvresult₁, _ = cvtmle(Ψ₁, dataset);
120+
using MLJXGBoostInterface
121+
xgboost_regressor = XGBoostRegressor()
122+
xgboost_classifier = XGBoostClassifier()
123+
models = default_models(
124+
Q_binary=xgboost_classifier,
125+
Q_continuous=xgboost_regressor,
126+
G=xgboost_classifier
127+
)
128+
tmle_gboost = TMLEE(models=models)
121129
```
122130

123-
Similarly, one could build CV-OSE:
131+
The advantage of using `default_models` is that it will automatically prepend each model with a [ContinuousEncoder](https://alan-turing-institute.github.io/MLJ.jl/dev/transformers/#MLJModels.ContinuousEncoder) to make sure the correct types are passed to the downstream models.
132+
133+
Super Learning ([Stack](https://alan-turing-institute.github.io/MLJ.jl/dev/model_stacking/#Model-Stacking)) as well as variable specific models can be defined as well. Here is a more customized version:
134+
135+
```@example estimation
136+
lr = LogisticClassifier(lambda=0.)
137+
stack_binary = Stack(
138+
metalearner=lr,
139+
xgboost=xgboost_classifier,
140+
lr=lr
141+
)
142+
143+
models = (
144+
T₁ = with_encoder(xgboost_classifier), # T₁ with XGBoost prepended with a Continuous Encoder
145+
default_models( # For all other variables use the following defaults
146+
Q_binary=stack_binary, # A Super Learner
147+
Q_continuous=xgboost_regressor, # An XGBoost
148+
# Unspecified G defaults to Logistic Regression
149+
)...
150+
)
151+
152+
tmle_custom = TMLEE(models=models)
153+
```
154+
155+
Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `NamedTuple`.
156+
157+
## CV-Estimation
158+
159+
Canonical TMLE/OSE are essentially using the dataset twice, once for the estimation of the nuisance functions and once for the estimation of the parameter of interest. This means that there is a risk of over-fitting and residual bias ([see here](https://arxiv.org/abs/2203.06469) for some discussion). One way to address this limitation is to use a technique called sample-splitting / cross-validating. In order to activate the sample-splitting mode, simply provide a `MLJ.ResamplingStrategy` using the `resampling` keyword argument:
160+
161+
```@example estimation
162+
TMLEE(resampling=StratifiedCV());
163+
```
164+
165+
or
124166

125167
```julia
126-
cvose = OSE(resampling=CV(nfolds=3))
168+
OSE(resampling=StratifiedCV(nfolds=3));
127169
```
128170

129-
## Caching model fits
171+
There are some practical considerations
172+
173+
- Choice of `resampling` Strategy: The theory behind sample-splitting requires the nuisance functions to be sufficiently well estimated on **each and every** fold. A practical aspect of it is that each fold should contain a sample representative of the dataset. In particular, when the treatment and outcome variables are categorical it is important to make sure the proportions are preserved. This is typically done using `StratifiedCV`.
174+
- Computational Complexity: Sample-splitting results in ``K`` fits of the nuisance functions, drastically increasing computational complexity. In particular, if the nuisance functions are estimated using (P-fold) Super-Learning, this will result in two nested cross-validation loops and ``K \times P`` fits.
175+
- Caching of Nuisance Functions: Because the `resampling` strategy typically needs to preserve the outcome and treatment proportions, very little reuse of cached models is possible (see [Caching Models](@ref)).
176+
177+
## Caching Models
130178

131179
Let's now see how the `cache` can be reused with a new estimand, say the Total Average Treatment Effect of both `T₁` and `T₂`.
132180

src/counterfactual_mean_based/estimators.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,17 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict()
221221
machine_cache=tmle.machine_cache
222222
)
223223
# Estimation results after TMLE
224-
IC, Ψ̂ = gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
224+
IC, Ψ̂ = gradient_and_estimate(tmle, Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
225225
σ̂ = std(IC)
226226
n = size(IC, 1)
227227
verbosity >= 1 && @info "Done."
228228
# update!(cache, relevant_factors, targeted_factors_estimate)
229229
return TMLEstimate(Ψ, Ψ̂, σ̂, n, IC), cache
230230
end
231231

232+
gradient_and_estimate(::TMLEE, Ψ, factors, dataset; ps_lowerbound=1e-8) =
233+
gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound)
234+
232235
#####################################################################
233236
### OSE ###
234237
#####################################################################
@@ -267,14 +270,14 @@ ose = OSE()
267270
OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) =
268271
OSE(models, resampling, ps_lowerbound, machine_cache)
269272

270-
function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1)
273+
function (ose::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1)
271274
# Check the estimand against the dataset
272275
check_treatment_levels(Ψ, dataset)
273276
# Initial fit of the SCM's relevant factors
274277
initial_factors = get_relevant_factors(Ψ)
275278
nomissing_dataset = nomissing(dataset, variables(initial_factors))
276-
initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling)
277-
initial_factors_estimator = CMRelevantFactorsEstimator(estimator.resampling, estimator.models)
279+
initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset, ose.resampling)
280+
initial_factors_estimator = CMRelevantFactorsEstimator(ose.resampling, ose.models)
278281
initial_factors_estimate = initial_factors_estimator(
279282
initial_factors,
280283
initial_factors_dataset;
@@ -283,16 +286,21 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic
283286
)
284287
# Get propensity score truncation threshold
285288
n = nrows(nomissing_dataset)
286-
ps_lowerbound = ps_lower_bound(n, estimator.ps_lowerbound)
289+
ps_lowerbound = ps_lower_bound(n, ose.ps_lowerbound)
287290

288291
# Gradient and estimate
289-
IC, Ψ̂ = gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
290-
IC_mean = mean(IC)
291-
IC .-= IC_mean
292+
IC, Ψ̂ = gradient_and_estimate(ose, Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound)
292293
σ̂ = std(IC)
293294
n = size(IC, 1)
294295
verbosity >= 1 && @info "Done."
295-
return OSEstimate(Ψ, Ψ̂ + IC_mean, σ̂, n, IC), cache
296+
return OSEstimate(Ψ, Ψ̂, σ̂, n, IC), cache
297+
end
298+
299+
function gradient_and_estimate(::OSE, Ψ, factors, dataset; ps_lowerbound=1e-8)
300+
IC, Ψ̂ = gradient_and_plugin_estimate(Ψ, factors, dataset; ps_lowerbound=ps_lowerbound)
301+
IC_mean = mean(IC)
302+
IC .-= IC_mean
303+
return IC, Ψ̂ + IC_mean
296304
end
297305

298306
#####################################################################

src/counterfactual_mean_based/gradient.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@ function counterfactual_aggregate(Ψ::StatisticalCMCompositeEstimand, Q, dataset
2424
return ctf_agg
2525
end
2626

27-
compute_estimate(ctf_aggregate, ::Nothing) = mean(ctf_aggregate)
28-
29-
compute_estimate(ctf_aggregate, train_validation_indices) =
30-
mean(compute_estimate(ctf_aggregate[val_indices], nothing) for (_, val_indices) in train_validation_indices)
27+
plugin_estimate(ctf_aggregate) = mean(ctf_aggregate)
3128

3229

3330
"""
@@ -53,11 +50,11 @@ function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=
5350
end
5451

5552

56-
function gradient_and_estimate::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8)
53+
function gradient_and_plugin_estimate::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8)
5754
Q = factors.outcome_mean
5855
G = factors.propensity_score
5956
ctf_agg = counterfactual_aggregate(Ψ, Q, dataset)
60-
Ψ̂ = compute_estimate(ctf_agg, train_validation_indices_from_factors(factors))
57+
Ψ̂ = plugin_estimate(ctf_agg)
6158
IC = ∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂)
6259
return IC, Ψ̂
6360
end

0 commit comments

Comments
 (0)