Skip to content

Commit 248bc05

Browse files
Merge pull request #112 from TARGENE/treatment_values
For 0.17.0 release
2 parents 5d4a8e9 + e9aca43 commit 248bc05

36 files changed

+773
-684
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
fail-fast: false
1111
matrix:
1212
version:
13-
- '1.6'
13+
- '1.10'
1414
- '1'
1515
os:
1616
- ubuntu-latest

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "TMLE"
22
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
33
authors = ["Olivier Labayle"]
4-
version = "0.16.1"
4+
version = "0.17.0"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
8+
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
89
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
910
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1011
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -18,6 +19,7 @@ MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
1819
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
1920
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
2021
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
22+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2123
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2224
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2325
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
@@ -46,7 +48,7 @@ JSON = "0.21.4"
4648
LogExpFunctions = "0.3"
4749
MLJBase = "1.0.1"
4850
MLJGLMInterface = "0.3.4"
49-
MLJModels = "0.15, 0.16"
51+
MLJModels = "0.15, 0.16, 0.17"
5052
MetaGraphsNext = "0.7"
5153
Missings = "1.0"
5254
PrecompileTools = "1.1.1"
@@ -55,7 +57,9 @@ TableOperations = "1.2"
5557
Tables = "1.6"
5658
YAML = "0.4.9"
5759
Zygote = "0.6.69"
58-
julia = "1.6, 1.7, 1"
60+
OrderedCollections = "1.6.3"
61+
AutoHashEquals = "2.1.0"
62+
julia = "1.10, 1"
5963

6064
[extras]
6165
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"

docs/src/user_guide/estimands.md

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,7 @@ statisticalΨ = ATE(
119119
)
120120
```
121121

122-
- Factorial Treatments
123-
124-
It is possible to generate a `ComposedEstimand` containing all linearly independent IATEs from a set of treatment values or from a dataset. For that purpose, use the `factorialEstimand` function.
125-
126-
## The Interaction Average Treatment Effect
122+
## The Average Interaction Effect
127123

128124
- Causal Question:
129125

@@ -136,14 +132,14 @@ For a general higher-order definition, please refer to [Higher-order interaction
136132
For two points interaction with both treatment and control levels ``0`` and ``1`` for ease of notation:
137133

138134
```math
139-
IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[Y|do(T_1=1, T_2=1)] - \mathbb{E}[Y|do(T_1=1, T_2=0)] \\
135+
AIE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[Y|do(T_1=1, T_2=1)] - \mathbb{E}[Y|do(T_1=1, T_2=0)] \\
140136
- \mathbb{E}[Y|do(T_1=0, T_2=1)] + \mathbb{E}[Y|do(T_1=0, T_2=0)]
141137
```
142138

143139
- Statistical Estimand (via backdoor adjustment):
144140

145141
```math
146-
IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|T_1=1, T_2=1, \textbf{W}] - \mathbb{E}[Y|T_1=1, T_2=0, \textbf{W}] \\
142+
AIE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|T_1=1, T_2=1, \textbf{W}] - \mathbb{E}[Y|T_1=1, T_2=0, \textbf{W}] \\
147143
- \mathbb{E}[Y|T_1=0, T_2=1, \textbf{W}] + \mathbb{E}[Y|T_1=0, T_2=0, \textbf{W}]]
148144
```
149145

@@ -152,7 +148,7 @@ IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[
152148
A causal estimand is given by:
153149

154150
```@example estimands
155-
causalΨ = IATE(
151+
causalΨ = AIE(
156152
outcome=:Y,
157153
treatment_values=(
158154
T₁=(case=1, control=0),
@@ -170,7 +166,7 @@ statisticalΨ = identify(causalΨ, scm)
170166
or defined directly:
171167

172168
```@example estimands
173-
statisticalΨ = IATE(
169+
statisticalΨ = AIE(
174170
outcome=:Y,
175171
treatment_values=(
176172
T₁=(case=1, control=0),
@@ -182,13 +178,11 @@ statisticalΨ = IATE(
182178

183179
- Factorial Treatments
184180

185-
It is possible to generate a `ComposedEstimand` containing all linearly independent IATEs from a set of treatment values or from a dataset. For that purpose, use the `factorialEstimand` function.
186-
187-
## Composed Estimands
181+
It is possible to generate a `JointEstimand` containing all linearly independent AIEs from a set of treatment values or from a dataset. For that purpose, use the `factorialEstimand` function.
188182

189-
As a result of Julia's automatic differentiation facilities, given a set of predefined estimands ``(\Psi_1, ..., \Psi_k)``, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. This is done via the `ComposedEstimand` type.
183+
## Joint And Composed Estimands
190184

191-
For example, the difference in ATE for a treatment with 3 levels (0, 1, 2) can be defined as follows:
185+
A `JointEstimand` is simply a list of one dimensional estimands that are grouped together. For instance for a treatment `T` taking three possible values ``(0, 1, 2)`` we can define the two following Average Treatment Effects and a corresponding `JointEstimand`:
192186

193187
```julia
194188
ATE₁ = ATE(
@@ -201,5 +195,23 @@ ATE₂ = ATE(
201195
treatment_values = (T = (control = 1, case = 2),),
202196
treatment_confounders = [:W]
203197
)
204-
ATEdiff = ComposedEstimand(-, (ATE₁, ATE₂))
198+
joint_estimand = JointEstimand(ATE₁, ATE₂)
199+
```
200+
201+
You can easily generate joint estimands corresponding to Counterfactual Means, Average Treatment Effects or Average Interaction Effects by using the `factorialEstimand` function.
202+
203+
To estimate a joint estimand you can use any of the estimators defined in this package exactly as you would do it for a one dimensional estimand.
204+
205+
There are two main use cases for them that we now describe.
206+
207+
### Joint Testing
208+
209+
In some cases, like in factorial analyses where multiple versions of a treatment are tested, it may be of interest to know if any version of the versions has had an effect. This can be done via a Hotelling's T2 Test, which is simply a multivariate generalisation of the Student's T test. This is the default returned by the `significance_test` function provided in TMLE.jl and the result of the test is also printed to the REPL for any joint estimate.
210+
211+
### Composition
212+
213+
Once you have estimated a `JointEstimand` and have a `JointEstimate`, you may be interested to ask further questions. For instance whether two treatment versions have the same effect. This question is typically answered by testing if the difference in Average Treatment Effect is 0. Using the Delta Method and Julia's automatic differentiation, you don't need to explicitly define a semi-parametric estimator for it. You can simply call `compose`:
214+
215+
```julia
216+
ATEdiff = compose(-, joint_estimate)
205217
```

docs/src/user_guide/estimation.md

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Again, required nuisance functions are fitted and stored in the cache.
110110

111111
## Specifying Models
112112

113-
By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is essentially a `NamedTuple` mapping variables' names to their respective model.
113+
By default, TMLE.jl uses generalized linear models for the estimation of relevant and nuisance factors such as the outcome mean and the propensity score. However, this is not the recommended usage since the estimators' performance is closely related to how well we can estimate these factors. More sophisticated models can be provided using the `models` keyword argument of each estimator which is a `Dict{Symbol, Model}` mapping variables' names to their respective model.
114114

115115
Rather than specifying a specific model for each variable it may be easier to override the default models using the `default_models` function:
116116

@@ -121,9 +121,9 @@ using MLJXGBoostInterface
121121
xgboost_regressor = XGBoostRegressor()
122122
xgboost_classifier = XGBoostClassifier()
123123
models = default_models(
124-
Q_binary=xgboost_classifier,
125-
Q_continuous=xgboost_regressor,
126-
G=xgboost_classifier
124+
Q_binary = xgboost_classifier,
125+
Q_continuous = xgboost_regressor,
126+
G = xgboost_classifier
127127
)
128128
tmle_gboost = TMLEE(models=models)
129129
```
@@ -140,19 +140,18 @@ stack_binary = Stack(
140140
lr=lr
141141
)
142142
143-
models = (
144-
T₁ = with_encoder(xgboost_classifier), # T₁ with XGBoost prepended with a Continuous Encoder
145-
default_models( # For all other variables use the following defaults
146-
Q_binary=stack_binary, # A Super Learner
147-
Q_continuous=xgboost_regressor, # An XGBoost
143+
models = default_models( # For all non-specified variables use the following defaults
144+
Q_binary = stack_binary, # A Super Learner
145+
Q_continuous = xgboost_regressor, # An XGBoost
146+
# T₁ with XGBoost prepended with a Continuous Encoder
147+
T₁ = xgboost_classifier
148148
# Unspecified G defaults to Logistic Regression
149-
)...
150149
)
151150
152151
tmle_custom = TMLEE(models=models)
153152
```
154153

155-
Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `NamedTuple`.
154+
Notice that `with_encoder` is simply a shorthand to construct a pipeline with a `ContinuousEncoder` and that the resulting `models` is simply a `Dict`.
156155

157156
## CV-Estimation
158157

@@ -196,10 +195,10 @@ result₃
196195
nothing # hide
197196
```
198197

199-
This time only the model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `IATE` between `T₁` and `T₂`.
198+
This time only the model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `AIE` between `T₁` and `T₂`.
200199

201200
```@example estimation
202-
Ψ₄ = IATE(
201+
Ψ₄ = AIE(
203202
outcome=:Y,
204203
treatment_values=(
205204
T₁=(case=true, control=false),
@@ -218,18 +217,20 @@ nothing # hide
218217

219218
All nuisance functions have been reused, only the fluctuation is fitted!
220219

221-
## Composing Estimands
220+
## Joint Estimands and Composition
222221

223-
By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can estimate any estimand which is a function of already estimated estimands. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system.
222+
As explained in [Joint And Composed Estimands](@ref), a joint estimand is simply a collection of estimands. Here, we will illustrate that an Average Interaction Effect is also defined as a difference in partial Average Treatment Effects.
224223

225-
For instance, by definition of the ``IATE``, we should be able to retrieve:
224+
More precisely, we would like to see if the left-hand side of this equation is equal to the right-hand side:
226225

227226
```math
228-
IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} - ATE_{T_1=0, T_2=0 \rightarrow 1} - ATE_{T_1=0 \rightarrow 1, T_2=0}
227+
AIE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} - ATE_{T_1=0, T_2=0 \rightarrow 1} - ATE_{T_1=0 \rightarrow 1, T_2=0}
229228
```
230229

230+
For that, we need to define a joint estimand of three components:
231+
231232
```@example estimation
232-
first_ate = ATE(
233+
ATE₁ = ATE(
233234
outcome=:Y,
234235
treatment_values=(
235236
T₁=(case=true, control=false),
@@ -239,9 +240,7 @@ first_ate = ATE(
239240
T₂=[:W₂₁, :W₂₂],
240241
),
241242
)
242-
first_ate_result, cache = tmle(first_ate, dataset, cache=cache, verbosity=0);
243-
244-
second_ate = ATE(
243+
ATE₂ = ATE(
245244
outcome=:Y,
246245
treatment_values=(
247246
T₁=(case=false, control=false),
@@ -251,15 +250,27 @@ second_ate = ATE(
251250
T₂=[:W₂₁, :W₂₂],
252251
),
253252
)
254-
second_ate_result, cache = tmle(second_ate, dataset, cache=cache, verbosity=0);
253+
joint_estimand = JointEstimand(Ψ₃, ATE₁, ATE₂)
254+
```
255255

256-
composed_iate_result = compose(
257-
(x, y, z) -> x - y - z,
258-
result₃, first_ate_result, second_ate_result
259-
)
256+
where the interaction `Ψ₃` was defined earlier. This joint estimand can be estimated like any other estimand using our estimator of choice:
257+
258+
```@example estimation
259+
joint_estimate, cache = tmle(joint_estimand, dataset, cache=cache, verbosity=0);
260+
joint_estimate
261+
```
262+
263+
The printed output is the result of a Hotelling's T2 Test which is the multivariate counterpart of the Student's T Test. It tells us whether any of the component of this joint estimand is different from 0.
264+
265+
Then we can formally test our hypothesis by leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation.
266+
267+
```@example estimation
268+
composed_result = compose((x, y, z) -> x - y - z, joint_estimate)
260269
isapprox(
261270
estimate(result₄),
262-
estimate(composed_iate_result),
271+
first(estimate(composed_result)),
263272
atol=0.1
264273
)
265274
```
275+
276+
By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system.

docs/src/walk_through.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ marginal_ate_t1 = ATE(
108108
)
109109
```
110110

111-
- The Interaction Average Treatment Effect:
111+
- The Average Interaction Effect:
112112

113113
```@example walk-through
114-
iate = IATE(
114+
aie = AIE(
115115
outcome = :Y,
116116
treatment_values = (
117117
T₁=(case=1, control=0),
@@ -125,7 +125,7 @@ iate = IATE(
125125
Identification is the process by which a Causal Estimand is turned into a Statistical Estimand, that is, a quantity we may estimate from data. This is done via the `identify` function which also takes in the ``SCM``:
126126

127127
```@example walk-through
128-
statistical_iate = identify(iate, scm)
128+
statistical_aie = identify(aie, scm)
129129
```
130130

131131
Alternatively, you can also directly define the statistical parameters (see [Estimands](@ref)).
@@ -149,7 +149,7 @@ Statistical Estimands can be estimated without a ``SCM``, let's use the One-Step
149149

150150
```@example walk-through
151151
ose = OSE()
152-
result, cache = ose(statistical_iate, dataset)
152+
result, cache = ose(statistical_aie, dataset)
153153
result
154154
```
155155

@@ -160,3 +160,5 @@ Both TMLE and OSE asymptotically follow a Normal distribution. It means we can p
160160
```@example walk-through
161161
OneSampleTTest(result)
162162
```
163+
164+
If the estimate is high-dimensional, a `OneSampleHotellingT2Test` should be performed instead. Alternatively, the `significance_test` function will automatically select the appropriate test for the estimate and return its result.

examples/double_robustness.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ function tmle_inference(data)
157157
treatment_values=(Tcat=(case=1.0, control=0.0),),
158158
treatment_confounders=(Tcat=[:W],)
159159
)
160-
models = (
161-
Y = with_encoder(LinearRegressor()),
162-
Tcat = with_encoder(LinearBinaryClassifier())
160+
models = Dict(
161+
:Y => with_encoder(LinearRegressor()),
162+
:Tcat => with_encoder(LinearBinaryClassifier())
163163
)
164164
tmle = TMLEE(models=models)
165165
result, _ = tmle(Ψ, data; verbosity=0)

src/TMLE.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@ using Graphs
2020
using MetaGraphsNext
2121
using Combinatorics
2222
using SplitApplyCombine
23+
using OrderedCollections
24+
using AutoHashEquals
2325

2426
# #############################################################################
2527
# EXPORTS
2628
# #############################################################################
2729

2830
export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices
29-
export CM, ATE, IATE
31+
export CM, ATE, AIE
3032
export AVAILABLE_ESTIMANDS
3133
export factorialEstimand, factorialEstimands
3234
export TMLEE, OSE, NAIVE
33-
export ComposedEstimand
35+
export JointEstimand, ComposedEstimand
3436
export var, estimate, pvalue, confint, emptyIC
3537
export significance_test, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test
3638
export compose
@@ -48,8 +50,8 @@ include("utils.jl")
4850
include("scm.jl")
4951
include("adjustment.jl")
5052
include("estimands.jl")
51-
include("estimators.jl")
5253
include("estimates.jl")
54+
include("estimators.jl")
5355
include("treatment_transformer.jl")
5456
include("estimand_ordering.jl")
5557

@@ -61,6 +63,6 @@ include("counterfactual_mean_based/clever_covariate.jl")
6163
include("counterfactual_mean_based/gradient.jl")
6264

6365
include("configuration.jl")
64-
66+
include("testing.jl")
6567

6668
end

src/configuration.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ from_dict!(x) = x
3535
from_dict!(v::AbstractVector) = [from_dict!(x) for x in v]
3636

3737
"""
38-
from_dict!(d::Dict)
38+
from_dict!(d::AbstractDict)
3939
4040
Converts a dictionary to a TMLE struct.
4141
"""
42-
function from_dict!(d::Dict{T, Any}) where T
42+
function from_dict!(d::AbstractDict{T, Any}) where T
4343
haskey(d, T(:type)) || return Dict(key => from_dict!(val) for (key, val) in d)
4444
constructor = eval(Meta.parse(pop!(d, :type)))
4545
return constructor(;(key => from_dict!(val) for (key, val) in d)...)

0 commit comments

Comments
 (0)