Skip to content

Commit 58148df

Browse files
add support for composed estimands in from_param_file (#190)
* add support for composed estimands in from_param_file * up manifest
1 parent 7795fe7 commit 58148df

File tree

7 files changed

+235
-274
lines changed

7 files changed

+235
-274
lines changed

Manifest.toml

Lines changed: 73 additions & 97 deletions
Large diffs are not rendered by default.

src/tl_inputs/from_actors.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function control_case_settings(::Type{TMLE.StatisticalATE}, treatments, data)
124124
end
125125

126126
function addEstimands!(estimands, treatments, variables, data; positivity_constraint=0.)
127-
freqs = TargeneCore.frequency_table(data, treatments)
127+
freqs = TMLE.frequency_table(data, treatments)
128128
# This loop adds all ATE estimands where all other treatments than
129129
# the bQTL are fixed, at the order 1, this is the simple bQTL's ATE
130130
for setting in control_case_settings(TMLE.StatisticalATE, treatments, data)
@@ -134,7 +134,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
134134
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
135135
outcome_extra_covariates = variables.covariates
136136
)
137-
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
137+
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
138138
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
139139
end
140140
end
@@ -147,7 +147,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
147147
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
148148
outcome_extra_covariates = variables.covariates
149149
)
150-
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
150+
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
151151
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
152152
end
153153
end

src/tl_inputs/from_param_files.jl

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ MismatchedCaseControlEncodingError() =
1616

1717
NoRemainingParamsError(positivity_constraint) = ArgumentError(string("No parameter passed the given positivity constraint: ", positivity_constraint))
1818

19+
MismatchedVariableError(variable) = ArgumentError(string("Each component of a ComposedEstimand should contain the same ", variable, " variables."))
1920

2021
function check_genotypes_encoding(val::NamedTuple, type)
2122
if !(typeof(val.case) <: type && typeof(val.control) <: type)
@@ -27,17 +28,66 @@ check_genotypes_encoding(val::T, type) where T =
2728
T <: type || throw(MismatchedCaseControlEncodingError())
2829

2930

31+
get_treatments(Ψ) = keys.treatment_values)
32+
33+
function get_treatments::ComposedEstimand)
34+
treatments = get_treatments(first.args))
35+
if length.args) > 1
36+
for arg in Ψ.args[2:end]
37+
get_treatments(arg) == treatments || throw(MismatchedVariableError("treatments"))
38+
end
39+
end
40+
return treatments
41+
end
42+
43+
get_confounders(Ψ) = Tuple(Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders)))
44+
45+
function get_confounders::ComposedEstimand)
46+
confounders = get_confounders(first.args))
47+
if length.args) > 1
48+
for arg in Ψ.args[2:end]
49+
get_confounders(arg) == confounders || throw(MismatchedVariableError("confounders"))
50+
end
51+
end
52+
return confounders
53+
end
54+
55+
get_outcome_extra_covariates(Ψ) = Ψ.outcome_extra_covariates
56+
57+
function get_outcome_extra_covariates::ComposedEstimand)
58+
outcome_extra_covariates = get_outcome_extra_covariates(first.args))
59+
if length.args) > 1
60+
for arg in Ψ.args[2:end]
61+
get_outcome_extra_covariates(arg) == outcome_extra_covariates || throw(MismatchedVariableError("outcome extra covariates"))
62+
end
63+
end
64+
return outcome_extra_covariates
65+
end
66+
67+
get_outcome(Ψ) = Ψ.outcome
68+
69+
function get_outcome::ComposedEstimand)
70+
outcome = get_outcome(first.args))
71+
if length.args) > 1
72+
for arg in Ψ.args[2:end]
73+
get_outcome(arg) == outcome || throw(MismatchedVariableError("outcome"))
74+
end
75+
end
76+
return outcome
77+
end
78+
3079
function get_variables(estimands, traits, pcs)
3180
genetic_variants = Set{Symbol}()
3281
others = Set{Symbol}()
3382
pcs = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(pcs)))
3483
alltraits = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(traits)))
3584
for Ψ in estimands
36-
treatments = keys.treatment_values)
37-
confounders = Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders))
85+
treatments = get_treatments(Ψ)
86+
confounders = get_confounders(Ψ)
87+
outcome_extra_covariates = get_outcome_extra_covariates(Ψ)
3888
push!(
3989
others,
40-
Ψ.outcome_extra_covariates...,
90+
outcome_extra_covariates...,
4191
confounders...,
4292
treatments...
4393
)
@@ -123,6 +173,8 @@ function adjust_parameter_sections(Ψ::T, variants_alleles, pcs) where T<:TMLE.E
123173
return T(outcome=Ψ.outcome, treatment_values=treatments, treatment_confounders=confounders, outcome_extra_covariates=Ψ.outcome_extra_covariates)
124174
end
125175

176+
adjust_parameter_sections::ComposedEstimand, variants_alleles, pcs) =
177+
ComposedEstimand.f, Tuple(adjust_parameter_sections(arg, variants_alleles, pcs) for arg in Ψ.args))
126178

127179
function append_from_valid_estimands!(
128180
estimands::Vector{<:TMLE.Estimand},
@@ -136,29 +188,28 @@ function append_from_valid_estimands!(
136188
# Update treatment's and confounders's sections of Ψ
137189
Ψ = adjust_parameter_sections(Ψ, variants_alleles, variables.pcs)
138190
# Update frequency tables with current treatments
139-
treatments = sorted_treatment_names(Ψ)
191+
treatments = get_treatments(Ψ)
140192
if !haskey(frequency_tables, treatments)
141-
frequency_tables[treatments] = TargeneCore.frequency_table(data, collect(treatments))
193+
frequency_tables[treatments] = TMLE.frequency_table(data, treatments)
142194
end
143195
# Check if parameter satisfies positivity
144-
satisfies_positivity(Ψ, frequency_tables[treatments];
145-
positivity_constraint=positivity_constraint) || return
146-
# Expand wildcard to all outcomes
147-
if Ψ.outcome === :ALL
148-
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
149-
else
150-
# Ψ.target || MissingVariableError(variable)
151-
push!(estimands, Ψ)
196+
if TMLE.satisfies_positivity(Ψ, frequency_tables[treatments]; positivity_constraint=positivity_constraint)
197+
# Expand wildcard to all outcomes
198+
if get_outcome(Ψ) === :ALL
199+
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
200+
else
201+
push!(estimands, Ψ)
202+
end
152203
end
153204
end
154205

155206
function adjusted_estimands(estimands, variables, data; positivity_constraint=0.)
156207
final_estimands = TMLE.Estimand[]
157208
variants_alleles = Dict(v => Set(unique(skipmissing(data[!, v]))) for v in variables.genetic_variants)
158-
freqency_tables = Dict()
209+
frequency_tables = Dict()
159210
for Ψ in estimands
160211
# If the genotypes encoding is a string representation make sure they match the actual genotypes
161-
append_from_valid_estimands!(final_estimands, freqency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
212+
append_from_valid_estimands!(final_estimands, frequency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
162213
end
163214

164215
length(final_estimands) > 0 || throw(NoRemainingParamsError(positivity_constraint))

src/tl_inputs/tl_inputs.jl

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ NotAllVariantsFoundError(rsids) =
6464
ArgumentError(string("Some variants were not found in the genotype files: ", join(rsids, ", ")))
6565

6666
NotBiAllelicOrUnphasedVariantError(rsid) = ArgumentError(string("Variant: ", rsid, " is not bi-allelic or not unphased."))
67+
6768
"""
6869
bgen_files(snps, bgen_prefix)
6970
@@ -103,47 +104,8 @@ function call_genotypes(bgen_prefix::String, query_rsids::Set{<:AbstractString},
103104
return genotypes
104105
end
105106

106-
sorted_treatment_names(Ψ) = tuple(sort(collect(keys.treatment_values)))...)
107-
108-
function setting_iterator::TMLE.StatisticalIATE)
109-
treatments = sorted_treatment_names(Ψ)
110-
return (
111-
NamedTuple{treatments}(collect(Tval)) for
112-
Tval in Iterators.product((values.treatment_values[T]) for T in treatments)...)
113-
)
114-
end
115-
116-
function setting_iterator::TMLE.StatisticalATE)
117-
treatments = sorted_treatment_names(Ψ)
118-
return (
119-
NamedTuple{treatments}([(Ψ.treatment_values[T][c]) for T in treatments])
120-
for c in (:case, :control)
121-
)
122-
end
123-
124-
function setting_iterator::TMLE.StatisticalCM)
125-
treatments = sorted_treatment_names(Ψ)
126-
return (NamedTuple{treatments}.treatment_values[T] for T in treatments), )
127-
end
128-
129-
function satisfies_positivity::TMLE.Estimand, freqs; positivity_constraint=0.01)
130-
for base_setting in setting_iterator(Ψ)
131-
if !haskey(freqs, base_setting) || freqs[base_setting] < positivity_constraint
132-
return false
133-
end
134-
end
135-
return true
136-
end
137-
138-
function frequency_table(data, treatments::AbstractVector)
139-
treatments = sort(treatments)
140-
freqs = Dict()
141-
N = nrow(data)
142-
for (key, group) in pairs(groupby(data, treatments; skipmissing=true))
143-
freqs[NamedTuple(key)] = nrow(group) / N
144-
end
145-
return freqs
146-
end
107+
TMLE.satisfies_positivity::ComposedEstimand, freqs; positivity_constraint=0.01) =
108+
all(TMLE.satisfies_positivity(arg, freqs; positivity_constraint=positivity_constraint) for arg in Ψ.args)
147109

148110
read_txt_file(path::Nothing) = nothing
149111
read_txt_file(path) = CSV.read(path, DataFrame, header=false)[!, 1]
@@ -164,15 +126,27 @@ function merge(traits, pcs, genotypes)
164126
)
165127
end
166128

129+
estimand_with_new_outcome::T, outcome) where T = T(
130+
outcome=outcome,
131+
treatment_values=Ψ.treatment_values,
132+
treatment_confounders=Ψ.treatment_confounders,
133+
outcome_extra_covariates=Ψ.outcome_extra_covariates
134+
)
135+
167136
function update_estimands_from_outcomes!(estimands, Ψ::T, outcomes) where T
168137
for outcome in outcomes
169138
push!(
170-
estimands,
171-
T(
172-
outcome=outcome,
173-
treatment_values=Ψ.treatment_values,
174-
treatment_confounders=Ψ.treatment_confounders,
175-
outcome_extra_covariates=Ψ.outcome_extra_covariates)
139+
estimands,
140+
estimand_with_new_outcome(Ψ, outcome)
141+
)
142+
end
143+
end
144+
145+
function update_estimands_from_outcomes!(estimands, Ψ::ComposedEstimand, outcomes)
146+
for outcome in outcomes
147+
push!(
148+
estimands,
149+
ComposedEstimand.f, Tuple(estimand_with_new_outcome(arg, outcome) for arg in Ψ.args))
176150
)
177151
end
178152
end

test/tl_inputs/from_param_files.jl

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,40 @@ include(joinpath(TESTDIR, "tl_inputs", "test_utils.jl"))
2424
pcs = TargeneCore.read_csv_file(joinpath(TESTDIR, "data", "pcs.csv"))
2525
# extraW, extraT, extraC are parsed from all param_files
2626
estimands = make_estimands_configuration().estimands
27+
# get_treatments, get_outcome, ...
28+
## Simple Estimand
29+
Ψ = estimands[1]
30+
@test TargeneCore.get_outcome(Ψ) == :ALL
31+
@test TargeneCore.get_treatments(Ψ) == keys.treatment_values)
32+
@test TargeneCore.get_confounders(Ψ) == ()
33+
@test TargeneCore.get_outcome_extra_covariates(Ψ) == ()
34+
## ComposedEstimand
35+
Ψ = estimands[5]
36+
@test TargeneCore.get_outcome(Ψ) == :ALL
37+
@test TargeneCore.get_treatments(Ψ) == keys.args[1].treatment_values)
38+
@test TargeneCore.get_confounders(Ψ) == ()
39+
@test TargeneCore.get_outcome_extra_covariates(Ψ) == (Symbol("22001"), )
40+
## Bad ComposedEstimand
41+
Ψ = ComposedEstimand(
42+
TMLE.joint_estimand, (
43+
CM(
44+
outcome = "Y1",
45+
treatment_values = (RSID_3 = "GG", RSID_198 = "AG"),
46+
treatment_confounders = (RSID_3 = [], RSID_198 = []),
47+
outcome_extra_covariates = [22001]
48+
),
49+
CM(
50+
outcome = "Y2",
51+
treatment_values = (RSID_2 = "AA", RSID_198 = "AG"),
52+
treatment_confounders = (RSID_2 = [:PC1], RSID_198 = []),
53+
outcome_extra_covariates = []
54+
))
55+
)
56+
@test_throws ArgumentError TargeneCore.get_outcome(Ψ) == :ALL
57+
@test_throws ArgumentError TargeneCore.get_treatments(Ψ)
58+
@test_throws ArgumentError TargeneCore.get_confounders(Ψ)
59+
@test_throws ArgumentError TargeneCore.get_outcome_extra_covariates(Ψ)
60+
# get_variables
2761
variables = TargeneCore.get_variables(estimands, traits, pcs)
2862
@test variables.genetic_variants == Set([:RSID_198, :RSID_2])
2963
@test variables.outcomes == Set([:BINARY_1, :CONTINUOUS_2, :CONTINUOUS_1, :BINARY_2])
@@ -38,8 +72,9 @@ end
3872
)
3973
pcs = Set([:PC1, :PC2])
4074
variants_alleles = Dict(:RSID_198 => Set(genotypes.RSID_198))
41-
# AG is not in the genotypes but GA is
42-
Ψ = make_estimands_configuration().estimands[4]
75+
estimands = make_estimands_configuration().estimands
76+
# RS198 AG is not in the genotypes but GA is
77+
Ψ = estimands[4]
4378
@test Ψ.treatment_values.RSID_198 == (case="AG", control="AA")
4479
new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
4580
@test new_Ψ.outcome == Ψ.outcome
@@ -50,6 +85,19 @@ end
5085
RSID_2 = (case = "AA", control = "GG")
5186
)
5287

88+
# ComnposedEstimand
89+
Ψ = estimands[5]
90+
@test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
91+
@test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA")
92+
new_Ψ = TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
93+
for index in 1:length.args)
94+
@test new_Ψ.args[index].outcome == Ψ.args[index].outcome
95+
@test new_Ψ.args[index].outcome_extra_covariates == (Symbol(22001),)
96+
@test new_Ψ.args[index].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2),)
97+
end
98+
@test new_Ψ.args[1].treatment_values == (RSID_198 = "GA", RSID_2 = "GG")
99+
@test new_Ψ.args[2].treatment_values == (RSID_198 = "GA", RSID_2 = "AA")
100+
53101
# If the allele is not present
54102
variants_alleles = Dict(:RSID_198 => Set(["AA"]))
55103
@test_throws TargeneCore.AbsentAlleleError("RSID_198", "AG") TargeneCore.adjust_parameter_sections(Ψ, variants_alleles, pcs)
@@ -95,8 +143,8 @@ end
95143

96144
## Estimands file:
97145
output_estimands = deserialize("final.estimands.jls").estimands
98-
# There are 5 initial estimands containing a *
99-
# Those are duplicated for each of the 4 targets.
146+
# There are 5 initial estimands containing a :ALL
147+
# Those are duplicated for each of the 4 outcomes.
100148
@test length(output_estimands) == 20
101149
# In all cases the PCs are appended to the confounders.
102150
for Ψ output_estimands
@@ -120,10 +168,11 @@ end
120168
@test Ψ.outcome_extra_covariates == (Symbol("22001"),)
121169

122170
# Input Estimand 5: GA is corrected to AG to match the data
123-
elseif Ψ isa TMLE.StatisticalCM && Ψ.treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
124-
@test Ψ.treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2))
125-
@test Ψ.outcome_extra_covariates == (Symbol("22001"),)
126-
171+
elseif Ψ isa TMLE.ComposedEstimand
172+
@test Ψ.args[1].treatment_values == (RSID_198 = "AG", RSID_2 = "GG")
173+
@test Ψ.args[2].treatment_values == (RSID_198 = "AG", RSID_2 = "AA")
174+
@test Ψ.args[1].treatment_confounders == Ψ.args[2].treatment_confounders == (RSID_198 = (:PC1, :PC2), RSID_2 = (:PC1, :PC2))
175+
@test Ψ.args[1].outcome_extra_covariates == Ψ.args[2].outcome_extra_covariates == (Symbol("22001"),)
127176
else
128177
throw(AssertionError(string("Which input did this output come from: ", Ψ)))
129178
end
@@ -142,7 +191,7 @@ end
142191
tl_inputs(parsed_args)
143192
# The IATES are the most sensitives
144193
outestimands = deserialize("final.estimands.jls").estimands
145-
@test allisa Union{TMLE.StatisticalCM, TMLE.StatisticalATE} for Ψ in outestimands)
194+
@test allisa Union{TMLE.StatisticalCM, TMLE.StatisticalATE, ComposedEstimand} for Ψ in outestimands)
146195
@test size(outestimands, 1) == 16
147196

148197
cleanup()

test/tl_inputs/test_utils.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ function cleanup(;prefix="final.")
66
end
77
end
88

9-
109
function make_estimands_configuration()
1110
estimands = [
1211
IATE(
@@ -32,11 +31,20 @@ function make_estimands_configuration()
3231
treatment_confounders = (RSID_2 = [], RSID_198 = []),
3332
outcome_extra_covariates = [22001]
3433
),
35-
CM(
36-
outcome = "ALL",
37-
treatment_values = (RSID_2 = "GG", RSID_198 = "GA"),
38-
treatment_confounders = (RSID_2 = [], RSID_198 = []),
39-
outcome_extra_covariates = [22001]
34+
ComposedEstimand(
35+
TMLE.joint_estimand, (
36+
CM(
37+
outcome = "ALL",
38+
treatment_values = (RSID_2 = "GG", RSID_198 = "AG"),
39+
treatment_confounders = (RSID_2 = [], RSID_198 = []),
40+
outcome_extra_covariates = [22001]
41+
),
42+
CM(
43+
outcome = "ALL",
44+
treatment_values = (RSID_2 = "AA", RSID_198 = "AG"),
45+
treatment_confounders = (RSID_2 = [], RSID_198 = []),
46+
outcome_extra_covariates = [22001]
47+
))
4048
)
4149
]
4250
return Configuration(estimands=estimands)

0 commit comments

Comments
 (0)