Skip to content

Commit 5e575f7

Browse files
correct behavior to do first order interactions iteratively across listed extra_treatments
1 parent 8d63364 commit 5e575f7

File tree

4 files changed

+34
-95
lines changed

4 files changed

+34
-95
lines changed

src/inputs_from_config.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,26 @@ function genome_wide_estimands(
7171
if isempty(extra_treatments)
7272
treatments = treatments_from_variant(variant, dataset)
7373
else
74-
treatments = Dict(
75-
treatments_from_variant(variant, dataset)...,
76-
(pair for extra in extra_treatments for pair in treatments_from_variant(string(extra), dataset))...
77-
)
74+
for treatment in extra_treatments
75+
treatments = Dict(treatments_from_variant(variant, dataset)..., treatments_from_variant(string(treatment), dataset)...)
76+
local Ψ
77+
try
78+
Ψ = factorialEstimands(
79+
estimand_constructor, treatments, outcomes;
80+
confounders=confounders,
81+
dataset=dataset,
82+
outcome_extra_covariates=outcome_extra_covariates,
83+
positivity_constraint=positivity_constraint,
84+
verbosity=verbosity-1)
85+
86+
catch e
87+
if !(e == ArgumentError("No component passed the positivity constraint."))
88+
throw(e)
89+
end
90+
else
91+
append!(estimands, Ψ)
92+
end
93+
end
7894
end
7995

8096
local Ψ

test/data/config_gweis_first_order.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
type: gweis
22

33
extra_treatments:
4+
- TREAT_1
45
- 22001
56

67
outcome_extra_covariates:

test/data/config_gweis_higher_order.yaml

Lines changed: 0 additions & 12 deletions
This file was deleted.

test/inputs_from_gweis_config.jl

Lines changed: 13 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DataFrames
88
using Serialization
99
using TMLE
1010
using CSV
11+
using YAML
1112

1213
TESTDIR = joinpath(pkgdir(TargeneCore), "test")
1314

@@ -20,10 +21,16 @@ function get_summary_stats(estimands)
2021
end
2122

2223
function check_estimands_levels_interactions(estimands)
24+
extra_treatments = YAML.load_file(joinpath(TESTDIR, "data", "config_gweis_first_order.yaml"))["extra_treatments"]
25+
for (i,x) in enumerate(extra_treatments)
26+
extra_treatments[i]=Symbol(x)
27+
end
28+
2329
for Ψ in estimands
2430
# If the two components are present, the first is the 0 -> 1 and the second is the 1 -> 2
2531
# The variant should always be the last key
26-
variant = last(collect(keys.args[1].treatment_values)))
32+
treatment_set = collect(keys.args[1].treatment_values))
33+
variant = setdiff(treatment_set, extra_treatments)[1]
2734
if length.args) == 2
2835
@test Ψ.args[1].treatment_values[variant] == (control = 0x00, case = 0x01)
2936
@test Ψ.args[2].treatment_values[variant] == (control = 0x01, case = 0x02)
@@ -66,10 +73,10 @@ end
6673
# There are 875 variants in the dataset
6774
summary_stats = get_summary_stats(estimands)
6875
@test summary_stats == DataFrame(
69-
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
70-
nrow = repeat([875], 5)
76+
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2],
77+
nrow = repeat([2625], 4)
7178
)
72-
79+
println(estimands[1])
7380
check_estimands_levels_interactions(estimands)
7481
end
7582

@@ -102,86 +109,13 @@ end
102109
@test all(e isa JointEstimand for e in estimands)
103110
summary_stats = get_summary_stats(estimands)
104111
@test summary_stats == DataFrame(
105-
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
106-
nrow = repeat([142], 5)
112+
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2],
113+
nrow = repeat([430], 4)
107114
)
108115

109116
check_estimands_levels_interactions(estimands)
110117
end
111118

112-
@testset "Test inputs_from_config gweis: no positivity constraint and four-point interaction" begin
113-
tmpdir = mktempdir()
114-
copy!(ARGS, [
115-
"estimation-inputs",
116-
joinpath(TESTDIR, "data", "config_gweis_higher_order.yaml"),
117-
string("--traits-file=", joinpath(TESTDIR, "data", "ukbb_traits.csv")),
118-
string("--pcs-file=", joinpath(TESTDIR, "data", "ukbb_pcs.csv")),
119-
string("--genotypes-prefix=", joinpath(TESTDIR, "data", "ukbb", "genotypes" , "ukbb_1.")),
120-
string("--outprefix=", joinpath(tmpdir, "final")),
121-
"--batchsize=5",
122-
"--verbosity=0",
123-
"--positivity-constraint=0"
124-
])
125-
TargeneCore.julia_main()
126-
# Check dataset
127-
dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
128-
@test size(dataset) == (1940, 886)
129-
130-
# Check estimands
131-
estimands = []
132-
for file in readdir(tmpdir, join=true)
133-
if endswith(file, "jls")
134-
append!(estimands, deserialize(file).estimands)
135-
end
136-
end
137-
@test all(e isa JointEstimand for e in estimands)
138-
139-
# There are 875 variants in the dataset
140-
summary_stats = get_summary_stats(estimands)
141-
@test summary_stats == DataFrame(
142-
OUTCOME = [:CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
143-
nrow = repeat([875], 3)
144-
)
145-
146-
check_estimands_levels_interactions(estimands)
147-
end
148-
149-
@testset "Test inputs_from_config gweis: positivity constraint and four-point interaction" begin
150-
tmpdir = mktempdir()
151-
copy!(ARGS, [
152-
"estimation-inputs",
153-
joinpath(TESTDIR, "data", "config_gweis_higher_order.yaml"),
154-
string("--traits-file=", joinpath(TESTDIR, "data", "ukbb_traits.csv")),
155-
string("--pcs-file=", joinpath(TESTDIR, "data", "ukbb_pcs.csv")),
156-
string("--genotypes-prefix=", joinpath(TESTDIR, "data", "ukbb", "genotypes" , "ukbb_1.")),
157-
string("--outprefix=", joinpath(tmpdir, "final")),
158-
"--batchsize=5",
159-
"--verbosity=0",
160-
"--positivity-constraint=0.02"
161-
])
162-
TargeneCore.julia_main()
163-
# Check dataset
164-
dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
165-
@test size(dataset) == (1940, 886)
166-
167-
# Check estimands
168-
estimands = []
169-
for file in readdir(tmpdir, join=true)
170-
if endswith(file, "jls")
171-
append!(estimands, deserialize(file).estimands)
172-
end
173-
end
174-
@test all(e isa JointEstimand for e in estimands)
175-
176-
# There are 784 treatments in the dataset after positivity_constraint
177-
summary_stats = get_summary_stats(estimands)
178-
@test summary_stats == DataFrame(
179-
OUTCOME = [:CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
180-
nrow = repeat([784], 3)
181-
)
182-
183-
check_estimands_levels_interactions(estimands)
184-
end
185119

186120
end
187121
true

0 commit comments

Comments
 (0)