@@ -8,6 +8,7 @@ using DataFrames
8
8
using Serialization
9
9
using TMLE
10
10
using CSV
11
+ using YAML
11
12
12
13
TESTDIR = joinpath (pkgdir (TargeneCore), " test" )
13
14
@@ -20,10 +21,16 @@ function get_summary_stats(estimands)
20
21
end
21
22
22
23
function check_estimands_levels_interactions (estimands)
24
+ extra_treatments = YAML. load_file (joinpath (TESTDIR, " data" , " config_gweis_first_order.yaml" ))[" extra_treatments" ]
25
+ for (i,x) in enumerate (extra_treatments)
26
+ extra_treatments[i]= Symbol (x)
27
+ end
28
+
23
29
for Ψ in estimands
24
30
# If the two components are present, the first is the 0 -> 1 and the second is the 1 -> 2
25
31
# The variant should always be the last key
26
- variant = last (collect (keys (Ψ. args[1 ]. treatment_values)))
32
+ treatment_set = collect (keys (Ψ. args[1 ]. treatment_values))
33
+ variant = setdiff (treatment_set, extra_treatments)[1 ]
27
34
if length (Ψ. args) == 2
28
35
@test Ψ. args[1 ]. treatment_values[variant] == (control = 0x00 , case = 0x01 )
29
36
@test Ψ. args[2 ]. treatment_values[variant] == (control = 0x01 , case = 0x02 )
66
73
# There are 875 variants in the dataset
67
74
summary_stats = get_summary_stats (estimands)
68
75
@test summary_stats == DataFrame (
69
- OUTCOME = [:BINARY_1 , :BINARY_2 , :CONTINUOUS_1 , :CONTINUOUS_2 , :TREAT_1 ],
70
- nrow = repeat ([875 ], 5 )
76
+ OUTCOME = [:BINARY_1 , :BINARY_2 , :CONTINUOUS_1 , :CONTINUOUS_2 ],
77
+ nrow = repeat ([2625 ], 4 )
71
78
)
72
-
79
+ println (estimands[ 1 ])
73
80
check_estimands_levels_interactions (estimands)
74
81
end
75
82
@@ -102,86 +109,13 @@ end
102
109
@test all (e isa JointEstimand for e in estimands)
103
110
summary_stats = get_summary_stats (estimands)
104
111
@test summary_stats == DataFrame (
105
- OUTCOME = [:BINARY_1 , :BINARY_2 , :CONTINUOUS_1 , :CONTINUOUS_2 , :TREAT_1 ],
106
- nrow = repeat ([142 ], 5 )
112
+ OUTCOME = [:BINARY_1 , :BINARY_2 , :CONTINUOUS_1 , :CONTINUOUS_2 ],
113
+ nrow = repeat ([430 ], 4 )
107
114
)
108
115
109
116
check_estimands_levels_interactions (estimands)
110
117
end
111
118
112
- @testset " Test inputs_from_config gweis: no positivity constraint and four-point interaction" begin
113
- tmpdir = mktempdir ()
114
- copy! (ARGS , [
115
- " estimation-inputs" ,
116
- joinpath (TESTDIR, " data" , " config_gweis_higher_order.yaml" ),
117
- string (" --traits-file=" , joinpath (TESTDIR, " data" , " ukbb_traits.csv" )),
118
- string (" --pcs-file=" , joinpath (TESTDIR, " data" , " ukbb_pcs.csv" )),
119
- string (" --genotypes-prefix=" , joinpath (TESTDIR, " data" , " ukbb" , " genotypes" , " ukbb_1." )),
120
- string (" --outprefix=" , joinpath (tmpdir, " final" )),
121
- " --batchsize=5" ,
122
- " --verbosity=0" ,
123
- " --positivity-constraint=0"
124
- ])
125
- TargeneCore. julia_main ()
126
- # Check dataset
127
- dataset = DataFrame (Arrow. Table (joinpath (tmpdir, " final.data.arrow" )))
128
- @test size (dataset) == (1940 , 886 )
129
-
130
- # Check estimands
131
- estimands = []
132
- for file in readdir (tmpdir, join= true )
133
- if endswith (file, " jls" )
134
- append! (estimands, deserialize (file). estimands)
135
- end
136
- end
137
- @test all (e isa JointEstimand for e in estimands)
138
-
139
- # There are 875 variants in the dataset
140
- summary_stats = get_summary_stats (estimands)
141
- @test summary_stats == DataFrame (
142
- OUTCOME = [:CONTINUOUS_1 , :CONTINUOUS_2 , :TREAT_1 ],
143
- nrow = repeat ([875 ], 3 )
144
- )
145
-
146
- check_estimands_levels_interactions (estimands)
147
- end
148
-
149
- @testset " Test inputs_from_config gweis: positivity constraint and four-point interaction" begin
150
- tmpdir = mktempdir ()
151
- copy! (ARGS , [
152
- " estimation-inputs" ,
153
- joinpath (TESTDIR, " data" , " config_gweis_higher_order.yaml" ),
154
- string (" --traits-file=" , joinpath (TESTDIR, " data" , " ukbb_traits.csv" )),
155
- string (" --pcs-file=" , joinpath (TESTDIR, " data" , " ukbb_pcs.csv" )),
156
- string (" --genotypes-prefix=" , joinpath (TESTDIR, " data" , " ukbb" , " genotypes" , " ukbb_1." )),
157
- string (" --outprefix=" , joinpath (tmpdir, " final" )),
158
- " --batchsize=5" ,
159
- " --verbosity=0" ,
160
- " --positivity-constraint=0.02"
161
- ])
162
- TargeneCore. julia_main ()
163
- # Check dataset
164
- dataset = DataFrame (Arrow. Table (joinpath (tmpdir, " final.data.arrow" )))
165
- @test size (dataset) == (1940 , 886 )
166
-
167
- # Check estimands
168
- estimands = []
169
- for file in readdir (tmpdir, join= true )
170
- if endswith (file, " jls" )
171
- append! (estimands, deserialize (file). estimands)
172
- end
173
- end
174
- @test all (e isa JointEstimand for e in estimands)
175
-
176
- # There are 784 treatments in the dataset after positivity_constraint
177
- summary_stats = get_summary_stats (estimands)
178
- @test summary_stats == DataFrame (
179
- OUTCOME = [:CONTINUOUS_1 , :CONTINUOUS_2 , :TREAT_1 ],
180
- nrow = repeat ([784 ], 3 )
181
- )
182
-
183
- check_estimands_levels_interactions (estimands)
184
- end
185
119
186
120
end
187
121
true
0 commit comments