Skip to content

Commit 6a03213

Browse files
GWIS configuration
1 parent 15d842e commit 6a03213

File tree

5 files changed

+164
-3
lines changed

5 files changed

+164
-3
lines changed

Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.4"
44
manifest_format = "2.0"
5-
project_hash = "2b58dcc0ffd21f9ddd37c869a6c32c55af516134"
5+
project_hash = "b1a916e4f68d2fb953c6eba21c0fc7546e7dc5c6"
66

77
[[deps.ARFFFiles]]
88
deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"]

src/inputs_from_config.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,22 @@ function estimands_from_variants(
6060
estimand_constructor,
6161
outcomes,
6262
confounders;
63+
extra_treatments=[],
6364
outcome_extra_covariates=[],
6465
positivity_constraint=0.,
6566
verbosity=1
6667
)
6768
estimands = []
6869
for variant in variants
69-
treatments = treatments_from_variant(variant, dataset)
70+
71+
if isempty(extra_treatments)
72+
treatments = treatments_from_variant(variant, dataset)
73+
elseif length(extra_treatments) == 1
74+
treatments = Dict(treatments_from_variant(variant, dataset)..., treatments_from_variant(string(extra_treatments[1]), dataset)...)
75+
else
76+
error("GWIS mode only supports pairwise interaction with one extra treatment.")
77+
end
78+
7079
local Ψ
7180
try
7281
Ψ = factorialEstimands(
@@ -171,13 +180,15 @@ function treatments_from_variant(variant::String, dataset::DataFrame)
171180
end
172181

173182
function estimands_from_gwas(dataset, variants, outcomes, confounders;
183+
extra_treatments=extra_treatments,
174184
outcome_extra_covariates = [],
175185
positivity_constraint=0.,
176186
verbosity=0
177187
)
178188
variants_groups = Iterators.partition(variants, length(variants) ÷ Threads.nthreads())
179189
estimands_tasks = map(variants_groups) do variants
180190
Threads.@spawn estimands_from_variants(variants, dataset, ATE, outcomes, confounders;
191+
extra_treatments=extra_treatments,
181192
outcome_extra_covariates=outcome_extra_covariates,
182193
positivity_constraint=positivity_constraint,
183194
verbosity=verbosity
@@ -187,6 +198,24 @@ function estimands_from_gwas(dataset, variants, outcomes, confounders;
187198
return vcat(estimands_partitions...)
188199
end
189200

201+
function estimands_from_gwis(dataset, variants, outcomes, confounders;
202+
extra_treatments=extra_treatments,
203+
outcome_extra_covariates = [],
204+
positivity_constraint=0.,
205+
verbosity=0
206+
)
207+
variants_groups = Iterators.partition(variants, length(variants) ÷ Threads.nthreads())
208+
estimands_tasks = map(variants_groups) do variants
209+
Threads.@spawn estimands_from_variants(variants, dataset, AIE, outcomes, confounders;
210+
extra_treatments=extra_treatments,
211+
outcome_extra_covariates=outcome_extra_covariates,
212+
positivity_constraint=positivity_constraint,
213+
verbosity=verbosity
214+
)
215+
end
216+
estimands_partitions = fetch.(estimands_tasks)
217+
return vcat(estimands_partitions...)
218+
end
190219

191220
get_only_file_with_suffix(files, suffix) = files[only(findall(x -> endswith(x, suffix), files))]
192221

@@ -210,7 +239,7 @@ function get_genotypes_from_beds(bedprefix)
210239
end
211240

212241
function make_genotypes(genotype_prefix, config, call_threshold)
213-
genotypes = if config["type"] == "gwas"
242+
genotypes = if config["type"] == "gwas"|| config["type"] == "gwis"
214243
get_genotypes_from_beds(genotype_prefix)
215244
else
216245
variants_set = Set(retrieve_variants_list(config["variants"]))
@@ -263,10 +292,19 @@ function inputs_from_config(config_file, genotypes_prefix, traits_file, pcs_file
263292
elseif config_type == "gwas"
264293
variants = filter(!=("SAMPLE_ID"), names(genotypes))
265294
estimands_from_gwas(dataset, variants, outcomes, confounders;
295+
extra_treatments=extra_treatments,
266296
outcome_extra_covariates=outcome_extra_covariates,
267297
positivity_constraint=positivity_constraint,
268298
verbosity=verbosity
269299
)
300+
elseif config_type == "gwis"
301+
variants = filter(!=("SAMPLE_ID"), names(genotypes))
302+
estimands_from_gwis(dataset, variants, outcomes, confounders;
303+
extra_treatments=extra_treatments,
304+
outcome_extra_covariates=outcome_extra_covariates,
305+
positivity_constraint=positivity_constraint,
306+
verbosity=verbosity
307+
)
270308
else
271309
throw(ArgumentError(string("Unknown extraction type: ", config_type, ", use any of: (flat, groups, gwas)")))
272310
end

test/data/config_gwis.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
type: gwis
2+
3+
extra_treatments:
4+
- 22001
5+
6+
outcome_extra_covariates:
7+
- COV_1
8+
9+
extra_confounders:
10+
- 21003

test/inputs_from_gwis_config.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
module TestGwisEstimands
2+
3+
using Test
4+
using SnpArrays
5+
using TargeneCore
6+
using Arrow
7+
using DataFrames
8+
using Serialization
9+
using TMLE
10+
using CSV
11+
12+
TESTDIR = joinpath(pkgdir(TargeneCore), "test")
13+
14+
include(joinpath(TESTDIR, "testutils.jl"))
15+
16+
function get_summary_stats(estimands)
17+
outcomes = [TargeneCore.get_outcome(Ψ) for Ψ in estimands]
18+
results = DataFrame(ESTIMAND = estimands, OUTCOME = outcomes)
19+
return sort(combine(groupby(results, :OUTCOME), nrow), :OUTCOME)
20+
end
21+
22+
function check_estimands_levels_order(estimands)
23+
for Ψ in estimands
24+
# If the two components are present, the first is the 0 -> 1 and the second is the 1 -> 2
25+
variant = collect(keys.args[1].treatment_values))[2]
26+
if length.args) == 2
27+
@test Ψ.args[1].treatment_values[variant] == (control = 0x00, case = 0x01)
28+
@test Ψ.args[2].treatment_values[variant] == (control = 0x01, case = 0x02)
29+
else
30+
# Otherwise we check they are one or the other
31+
arg = only.args)
32+
@test arg.treatment_values[variant]==(control = 0x00, case = 0x01) ||
33+
arg.treatment_values[variant]==( control = 0x01, case = 0x02)
34+
end
35+
end
36+
end
37+
38+
@testset "Test inputs_from_config gwas: no positivity constraint" begin
39+
tmpdir = mktempdir()
40+
copy!(ARGS, [
41+
"estimation-inputs",
42+
joinpath(TESTDIR, "data", "config_gwis.yaml"),
43+
string("--traits-file=", joinpath(TESTDIR, "data", "ukbb_traits.csv")),
44+
string("--pcs-file=", joinpath(TESTDIR, "data", "ukbb_pcs.csv")),
45+
string("--genotypes-prefix=", joinpath(TESTDIR, "data", "ukbb", "genotypes" , "ukbb_1.")),
46+
string("--outprefix=", joinpath(tmpdir, "final")),
47+
"--batchsize=5",
48+
"--verbosity=0",
49+
"--positivity-constraint=0"
50+
])
51+
TargeneCore.julia_main()
52+
# Check dataset
53+
dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
54+
@test size(dataset) == (1940, 886)
55+
56+
# Check estimands
57+
estimands = []
58+
for file in readdir(tmpdir, join=true)
59+
if endswith(file, "jls")
60+
append!(estimands, deserialize(file).estimands)
61+
end
62+
end
63+
@test all(e isa JointEstimand for e in estimands)
64+
65+
# There are 875 variants in the dataset
66+
summary_stats = get_summary_stats(estimands)
67+
@test summary_stats == DataFrame(
68+
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
69+
nrow = repeat([875], 5)
70+
)
71+
72+
check_estimands_levels_order(estimands)
73+
end
74+
75+
76+
@testset "Test inputs_from_config gwas: positivity constraint" begin
77+
tmpdir = mktempdir()
78+
copy!(ARGS, [
79+
"estimation-inputs",
80+
joinpath(TESTDIR, "data", "config_gwis.yaml"),
81+
string("--traits-file=", joinpath(TESTDIR, "data", "ukbb_traits.csv")),
82+
string("--pcs-file=", joinpath(TESTDIR, "data", "ukbb_pcs.csv")),
83+
string("--genotypes-prefix=", joinpath(TESTDIR, "data", "ukbb", "genotypes" , "ukbb_1.")),
84+
string("--outprefix=", joinpath(tmpdir, "final")),
85+
"--batchsize=5",
86+
"--verbosity=0",
87+
"--positivity-constraint=0.2"
88+
])
89+
TargeneCore.julia_main()
90+
# Check dataset
91+
dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow")))
92+
@test size(dataset) == (1940, 886)
93+
# Check estimands
94+
estimands = []
95+
for file in readdir(tmpdir, join=true)
96+
if endswith(file, "jls")
97+
append!(estimands, deserialize(file).estimands)
98+
end
99+
end
100+
# The positivity constraint reduces the number of variants
101+
@test all(e isa JointEstimand for e in estimands)
102+
summary_stats = get_summary_stats(estimands)
103+
@test summary_stats == DataFrame(
104+
OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2, :TREAT_1],
105+
nrow = repeat([142], 5)
106+
)
107+
108+
check_estimands_levels_order(estimands)
109+
end
110+
111+
112+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ TESTDIR = joinpath(pkgdir(TargeneCore), "test")
1010
@test include(joinpath(TESTDIR, "inputs_from_estimands.jl"))
1111
@test include(joinpath(TESTDIR, "inputs_from_config.jl"))
1212
@test include(joinpath(TESTDIR, "inputs_from_gwas_config.jl"))
13+
@test include(joinpath(TESTDIR, "inputs_from_gwis_config.jl"))
1314
@test include(joinpath(TESTDIR, "sieve_variance.jl"))
1415
end

0 commit comments

Comments
 (0)