diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index df03723..70ea0e7 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,11 +18,11 @@ jobs: - x64 steps: - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: diff --git a/.gitignore b/.gitignore index 9ab9e5f..b9a423b 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ sandbox.jl src/generate_results.jl test_grid.csv test/tl_inputs/real_data.jl +test/data/scratch/ diff --git a/Manifest.toml b/Manifest.toml index 5e87db8..e4c5124 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.4" manifest_format = "2.0" -project_hash = "2b58dcc0ffd21f9ddd37c869a6c32c55af516134" +project_hash = "026a4a228f2fe6f789f0bbad2d45493e7ed45f99" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] @@ -936,6 +936,24 @@ git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" version = "1.0.2" +[[deps.HDF5]] +deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] +git-tree-sha1 = "e856eef26cf5bf2b0f95f8f4fc37553c72c8641c" +uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +version = "0.17.2" + + [deps.HDF5.extensions] + MPIExt = "MPI" + + [deps.HDF5.weakdeps] + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + +[[deps.HDF5_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "LibCURL_jll", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "OpenSSL_jll", "TOML", "Zlib_jll", "libaec_jll"] +git-tree-sha1 = "82a471768b513dc39e471540fdadc84ff80ff997" +uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" +version = "1.14.3+3" + [[deps.HTTP]] deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" @@ -954,6 +972,12 @@ git-tree-sha1 = "8e070b599339d622e9a081d17230d74a5c473293" uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" version = "0.1.17" +[[deps.Hwloc_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "50aedf345a709ab75872f80a2779568dc0bb461b" +uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8" +version = "2.11.2+3" + [[deps.HypergeometricFunctions]] deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "7c4195be1649ae622304031ed46a2f4df989f1eb" @@ -1137,9 +1161,9 @@ version = "1.0.0" [[deps.JLD2]] deps = ["FileIO", "MacroTools", "Mmap", "OrderedCollections", "PrecompileTools", "Requires", "TranscodingStreams"] -git-tree-sha1 = "a0746c21bdc986d0dc293efa6b1faee112c37c28" +git-tree-sha1 = "89e1e5c3d43078d42eed2306cab2a11b13e5c6ae" uuid = "033835bb-8acc-5ee8-8aae-3f567f8a3819" -version = "0.4.53" +version = "0.4.54" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] @@ -1358,9 +1382,9 @@ weakdeps = ["ChainRulesCore", "SparseArrays", "Statistics"] [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" +version = "0.3.29" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -1503,6 +1527,24 @@ git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" version = "0.4.4" +[[deps.MPICH_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Hwloc_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "7715e65c47ba3941c502bffb7f266a41a7f54423" +uuid = "7cb0a576-ebde-5e09-9194-50597f1243b4" +version = "4.2.3+0" + +[[deps.MPIPreferences]] +deps = ["Libdl", "Preferences"] +git-tree-sha1 = "c105fe467859e7f6e9a852cb15cb4301126fac07" +uuid = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" +version = "0.1.11" + +[[deps.MPItrampoline_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "70e830dab5d0775183c99fc75e4c24c614ed7142" +uuid = "f1f71cc9-e9ae-5b93-9b94-4fe0e1ad3748" +version = "5.5.1+2" + [[deps.MacroTools]] deps = ["Markdown", "Random"] git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" @@ -1564,6 +1606,12 @@ git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" uuid = "128add7d-3638-4c79-886c-908ea0c25c34" version = "0.2.0" +[[deps.MicrosoftMPI_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "bc95bf4149bf535c09602e3acdf950d9b4376227" +uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" +version = "10.1.4+3" + [[deps.Missings]] deps = ["DataAPI"] git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" @@ -1697,6 +1745,12 @@ git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" version = "0.3.1" +[[deps.OpenMPI_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" +uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" +version = "4.1.6+0" + [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" @@ -2236,10 +2290,10 @@ uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" version = "1.7.0" [[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +deps = ["AliasTables", "DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "29321314c920c26684834965ec2ce0dacc9cf8e5" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" +version = "0.34.4" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -2300,15 +2354,20 @@ version = "7.2.1+1" [[deps.TMLE]] deps = ["AbstractDifferentiation", "AutoHashEquals", "CategoricalArrays", "Combinatorics", "Distributions", "GLM", "Graphs", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "MetaGraphsNext", "Missings", "OrderedCollections", "PrecompileTools", "Random", "SplitApplyCombine", "Statistics", "TableOperations", "Tables", "Zygote"] -git-tree-sha1 = "2928dc0b1c3e26d8dcaaa7585dedefe33f687020" +git-tree-sha1 = "01d5f62021293388d0d7a1fb6b3618312d543a33" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" -version = "0.17.0" -weakdeps = ["JSON", "YAML"] +version = "0.17.1" [deps.TMLE.extensions] + CausalTablesExt = "CausalTables" JSONExt = "JSON" YAMLExt = "YAML" + [deps.TMLE.weakdeps] + CausalTables = "6af48e0c-efc2-4bf7-a92f-a553ccf79fd6" + JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" + YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" + [[deps.TMLECLI]] deps = ["ArgParse", "Arrow", "CSV", "CategoricalArrays", "Combinatorics", "Configurations", "DataFrames", "EvoTrees", "GLMNet", "JLD2", "JSON", "MKL", "MLJ", "MLJBase", "MLJLinearModels", "MLJModelInterface", "MLJModels", "MLJXGBoostInterface", "Mmap", "PackageCompiler", "Random", "Serialization", "TMLE", "Tables", "YAML"] git-tree-sha1 = "eb8a7aa2a17c0dd2fb165dd79a860c1aea0df55a" @@ -2628,6 +2687,12 @@ git-tree-sha1 = "51b5eeb3f98367157a7a12a1fb0aa5328946c03c" uuid = "9a68df92-36a6-505f-a73e-abb412b6bfb4" version = "0.2.3+0" +[[deps.libaec_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "46bf7be2917b59b761247be3f317ddf75e50e997" +uuid = "477f73a3-ac25-53e9-8cc3-50b2fa2566f0" +version = "1.1.2+2" + [[deps.libaom_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] git-tree-sha1 = "1827acba325fdcdf1d2647fc8d5301dd9ba43a9d" diff --git a/Project.toml b/Project.toml index f823dd3..cdac9ef 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,10 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" @@ -24,6 +26,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" SnpArrays = "4e780e97-f5bf-4111-9dc4-b70aaf691b06" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" TMLECLI = "2573d147-4098-46ba-9db2-8608d210ccac" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" @@ -37,12 +40,12 @@ CairoMakie = "0.12" CategoricalArrays = "0.10" Combinatorics = "1.0" DataFrames = "1.2" +MKL = "0.7" OrderedCollections = "1.6.3" PackageCompiler = "2.1.17" SnpArrays = "0.3" StableRNGs = "1.0.1" Statistics = "1.10" TMLE = "0.17" -MKL = "0.7" YAML = "0.4.9" julia = "1.10, 1" diff --git a/src/TargeneCore.jl b/src/TargeneCore.jl index 4434dff..a81e58d 100644 --- a/src/TargeneCore.jl +++ b/src/TargeneCore.jl @@ -25,6 +25,7 @@ using TMLECLI using HypothesisTests using OrderedCollections using Mmap +using StatsBase ############################################################################### ### INCLUDES ### diff --git a/src/inputs_from_config.jl b/src/inputs_from_config.jl index f5bc918..171bce8 100644 --- a/src/inputs_from_config.jl +++ b/src/inputs_from_config.jl @@ -54,35 +54,63 @@ function try_append_new_estimands!( end end -function estimands_from_variants( +function genome_wide_estimands( variants, dataset, estimand_constructor, outcomes, confounders; + extra_treatments=[], outcome_extra_covariates=[], positivity_constraint=0., verbosity=1 ) estimands = [] for variant in variants - treatments = treatments_from_variant(variant, dataset) - local Ψ - try - Ψ = factorialEstimands( - estimand_constructor, treatments, outcomes; - confounders=confounders, - dataset=dataset, - outcome_extra_covariates=outcome_extra_covariates, - positivity_constraint=positivity_constraint, - verbosity=verbosity-1) + if estimand_constructor == ATE + treatments = treatments_from_variant(variant, dataset) + local Ψ + try + Ψ = factorialEstimands( + estimand_constructor, treatments, outcomes; + confounders=confounders, + dataset=dataset, + outcome_extra_covariates=outcome_extra_covariates, + positivity_constraint=positivity_constraint, + verbosity=verbosity-1) - catch e - if !(e == ArgumentError("No component passed the positivity constraint.")) - throw(e) + catch e + if !(e == ArgumentError("No component passed the positivity constraint.")) + throw(e) + end + else + append!(estimands, Ψ) end + elseif estimand_constructor == AIE && !isempty(extra_treatments) + for treatment in extra_treatments + treatments = Dict(treatments_from_variant(variant, dataset)..., treatments_from_variant(string(treatment), dataset)...) + local Ψ + try + Ψ = factorialEstimands( + estimand_constructor, treatments, outcomes; + confounders=confounders, + dataset=dataset, + outcome_extra_covariates=outcome_extra_covariates, + positivity_constraint=positivity_constraint, + verbosity=verbosity-1) + + catch e + if !(e == ArgumentError("No component passed the positivity constraint.")) + throw(e) + end + else + append!(estimands, Ψ) + end + end + elseif estimand_constructor == AIE && isempty(extra_treatments) + throw(ArgumentError("Extra treatments are necessary for AIE estimands in this configuration.")) else - append!(estimands, Ψ) + throw(ArgumentError(string("Invalid estimand constructor: ", estimand_constructor))) end end return estimands @@ -163,31 +191,37 @@ end """ treatment_from_variant(variant, dataset) - Generate a key-value pair (dicitionary) for treatment structs. + Generate a key-value pair (dictionary) for treatment structs. """ function treatments_from_variant(variant::String, dataset::DataFrame) variant_levels = sort(levels(dataset[!, variant], skipmissing=true)) return Dict{Symbol, Vector{UInt8}}(Symbol(variant)=>variant_levels) end -function estimands_from_gwas(dataset, variants, outcomes, confounders; +function estimands_from_gwas(estimands_configs, dataset, variants, outcomes, confounders; + extra_treatments=extra_treatments, outcome_extra_covariates = [], positivity_constraint=0., verbosity=0 ) - variants_groups = Iterators.partition(variants, length(variants) ÷ Threads.nthreads()) - estimands_tasks = map(variants_groups) do variants - Threads.@spawn estimands_from_variants(variants, dataset, ATE, outcomes, confounders; - outcome_extra_covariates=outcome_extra_covariates, - positivity_constraint=positivity_constraint, - verbosity=verbosity - ) + estimands = [] + for estimands_config in estimands_configs + estimand_constructor = eval(Symbol(estimands_config["type"])) + variants_groups = Iterators.partition(variants, length(variants) ÷ Threads.nthreads()) + estimands_tasks = map(variants_groups) do variants + Threads.@spawn genome_wide_estimands(variants, dataset, estimand_constructor, outcomes, confounders; + extra_treatments=extra_treatments, + outcome_extra_covariates=outcome_extra_covariates, + positivity_constraint=positivity_constraint, + verbosity=verbosity + ) + end + estimands_partitions = fetch.(estimands_tasks) + push!(estimands, vcat(estimands_partitions...)) end - estimands_partitions = fetch.(estimands_tasks) - return vcat(estimands_partitions...) + return vcat(estimands...) end - get_only_file_with_suffix(files, suffix) = files[only(findall(x -> endswith(x, suffix), files))] function read_bed_chromosome(bedprefix) @@ -198,20 +232,50 @@ function read_bed_chromosome(bedprefix) return SnpData(bed_file, famnm=fam_file, bimnm=bim_file) end -function get_genotypes_from_beds(bedprefix) - snpdata = read_bed_chromosome(bedprefix) - genotypes = DataFrame(convert(Matrix{UInt8}, snpdata.snparray), snpdata.snp_info."snpid") - genotype_map = Union{UInt8, Missing}[0, missing, 1, 2] - for col in names(genotypes) - genotypes[!, col] = [genotype_map[x+1] for x in genotypes[!, col]] +function get_genotypes_from_beds(bedprefix, outprefix) + snpdata = read_bed_chromosome(bedprefix) + code_map = Union{UInt8,Missing}[0, missing, 1, 2] + geno_mat = map(x -> code_map[x+1], snpdata.snparray) + genotypes = DataFrame(geno_mat, snpdata.snp_info.snpid; makeunique=true) + insertcols!(genotypes, 1, :SAMPLE_ID => snpdata.person_info.iid) + counts = countmap.(eachcol(select(genotypes, Not(:SAMPLE_ID)))) + + mapping_df = DataFrame( + snpid = snpdata.snp_info.snpid, + allele1 = snpdata.snp_info.allele1, + allele2 = snpdata.snp_info.allele2, + vₘ = get.(counts, missing, 0), + v₀ = get.(counts, UInt8(0), 0), + v₁ = get.(counts, UInt8(1), 0), + v₂ = get.(counts, UInt8(2), 0), + ) + mapping_df.n = mapping_df.v₀ .+ mapping_df.v₁ .+ mapping_df.v₂ + + # if v₀ < v₂, swap all 0↔2 so that 0 always marks the major homozygote + for (i, snpid) in enumerate(mapping_df.snpid) + if mapping_df.v₀[i] < mapping_df.v₂[i] + col = Symbol(snpid) + genotypes[!, col] = map(x -> x === UInt8(0) ? UInt8(2) : x === UInt8(2) ? UInt8(0) : x, genotypes[!, col]) + + # Swap alleles and counts respectively + allele2 = mapping_df.allele2[i] + mapping_df.allele2[i] = mapping_df.allele1[i] + mapping_df.allele1[i] = allele2 + + v₂ = mapping_df.v₀[i] + mapping_df.v₀[i] = mapping_df.v₂[i] + mapping_df.v₂[i] = v₂ + end end - insertcols!(genotypes, 1, :SAMPLE_ID => snpdata.person_info."iid") + mapping_df.MAF = (mapping_df.v₁ .+ (2 .* mapping_df.v₂)) ./ (2 .* (mapping_df.v₀ .+ mapping_df.v₁ .+ mapping_df.v₂)) + CSV.write("$(outprefix).mapping.txt", mapping_df) + return genotypes end -function make_genotypes(genotype_prefix, config, call_threshold) +function make_genotypes(genotype_prefix, config, call_threshold, outprefix) genotypes = if config["type"] == "gwas" - get_genotypes_from_beds(genotype_prefix) + get_genotypes_from_beds(genotype_prefix, outprefix) else variants_set = Set(retrieve_variants_list(config["variants"])) call_genotypes(genotype_prefix, variants_set, call_threshold) @@ -241,7 +305,7 @@ function inputs_from_config(config_file, genotypes_prefix, traits_file, pcs_file # Genotypes and final dataset verbosity > 0 && @info("Building and writing dataset.") - genotypes = make_genotypes(genotypes_prefix, config, call_threshold) + genotypes = make_genotypes(genotypes_prefix, config, call_threshold, outprefix) dataset = merge(traits, pcs, genotypes) Arrow.write(string(outprefix, ".data.arrow"), dataset) @@ -262,7 +326,8 @@ function inputs_from_config(config_file, genotypes_prefix, traits_file, pcs_file ) elseif config_type == "gwas" variants = filter(!=("SAMPLE_ID"), names(genotypes)) - estimands_from_gwas(dataset, variants, outcomes, confounders; + estimands_from_gwas(config["estimands"], dataset, variants, outcomes, confounders; + extra_treatments=extra_treatments, outcome_extra_covariates=outcome_extra_covariates, positivity_constraint=positivity_constraint, verbosity=verbosity diff --git a/test/data/config_gwas.yaml b/test/data/config_gwas.yaml index af80bbd..4e46c3d 100644 --- a/test/data/config_gwas.yaml +++ b/test/data/config_gwas.yaml @@ -1,5 +1,8 @@ type: gwas +estimands: + - type: ATE + outcome_extra_covariates: - COV_1 diff --git a/test/data/config_gweis_first_order.yaml b/test/data/config_gweis_first_order.yaml new file mode 100644 index 0000000..db3fd5c --- /dev/null +++ b/test/data/config_gweis_first_order.yaml @@ -0,0 +1,14 @@ +type: gwas + +estimands: + - type: AIE + +extra_treatments: + - TREAT_1 + +outcome_extra_covariates: + - COV_1 + +extra_confounders: + - 21003 + - 22001 diff --git a/test/inputs_from_gwas_config.jl b/test/inputs_from_gwas_config.jl index 919daff..158c70f 100644 --- a/test/inputs_from_gwas_config.jl +++ b/test/inputs_from_gwas_config.jl @@ -51,7 +51,7 @@ end # Check dataset dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow"))) @test size(dataset) == (1940, 886) - + @test isfile(joinpath(tmpdir, "final.mapping.txt")) # Check estimands estimands = [] for file in readdir(tmpdir, join=true) @@ -88,6 +88,7 @@ end # Check dataset dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow"))) @test size(dataset) == (1940, 886) + @test isfile(joinpath(tmpdir, "final.mapping.txt")) # Check estimands estimands = [] for file in readdir(tmpdir, join=true) diff --git a/test/inputs_from_gweis_config.jl b/test/inputs_from_gweis_config.jl new file mode 100644 index 0000000..543dbc8 --- /dev/null +++ b/test/inputs_from_gweis_config.jl @@ -0,0 +1,121 @@ +module TestGweisEstimands + +using Test +using SnpArrays +using TargeneCore +using Arrow +using DataFrames +using Serialization +using TMLE +using CSV +using YAML + +TESTDIR = joinpath(pkgdir(TargeneCore), "test") + +include(joinpath(TESTDIR, "testutils.jl")) + +function get_summary_stats(estimands) + outcomes = [TargeneCore.get_outcome(Ψ) for Ψ in estimands] + results = DataFrame(ESTIMAND = estimands, OUTCOME = outcomes) + return sort(combine(groupby(results, :OUTCOME), nrow), :OUTCOME) +end + +function check_estimands_levels_interactions(estimands) + string_treatments = YAML.load_file(joinpath(TESTDIR, "data", "config_gweis_first_order.yaml"))["extra_treatments"] + extra_treatments = [] + for (i,x) in enumerate(string_treatments) + push!(extra_treatments, Symbol(x)) + end + + for Ψ in estimands + # If the two components are present, the first is the 0 -> 1 and the second is the 1 -> 2 + treatment_set = collect(keys(Ψ.args[1].treatment_values)) + variant = setdiff(treatment_set, extra_treatments)[1] + if length(Ψ.args) == 2 + @test Ψ.args[1].treatment_values[variant] == (control = 0x00, case = 0x01) + @test Ψ.args[2].treatment_values[variant] == (control = 0x01, case = 0x02) + else + # Otherwise we check they are one or the other + arg = only(Ψ.args) + @test arg.treatment_values[variant]==(control = 0x00, case = 0x01) || + arg.treatment_values[variant]==( control = 0x01, case = 0x02) + end + end +end + +@testset "Test inputs_from_config gweis: no positivity constraint" begin + tmpdir = mktempdir() + copy!(ARGS, [ + "estimation-inputs", + joinpath(TESTDIR, "data", "config_gweis_first_order.yaml"), + string("--traits-file=", joinpath(TESTDIR, "data", "ukbb_traits.csv")), + string("--pcs-file=", joinpath(TESTDIR, "data", "ukbb_pcs.csv")), + string("--genotypes-prefix=", joinpath(TESTDIR, "data", "ukbb", "genotypes" , "ukbb_1.")), + string("--outprefix=", joinpath(tmpdir, "final")), + "--batchsize=5", + "--verbosity=0", + "--positivity-constraint=0" + ]) + TargeneCore.julia_main() + # Check dataset + dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow"))) + @test size(dataset) == (1940, 886) + @test isfile(joinpath(tmpdir, "final.mapping.txt")) + # Check estimands + estimands = [] + for file in readdir(tmpdir, join=true) + if endswith(file, "jls") + append!(estimands, deserialize(file).estimands) + end + end + @test all(e isa JointEstimand for e in estimands) + + # There are 875 variants in the dataset + summary_stats = get_summary_stats(estimands) + @test summary_stats == DataFrame( + OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2], + nrow = repeat([875], 4) + ) + check_estimands_levels_interactions(estimands) +end + + +@testset "Test inputs_from_config gweis: positivity constraint" begin + tmpdir = mktempdir() + copy!(ARGS, [ + "estimation-inputs", + joinpath(TESTDIR, "data", "config_gweis_first_order.yaml"), + string("--traits-file=", joinpath(TESTDIR, "data", "ukbb_traits.csv")), + string("--pcs-file=", joinpath(TESTDIR, "data", "ukbb_pcs.csv")), + string("--genotypes-prefix=", joinpath(TESTDIR, "data", "ukbb", "genotypes" , "ukbb_1.")), + string("--outprefix=", joinpath(tmpdir, "final")), + "--batchsize=5", + "--verbosity=0", + "--positivity-constraint=0.2" + ]) + TargeneCore.julia_main() + # Check dataset + dataset = DataFrame(Arrow.Table(joinpath(tmpdir, "final.data.arrow"))) + @test size(dataset) == (1940, 886) + @test isfile(joinpath(tmpdir, "final.mapping.txt")) + # Check estimands + estimands = [] + for file in readdir(tmpdir, join=true) + if endswith(file, "jls") + append!(estimands, deserialize(file).estimands) + end + end + # The positivity constraint reduces the number of variants + @test all(e isa JointEstimand for e in estimands) + summary_stats = get_summary_stats(estimands) + @test summary_stats == DataFrame( + OUTCOME = [:BINARY_1, :BINARY_2, :CONTINUOUS_1, :CONTINUOUS_2], + nrow = repeat([146], 4) + ) + + check_estimands_levels_interactions(estimands) +end + + +end +true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 60adb9f..084c6b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,5 +10,6 @@ TESTDIR = joinpath(pkgdir(TargeneCore), "test") @test include(joinpath(TESTDIR, "inputs_from_estimands.jl")) @test include(joinpath(TESTDIR, "inputs_from_config.jl")) @test include(joinpath(TESTDIR, "inputs_from_gwas_config.jl")) + @test include(joinpath(TESTDIR, "inputs_from_gweis_config.jl")) @test include(joinpath(TESTDIR, "sieve_variance.jl")) end