diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py index e9dc0bce..30eff3f1 100644 --- a/malariagen_data/anoph/snp_frq.py +++ b/malariagen_data/anoph/snp_frq.py @@ -8,7 +8,6 @@ from numpydoc_decorator import doc # type: ignore import xarray as xr import numba # type: ignore - from .. import veff from ..util import ( check_types, @@ -581,8 +580,8 @@ def snp_allele_frequencies_advanced( raise ValueError("No SNPs remaining after dropping invariant SNPs.") df_variants = df_variants.loc[loc_variant].reset_index(drop=True) - count = np.compress(loc_variant, count, axis=0) - nobs = np.compress(loc_variant, nobs, axis=0) + count = np.compress(loc_variant, count, axis=0).reshape(-1, count.shape[1]) + nobs = np.compress(loc_variant, nobs, axis=0).reshape(-1, nobs.shape[1]) frequency = np.compress(loc_variant, frequency, axis=0) # Set up variant effect annotator. diff --git a/poetry.lock b/poetry.lock index dd48815c..707db356 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3519,67 +3519,67 @@ zfpy = ["zfpy (>=1.0.0)"] [[package]] name = "numpy" -version = "2.1.3" +version = "2.2.5" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "numpy-2.1.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff"}, - {file = "numpy-2.1.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5"}, - {file = "numpy-2.1.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1"}, - {file = "numpy-2.1.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd"}, - {file = "numpy-2.1.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3"}, - {file = "numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098"}, - {file = "numpy-2.1.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c"}, - {file = "numpy-2.1.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4"}, - {file = "numpy-2.1.3-cp310-cp310-win32.whl", hash = "sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23"}, - {file = "numpy-2.1.3-cp310-cp310-win_amd64.whl", hash = "sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0"}, - {file = "numpy-2.1.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d"}, - {file = "numpy-2.1.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41"}, - {file = "numpy-2.1.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9"}, - {file = "numpy-2.1.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09"}, - {file = "numpy-2.1.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a"}, - {file = "numpy-2.1.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b"}, - {file = "numpy-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee"}, - {file = "numpy-2.1.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0"}, - {file = "numpy-2.1.3-cp311-cp311-win32.whl", hash = "sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9"}, - {file = "numpy-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2"}, - {file = "numpy-2.1.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e"}, - {file = "numpy-2.1.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958"}, - {file = "numpy-2.1.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8"}, - {file = "numpy-2.1.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564"}, - {file = "numpy-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512"}, - {file = "numpy-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b"}, - {file = "numpy-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc"}, - {file = "numpy-2.1.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0"}, - {file = "numpy-2.1.3-cp312-cp312-win32.whl", hash = "sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9"}, - {file = "numpy-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a"}, - {file = "numpy-2.1.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f"}, - {file = "numpy-2.1.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598"}, - {file = "numpy-2.1.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57"}, - {file = "numpy-2.1.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe"}, - {file = "numpy-2.1.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43"}, - {file = "numpy-2.1.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56"}, - {file = "numpy-2.1.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a"}, - {file = "numpy-2.1.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef"}, - {file = "numpy-2.1.3-cp313-cp313-win32.whl", hash = "sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f"}, - {file = "numpy-2.1.3-cp313-cp313-win_amd64.whl", hash = "sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed"}, - {file = "numpy-2.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f"}, - {file = "numpy-2.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4"}, - {file = "numpy-2.1.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e"}, - {file = "numpy-2.1.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0"}, - {file = "numpy-2.1.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408"}, - {file = "numpy-2.1.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6"}, - {file = "numpy-2.1.3-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f"}, - {file = "numpy-2.1.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17"}, - {file = "numpy-2.1.3-cp313-cp313t-win32.whl", hash = "sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48"}, - {file = "numpy-2.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4"}, - {file = "numpy-2.1.3-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f"}, - {file = "numpy-2.1.3-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4"}, - {file = "numpy-2.1.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d"}, - {file = "numpy-2.1.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb"}, - {file = "numpy-2.1.3.tar.gz", hash = "sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761"}, + {file = "numpy-2.2.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1f4a922da1729f4c40932b2af4fe84909c7a6e167e6e99f71838ce3a29f3fe26"}, + {file = "numpy-2.2.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b6f91524d31b34f4a5fee24f5bc16dcd1491b668798b6d85585d836c1e633a6a"}, + {file = "numpy-2.2.5-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:19f4718c9012e3baea91a7dba661dcab2451cda2550678dc30d53acb91a7290f"}, + {file = "numpy-2.2.5-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:eb7fd5b184e5d277afa9ec0ad5e4eb562ecff541e7f60e69ee69c8d59e9aeaba"}, + {file = "numpy-2.2.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6413d48a9be53e183eb06495d8e3b006ef8f87c324af68241bbe7a39e8ff54c3"}, + {file = "numpy-2.2.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7451f92eddf8503c9b8aa4fe6aa7e87fd51a29c2cfc5f7dbd72efde6c65acf57"}, + {file = "numpy-2.2.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:0bcb1d057b7571334139129b7f941588f69ce7c4ed15a9d6162b2ea54ded700c"}, + {file = "numpy-2.2.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:36ab5b23915887543441efd0417e6a3baa08634308894316f446027611b53bf1"}, + {file = "numpy-2.2.5-cp310-cp310-win32.whl", hash = "sha256:422cc684f17bc963da5f59a31530b3936f57c95a29743056ef7a7903a5dbdf88"}, + {file = "numpy-2.2.5-cp310-cp310-win_amd64.whl", hash = "sha256:e4f0b035d9d0ed519c813ee23e0a733db81ec37d2e9503afbb6e54ccfdee0fa7"}, + {file = "numpy-2.2.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c42365005c7a6c42436a54d28c43fe0e01ca11eb2ac3cefe796c25a5f98e5e9b"}, + {file = "numpy-2.2.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:498815b96f67dc347e03b719ef49c772589fb74b8ee9ea2c37feae915ad6ebda"}, + {file = "numpy-2.2.5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:6411f744f7f20081b1b4e7112e0f4c9c5b08f94b9f086e6f0adf3645f85d3a4d"}, + {file = "numpy-2.2.5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:9de6832228f617c9ef45d948ec1cd8949c482238d68b2477e6f642c33a7b0a54"}, + {file = "numpy-2.2.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:369e0d4647c17c9363244f3468f2227d557a74b6781cb62ce57cf3ef5cc7c610"}, + {file = "numpy-2.2.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:262d23f383170f99cd9191a7c85b9a50970fe9069b2f8ab5d786eca8a675d60b"}, + {file = "numpy-2.2.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa70fdbdc3b169d69e8c59e65c07a1c9351ceb438e627f0fdcd471015cd956be"}, + {file = "numpy-2.2.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:37e32e985f03c06206582a7323ef926b4e78bdaa6915095ef08070471865b906"}, + {file = "numpy-2.2.5-cp311-cp311-win32.whl", hash = "sha256:f5045039100ed58fa817a6227a356240ea1b9a1bc141018864c306c1a16d4175"}, + {file = "numpy-2.2.5-cp311-cp311-win_amd64.whl", hash = "sha256:b13f04968b46ad705f7c8a80122a42ae8f620536ea38cf4bdd374302926424dd"}, + {file = "numpy-2.2.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ee461a4eaab4f165b68780a6a1af95fb23a29932be7569b9fab666c407969051"}, + {file = "numpy-2.2.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec31367fd6a255dc8de4772bd1658c3e926d8e860a0b6e922b615e532d320ddc"}, + {file = "numpy-2.2.5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:47834cde750d3c9f4e52c6ca28a7361859fcaf52695c7dc3cc1a720b8922683e"}, + {file = "numpy-2.2.5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:2c1a1c6ccce4022383583a6ded7bbcda22fc635eb4eb1e0a053336425ed36dfa"}, + {file = "numpy-2.2.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d75f338f5f79ee23548b03d801d28a505198297534f62416391857ea0479571"}, + {file = "numpy-2.2.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a801fef99668f309b88640e28d261991bfad9617c27beda4a3aec4f217ea073"}, + {file = "numpy-2.2.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:abe38cd8381245a7f49967a6010e77dbf3680bd3627c0fe4362dd693b404c7f8"}, + {file = "numpy-2.2.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5a0ac90e46fdb5649ab6369d1ab6104bfe5854ab19b645bf5cda0127a13034ae"}, + {file = "numpy-2.2.5-cp312-cp312-win32.whl", hash = "sha256:0cd48122a6b7eab8f06404805b1bd5856200e3ed6f8a1b9a194f9d9054631beb"}, + {file = "numpy-2.2.5-cp312-cp312-win_amd64.whl", hash = "sha256:ced69262a8278547e63409b2653b372bf4baff0870c57efa76c5703fd6543282"}, + {file = "numpy-2.2.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:059b51b658f4414fff78c6d7b1b4e18283ab5fa56d270ff212d5ba0c561846f4"}, + {file = "numpy-2.2.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:47f9ed103af0bc63182609044b0490747e03bd20a67e391192dde119bf43d52f"}, + {file = "numpy-2.2.5-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:261a1ef047751bb02f29dfe337230b5882b54521ca121fc7f62668133cb119c9"}, + {file = "numpy-2.2.5-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:4520caa3807c1ceb005d125a75e715567806fed67e315cea619d5ec6e75a4191"}, + {file = "numpy-2.2.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d14b17b9be5f9c9301f43d2e2a4886a33b53f4e6fdf9ca2f4cc60aeeee76372"}, + {file = "numpy-2.2.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ba321813a00e508d5421104464510cc962a6f791aa2fca1c97b1e65027da80d"}, + {file = "numpy-2.2.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4cbdef3ddf777423060c6f81b5694bad2dc9675f110c4b2a60dc0181543fac7"}, + {file = "numpy-2.2.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:54088a5a147ab71a8e7fdfd8c3601972751ded0739c6b696ad9cb0343e21ab73"}, + {file = "numpy-2.2.5-cp313-cp313-win32.whl", hash = "sha256:c8b82a55ef86a2d8e81b63da85e55f5537d2157165be1cb2ce7cfa57b6aef38b"}, + {file = "numpy-2.2.5-cp313-cp313-win_amd64.whl", hash = "sha256:d8882a829fd779f0f43998e931c466802a77ca1ee0fe25a3abe50278616b1471"}, + {file = "numpy-2.2.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:e8b025c351b9f0e8b5436cf28a07fa4ac0204d67b38f01433ac7f9b870fa38c6"}, + {file = "numpy-2.2.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8dfa94b6a4374e7851bbb6f35e6ded2120b752b063e6acdd3157e4d2bb922eba"}, + {file = "numpy-2.2.5-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:97c8425d4e26437e65e1d189d22dff4a079b747ff9c2788057bfb8114ce1e133"}, + {file = "numpy-2.2.5-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:352d330048c055ea6db701130abc48a21bec690a8d38f8284e00fab256dc1376"}, + {file = "numpy-2.2.5-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b4c0773b6ada798f51f0f8e30c054d32304ccc6e9c5d93d46cb26f3d385ab19"}, + {file = "numpy-2.2.5-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55f09e00d4dccd76b179c0f18a44f041e5332fd0e022886ba1c0bbf3ea4a18d0"}, + {file = "numpy-2.2.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:02f226baeefa68f7d579e213d0f3493496397d8f1cff5e2b222af274c86a552a"}, + {file = "numpy-2.2.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c26843fd58f65da9491165072da2cccc372530681de481ef670dcc8e27cfb066"}, + {file = "numpy-2.2.5-cp313-cp313t-win32.whl", hash = "sha256:1a161c2c79ab30fe4501d5a2bbfe8b162490757cf90b7f05be8b80bc02f7bb8e"}, + {file = "numpy-2.2.5-cp313-cp313t-win_amd64.whl", hash = "sha256:d403c84991b5ad291d3809bace5e85f4bbf44a04bdc9a88ed2bb1807b3360bb8"}, + {file = "numpy-2.2.5-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b4ea7e1cff6784e58fe281ce7e7f05036b3e1c89c6f922a6bfbc0a7e8768adbe"}, + {file = "numpy-2.2.5-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:d7543263084a85fbc09c704b515395398d31d6395518446237eac219eab9e55e"}, + {file = "numpy-2.2.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0255732338c4fdd00996c0421884ea8a3651eea555c3a56b84892b66f696eb70"}, + {file = "numpy-2.2.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d2e3bdadaba0e040d1e7ab39db73e0afe2c74ae277f5614dad53eadbecbbb169"}, + {file = "numpy-2.2.5.tar.gz", hash = "sha256:a9c0d994680cd991b1cb772e8b297340085466a6fe964bc9d4e80f5e2f43c291"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index ca85554c..65a521b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.10,<3.13" -numpy = "<2.2" +numpy = "*" numba = ">=0.60.0" llvmlite = "*" scipy = "*" diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index 9f258c29..a79f1ded 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -2,7 +2,6 @@ import shutil import string from pathlib import Path -from random import choice, choices, randint from typing import Any, Dict, Tuple import numpy as np @@ -29,6 +28,9 @@ # real data in GCS, but which is much smaller and so can be used # for faster test runs. +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture(scope="session") def fixture_dir(): @@ -37,10 +39,10 @@ def fixture_dir(): def simulate_contig(*, low, high, base_composition): - size = np.random.randint(low=low, high=high) + size = int(rng.integers(low=low, high=high)) bases = np.array([b"a", b"c", b"g", b"t", b"n", b"A", b"C", b"G", b"T", b"N"]) p = np.array([base_composition[b] for b in bases]) - seq = np.random.choice(bases, size=size, replace=True, p=p) + seq = rng.choice(bases, size=size, replace=True, p=p) return seq @@ -148,9 +150,9 @@ def simulate_genes(self, *, contig, contig_size): # Simulate genes. for gene_ix in range(self.max_genes): gene_id = f"gene-{contig}-{gene_ix}" - strand = choice(["+", "-"]) - inter_size = randint(self.inter_size_low, self.inter_size_high) - gene_size = randint(self.gene_size_low, self.gene_size_high) + strand = rng.choice(["+", "-"]) + inter_size = int(rng.integers(self.inter_size_low, self.inter_size_high)) + gene_size = int(rng.integers(self.gene_size_low, self.gene_size_high)) if strand == "+": gene_start = cur_fwd + inter_size else: @@ -163,7 +165,11 @@ def simulate_genes(self, *, contig, contig_size): gene_attrs = f"ID={gene_id}" for attr in self.attrs: random_str = "".join( - choices(string.ascii_uppercase + string.digits, k=5) + rng.choice( + list(string.ascii_uppercase + string.digits), + size=5, + replace=True, + ) ) gene_attrs += f";{attr}={random_str}" gene = ( @@ -209,7 +215,7 @@ def simulate_transcripts( # accurate in real data. for transcript_ix in range( - randint(self.n_transcripts_low, self.n_transcripts_high) + int(rng.integers(self.n_transcripts_low, self.n_transcripts_high)) ): transcript_id = f"transcript-{contig}-{gene_ix}-{transcript_ix}" transcript_start = gene_start @@ -257,13 +263,16 @@ def simulate_exons( transcript_size = transcript_end - transcript_start exons = [] exon_end = transcript_start - n_exons = randint(self.n_exons_low, self.n_exons_high) + n_exons = int(rng.integers(self.n_exons_low, self.n_exons_high)) for exon_ix in range(n_exons): exon_id = f"exon-{contig}-{gene_ix}-{transcript_ix}-{exon_ix}" if exon_ix > 0: # Insert an intron between this exon and the previous one. - intron_size = randint( - self.intron_size_low, min(transcript_size, self.intron_size_high) + intron_size = int( + rng.integers( + self.intron_size_low, + min(transcript_size, self.intron_size_high), + ) ) exon_start = exon_end + intron_size if exon_start >= transcript_end: @@ -272,7 +281,7 @@ def simulate_exons( else: # First exon, assume exon starts where the transcript starts. exon_start = transcript_start - exon_size = randint(self.exon_size_low, self.exon_size_high) + exon_size = int(rng.integers(self.exon_size_low, self.exon_size_high)) exon_end = min(exon_start + exon_size, transcript_end) assert exon_end > exon_start exon = ( @@ -308,7 +317,7 @@ def simulate_exons( else: feature_type = self.cds_type # Cheat a little, random phase. - phase = choice([1, 2, 3]) + phase = rng.choice([1, 2, 3]) feature = ( contig, self.source, @@ -363,7 +372,7 @@ def simulate_site_filters(path, contigs, p_pass, n_sites): for contig in contigs: variants = root.require_group(contig).require_group("variants") size = n_sites[contig] - filter_pass = np.random.choice([False, True], size=size, p=p) + filter_pass = rng.choice([False, True], size=size, p=p) variants.create_dataset(name="filter_pass", data=filter_pass) zarr.consolidate_metadata(path) @@ -386,7 +395,7 @@ def simulate_snp_genotypes( contig_n_sites = n_sites[contig] # Simulate genotype calls. - gt = np.random.choice( + gt = rng.choice( np.arange(4, dtype="i1"), size=(contig_n_sites, n_samples, 2), replace=True, @@ -395,9 +404,7 @@ def simulate_snp_genotypes( # Simulate missing calls. n_calls = contig_n_sites * n_samples - loc_missing = np.random.choice( - [False, True], size=n_calls, replace=True, p=p_missing - ) + loc_missing = rng.choice([False, True], size=n_calls, replace=True, p=p_missing) gt.reshape(-1, 2)[loc_missing] = -1 # Store genotype calls. @@ -438,7 +445,7 @@ def simulate_site_annotations(path, genome): p = [0.897754, 0.0, 0.060577, 0.014287, 0.011096, 0.016286] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # codon_nonsyn @@ -447,7 +454,7 @@ def simulate_site_annotations(path, genome): p = [0.91404, 0.001646, 0.018698, 0.065616] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # codon_position @@ -456,7 +463,7 @@ def simulate_site_annotations(path, genome): p = [0.897754, 0.034082, 0.034082, 0.034082] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # seq_cls @@ -477,28 +484,28 @@ def simulate_site_annotations(path, genome): ] for contig in contigs: size = genome[contig].shape[0] - x = np.random.choice(vals, size=size, replace=True, p=p) + x = rng.choice(vals, size=size, replace=True, p=p) grp.create_dataset(name=contig, data=x) # seq_flen grp = root.require_group("seq_flen") for contig in contigs: size = genome[contig].shape[0] - x = np.random.randint(low=0, high=40_000, size=size) + x = rng.integers(low=0, high=40_000, size=size) grp.create_dataset(name=contig, data=x) # seq_relpos_start grp = root.require_group("seq_relpos_start") for contig in contigs: size = genome[contig].shape[0] - x = np.random.beta(a=0.4, b=4, size=size) * 40_000 + x = rng.beta(a=0.4, b=4, size=size) * 40_000 grp.create_dataset(name=contig, data=x) # seq_relpos_stop grp = root.require_group("seq_relpos_stop") for contig in contigs: size = genome[contig].shape[0] - x = np.random.beta(a=0.4, b=4, size=size) * 40_000 + x = rng.beta(a=0.4, b=4, size=size) * 40_000 grp.create_dataset(name=contig, data=x) zarr.consolidate_metadata(path) @@ -514,7 +521,7 @@ def simulate_hap_sites(path, contigs, snp_sites, p_site): # Simulate POS. snp_pos = snp_sites[f"{contig}/variants/POS"][:] - loc_hap_sites = np.random.choice( + loc_hap_sites = rng.choice( [False, True], size=snp_pos.shape[0], p=[1 - p_site, p_site] ) pos = snp_pos[loc_hap_sites] @@ -527,7 +534,7 @@ def simulate_hap_sites(path, contigs, snp_sites, p_site): # Simulate ALT. snp_alt = snp_sites[f"{contig}/variants/ALT"][:] - sim_alt_choice = np.random.choice(3, size=pos.shape[0]) + sim_alt_choice = rng.choice(3, size=pos.shape[0]) alt = np.take_along_axis( snp_alt[loc_hap_sites], indices=sim_alt_choice[:, None], axis=1 )[:, 0] @@ -547,8 +554,8 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): for contig_index, contig in enumerate(contigs): # Simulate AIM positions variable. snp_pos = snp_sites[f"{contig}/variants/POS"][:] - loc_aim_sites = np.random.choice( - snp_pos.shape[0], size=np.random.randint(n_sites_low, n_sites_high) + loc_aim_sites = rng.choice( + snp_pos.shape[0], size=int(rng.integers(n_sites_low, n_sites_high)) ) loc_aim_sites.sort() aim_pos = snp_pos[loc_aim_sites] @@ -564,10 +571,7 @@ def simulate_aim_variants(path, contigs, snp_sites, n_sites_low, n_sites_high): snp_alleles = np.concatenate([snp_ref[:, None], snp_alt], axis=1) aim_site_snp_alleles = snp_alleles[loc_aim_sites] sim_allele_choice = np.vstack( - [ - np.random.choice(4, size=2, replace=False) - for _ in range(len(loc_aim_sites)) - ] + [rng.choice(4, size=2, replace=False) for _ in range(len(loc_aim_sites))] ) aim_alleles = np.take_along_axis( aim_site_snp_alleles, indices=sim_allele_choice, axis=1 @@ -612,7 +616,7 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): # - samples [1D array] [str] # Get a random probability for a sample being high variance, between 0 and 1. - p_variance = np.random.random() + p_variance = rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -626,11 +630,11 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): n_samples = len(df_samples) # Simulate sample_coverage_variance array. - sample_coverage_variance = np.random.uniform(low=0, high=0.5, size=n_samples) + sample_coverage_variance = rng.uniform(low=0, high=0.5, size=n_samples) root.create_dataset(name="sample_coverage_variance", data=sample_coverage_variance) # Simulate sample_is_high_variance array. - sample_is_high_variance = np.random.choice( + sample_is_high_variance = rng.choice( [False, True], size=n_samples, p=[1 - p_variance, p_variance] ) root.create_dataset(name="sample_is_high_variance", data=sample_is_high_variance) @@ -661,9 +665,9 @@ def simulate_cnv_hmm(zarr_path, metadata_path, contigs, contig_sizes): ) # Simulate CN, NormCov, RawCov under calldata. - cn = np.random.randint(low=-1, high=12, size=(n_windows, n_samples)) - normCov = np.random.randint(low=0, high=356, size=(n_windows, n_samples)) - rawCov = np.random.randint(low=-1, high=18465, size=(n_windows, n_samples)) + cn = rng.integers(low=-1, high=12, size=(n_windows, n_samples)) + normCov = rng.integers(low=0, high=356, size=(n_windows, n_samples)) + rawCov = rng.integers(low=-1, high=18465, size=(n_windows, n_samples)) calldata_grp.create_dataset(name="CN", data=cn) calldata_grp.create_dataset(name="NormCov", data=normCov) calldata_grp.create_dataset(name="RawCov", data=rawCov) @@ -705,13 +709,13 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) # - POS [1D array] [int for n_variants] # Get a random probability for choosing allele 1, between 0 and 1. - p_allele = np.random.random() + p_allele = rng.random() # Get a random probability for passing a particular SNP site (position), between 0 and 1. - p_filter_pass = np.random.random() + p_filter_pass = rng.random() # Get a random probability for applying qMerge filter to a particular SNP site (position), between 0 and 1. - p_filter_qMerge = np.random.random() + p_filter_qMerge = rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -733,17 +737,14 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) contig_length_bp = contig_sizes[contig] # Get a random number of CNV alleles ("variants") to simulate. - n_cnv_alleles = np.random.randint(1, 5_000) + n_cnv_alleles = int(rng.integers(1, 5_000)) # Produce a set of random start positions for each allele as a sorted list. - allele_start_pos = sorted( - np.random.randint(1, contig_length_bp, size=n_cnv_alleles) - ) - + allele_start_pos = sorted(rng.integers(1, contig_length_bp, size=n_cnv_alleles)) # Produce a set of random allele lengths for each allele, according to a range. allele_length_bp_min = 100 allele_length_bp_max = 100_000 - allele_lengths_bp = np.random.randint( + allele_lengths_bp = rng.integers( allele_length_bp_min, allele_length_bp_max, size=n_cnv_alleles ) @@ -755,7 +756,7 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) # Simulate the genotype calls. # Note: this is only 2D, unlike SNP, HAP, AIM GT which are 3D - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_cnv_alleles, n_samples), replace=True, @@ -772,8 +773,8 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) variants_grp = contig_grp.require_group("variants") # Simulate the CIEND and CIPOS arrays under variants. - ciend = np.random.randint(low=0, high=13200, size=n_cnv_alleles) - cipos = np.random.randint(low=0, high=37200, size=n_cnv_alleles) + ciend = rng.integers(low=0, high=13200, size=n_cnv_alleles) + cipos = rng.integers(low=0, high=37200, size=n_cnv_alleles) variants_grp.create_dataset(name="CIEND", data=ciend) variants_grp.create_dataset(name="CIPOS", data=cipos) @@ -787,10 +788,10 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) variants_grp.create_dataset(name="ID", data=variant_IDs) # Simulate the filters under variants. - filter_pass = np.random.choice( + filter_pass = rng.choice( [False, True], size=n_cnv_alleles, p=[1 - p_filter_pass, p_filter_pass] ) - filter_qMerge = np.random.choice( + filter_qMerge = rng.choice( [False, True], size=n_cnv_alleles, p=[1 - p_filter_qMerge, p_filter_qMerge] ) variants_grp.create_dataset(name="FILTER_PASS", data=filter_pass) @@ -806,6 +807,8 @@ def simulate_cnv_coverage_calls(zarr_path, metadata_path, contigs, contig_sizes) def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig_sizes): + # Initialize a default RNG with a fixed seed for general random calls + default_rng = np.random.default_rng(seed=123) # Arbitrary seed for reproducibility # zarr_path is the output path to the zarr store # metadata_path is the input path for the sample metadata # contigs is the list of contigs, e.g. Ag has ('2R', '3R', 'X') @@ -828,10 +831,10 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig # - samples [1D array] [str for n_samples] # Get a random probability for a sample being high variance, between 0 and 1. - p_variance = np.random.random() + p_variance = default_rng.random() # Get a random probability for choosing allele 1, between 0 and 1. - p_allele = np.random.random() + p_allele = default_rng.random() # Open a zarr at the specified path. root = zarr.open(zarr_path, mode="w") @@ -845,11 +848,11 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig n_samples = len(df_samples) # Simulate sample_coverage_variance array. - sample_coverage_variance = np.random.uniform(low=0, high=0.5, size=n_samples) + sample_coverage_variance = default_rng.uniform(low=0, high=0.5, size=n_samples) root.create_dataset(name="sample_coverage_variance", data=sample_coverage_variance) # Simulate sample_is_high_variance array. - sample_is_high_variance = np.random.choice( + sample_is_high_variance = default_rng.choice( [False, True], size=n_samples, p=[1 - p_variance, p_variance] ) root.create_dataset(name="sample_is_high_variance", data=sample_is_high_variance) @@ -864,7 +867,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig for i, contig in enumerate(contigs): # Use the same random seed per contig, otherwise n_cnv_variants (and shapes) will not align. unique_seed = fixed_seed + i - np.random.seed(unique_seed) + rng = np.random.default_rng(seed=unique_seed) # Create the contig group. contig_grp = root.require_group(contig) @@ -876,17 +879,17 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig contig_length_bp = contig_sizes[contig] # Get a random number of CNV variants to simulate. - n_cnv_variants = np.random.randint(1, 100) + n_cnv_variants = int(rng.integers(1, 100)) # Produce a set of random start positions for each variant as a sorted list. variant_start_pos = sorted( - np.random.randint(1, contig_length_bp, size=n_cnv_variants) + rng.integers(1, contig_length_bp, size=n_cnv_variants) ) # Produce a set of random lengths for each variant, according to a range. variant_length_bp_min = 100 variant_length_bp_max = 100_000 - variant_lengths_bp = np.random.randint( + variant_lengths_bp = rng.integers( variant_length_bp_min, variant_length_bp_max, size=n_cnv_variants ) @@ -898,7 +901,7 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig # Simulate the genotype calls. # Note: this is only 2D, unlike SNP, HAP, AIM GT which are 3D - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_cnv_variants, n_samples), replace=True, @@ -915,8 +918,8 @@ def simulate_cnv_discordant_read_calls(zarr_path, metadata_path, contigs, contig variants_grp = contig_grp.require_group("variants") # Simulate the StartBreakpointMethod and EndBreakpointMethod arrays. - startBreakpointMethod = np.random.randint(low=-1, high=1, size=n_cnv_variants) - endBreakpointMethod = np.random.randint(low=-1, high=2, size=n_cnv_variants) + startBreakpointMethod = rng.integers(low=-1, high=1, size=n_cnv_variants) + endBreakpointMethod = rng.integers(low=-1, high=2, size=n_cnv_variants) variants_grp.create_dataset( name="StartBreakpointMethod", data=startBreakpointMethod ) @@ -1012,20 +1015,20 @@ def contigs(self) -> Tuple[str, ...]: return tuple(self.config["CONTIGS"]) def random_contig(self): - return choice(self.contigs) + return rng.choice(self.contigs) def random_transcript_id(self): df_transcripts = self.genome_features.query("type == 'mRNA'") transcript_ids = [ gff3_parse_attributes(t)["ID"] for t in df_transcripts.loc[:, "attributes"] ] - transcript_id = choice(transcript_ids) + transcript_id = rng.choice(transcript_ids) return transcript_id def random_region_str(self, region_size=None): contig = self.random_contig() contig_size = self.contig_sizes[contig] - region_start = randint(1, contig_size) + region_start = int(rng.integers(1, contig_size)) if region_size: # Ensure we the region span doesn't exceed the contig size. if contig_size - region_start < region_size: @@ -1033,7 +1036,7 @@ def random_region_str(self, region_size=None): region_end = region_start + region_size else: - region_end = randint(region_start, contig_size) + region_end = int(rng.integers(region_start, contig_size)) region = f"{contig}:{region_start:,}-{region_end:,}" return region @@ -1135,7 +1138,7 @@ def init_public_release_manifest(self): manifest = pd.DataFrame( { "sample_set": ["AG1000G-AO", "AG1000G-BF-A"], - "sample_count": [randint(10, 50), randint(10, 40)], + "sample_count": [int(rng.integers(10, 50)), int(rng.integers(10, 40))], "study_id": ["AG1000G-AO", "AG1000G-BF-1"], "study_url": [ "https://www.malariagen.net/network/where-we-work/AG1000G-AO", @@ -1167,7 +1170,7 @@ def init_pre_release_manifest(self): "1177-VO-ML-LEHMANN-VMF00004", ], # Make sure we have some gambiae, coluzzii and arabiensis. - "sample_count": [randint(20, 60)], + "sample_count": [int(rng.integers(20, 60))], "study_id": ["1177-VO-ML-LEHMANN"], "study_url": [ "https://www.malariagen.net/network/where-we-work/1177-VO-ML-LEHMANN" @@ -1567,7 +1570,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1598,7 +1601,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1629,7 +1632,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, @@ -1695,7 +1698,7 @@ def init_aim_calls(self): ds["sample_id"] = ("samples",), df_samples["sample_id"] # Add call_genotype variable. - gt = np.random.choice( + gt = rng.choice( np.arange(2, dtype="i1"), size=(ds.sizes["variants"], ds.sizes["samples"], 2), replace=True, @@ -2190,7 +2193,7 @@ def init_hap_sites(self): path=path, contigs=self.contigs, snp_sites=self.snp_sites, - p_site=np.random.random(), + p_site=rng.random(), ) def init_haplotypes(self): @@ -2217,7 +2220,7 @@ def init_haplotypes(self): # Simulate haplotypes. analysis = "funestus" - p_1 = np.random.random() + p_1 = rng.random() samples = df_samples["sample_id"].values self.phasing_samples[sample_set, analysis] = samples n_samples = len(samples) @@ -2233,7 +2236,7 @@ def init_haplotypes(self): root.create_dataset(name="samples", data=samples, dtype=str) for contig in self.contigs: n_sites = self.n_hap_sites[analysis][contig] - gt = np.random.choice( + gt = rng.choice( np.array([0, 1], dtype="i1"), size=(n_sites, n_samples, 2), replace=True, diff --git a/tests/anoph/test_aim_data.py b/tests/anoph/test_aim_data.py index 8a1c76b3..4c4e3769 100644 --- a/tests/anoph/test_aim_data.py +++ b/tests/anoph/test_aim_data.py @@ -1,14 +1,14 @@ import itertools -import random - import plotly.graph_objects as go import pytest import xarray as xr from numpy.testing import assert_array_equal - +import numpy as np from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.aim_data import AnophelesAimData +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -88,9 +88,9 @@ def test_aim_calls(aims, ag3_sim_api): all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize sample_query. @@ -179,9 +179,9 @@ def test_plot_aim_heatmap(aims, ag3_sim_api): all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize sample_query. diff --git a/tests/anoph/test_base.py b/tests/anoph/test_base.py index 8d7e2249..43bcfbd4 100644 --- a/tests/anoph/test_base.py +++ b/tests/anoph/test_base.py @@ -8,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.base import AnophelesBase +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -210,7 +213,7 @@ def test_lookup_study(fixture, api): # Set up test. df_sample_sets = api.sample_sets() all_sample_sets = df_sample_sets["sample_set"].values - sample_set = np.random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) study_rec_by_sample_set = api.lookup_study(sample_set) df_sample_set = df_sample_sets.set_index("sample_set").loc[sample_set] diff --git a/tests/anoph/test_cnv_data.py b/tests/anoph/test_cnv_data.py index 54c1ddf1..15bb229b 100644 --- a/tests/anoph/test_cnv_data.py +++ b/tests/anoph/test_cnv_data.py @@ -1,5 +1,3 @@ -import random - import bokeh.models import dask.array as da import numpy as np @@ -13,6 +11,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.cnv_data import AnophelesCnvData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -136,14 +137,14 @@ def test_open_cnv_coverage_calls(fixture, api: AnophelesCnvData): # Check with a sample set that should not exist with pytest.raises(ValueError): root = api.open_cnv_coverage_calls( - sample_set="foobar", analysis=random.choice(api.coverage_calls_analysis_ids) + sample_set="foobar", analysis=rng.choice(api.coverage_calls_analysis_ids) ) # Check with an analysis that should not exist all_sample_sets = api.sample_sets()["sample_set"].to_list() with pytest.raises(ValueError): root = api.open_cnv_coverage_calls( - sample_set=random.choice(all_sample_sets), analysis="foobar" + sample_set=rng.choice(all_sample_sets), analysis="foobar" ) # Check with a sample set and analysis that should not exist @@ -343,15 +344,15 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize region. parametrize_region = [ fixture.random_contig(), - random.sample(api.contigs, 2), + rng.choice(api.contigs, 2, replace=False).tolist(), fixture.random_region_str(), ] @@ -421,11 +422,11 @@ def test_cnv_hmm(fixture, api: AnophelesCnvData): def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() # Parametrize max_coverage_variance. - parametrize_max_coverage_variance = np.random.uniform(low=0, high=1, size=4) + parametrize_max_coverage_variance = rng.uniform(low=0, high=1, size=4) for max_coverage_variance in parametrize_max_coverage_variance: ds = api.cnv_hmm( @@ -465,7 +466,7 @@ def test_cnv_hmm__max_coverage_variance(fixture, api: AnophelesCnvData): def test_cnv_coverage_calls(fixture, api: AnophelesCnvData): # Parametrize sample_sets. all_sample_sets = api.sample_sets()["sample_set"].to_list() - parametrize_sample_sets = random.sample(all_sample_sets, 3) + parametrize_sample_sets = rng.choice(all_sample_sets, 3, replace=False).tolist() # Parametrize analysis. parametrize_analysis = api.coverage_calls_analysis_ids @@ -473,7 +474,7 @@ def test_cnv_coverage_calls(fixture, api: AnophelesCnvData): # Parametrize region. parametrize_region = [ fixture.random_contig(), - random.sample(api.contigs, 2), + rng.choice(api.contigs, 2, replace=False).tolist(), fixture.random_region_str(), ] @@ -551,15 +552,15 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Parametrize contig. parametrize_contig = [ - random.choice(api.contigs), - random.sample(api.contigs, 2), + rng.choice(api.contigs), + rng.choice(api.contigs, 2, replace=False).tolist(), ] for sample_sets in parametrize_sample_sets: @@ -628,13 +629,13 @@ def test_cnv_discordant_read_calls(fixture, api: AnophelesCnvData): # Check with a contig that should not exist with pytest.raises(ValueError): api.cnv_discordant_read_calls( - contig="foobar", sample_sets=random.choice(all_sample_sets) + contig="foobar", sample_sets=rng.choice(all_sample_sets) ) # Check with a sample set that should not exist with pytest.raises(ValueError): api.cnv_discordant_read_calls( - contig=random.choice(api.contigs), sample_sets="foobar" + contig=rng.choice(api.contigs), sample_sets="foobar" ) # Check with a contig and sample set that should not exist @@ -806,11 +807,11 @@ def test_cnv_discordant_read_calls__sample_query_options( def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) fig = api.plot_cnv_hmm_coverage_track( sample=sample_id, @@ -859,11 +860,11 @@ def test_plot_cnv_hmm_coverage_track(fixture, api: AnophelesCnvData): def test_plot_cnv_hmm_coverage(fixture, api: AnophelesCnvData): # Set up test. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) region = fixture.random_contig() df_samples = api.sample_metadata(sample_sets=sample_set) all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) fig = api.plot_cnv_hmm_coverage( sample=sample_id, @@ -913,9 +914,9 @@ def test_plot_cnv_hmm_heatmap_track(fixture, api: AnophelesCnvData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] for region in parametrize_region: diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py index c9787eb1..44db52ac 100644 --- a/tests/anoph/test_cnv_frq.py +++ b/tests/anoph/test_cnv_frq.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pandas as pd import xarray as xr @@ -20,6 +18,8 @@ add_random_year, ) +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -95,10 +95,10 @@ def test_gene_cnv_frequencies_with_str_cohorts( api: AnophelesCnvFrequencyAnalysis, cohorts, ): - region = random.choice(api.contigs) + region = rng.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) # Set up call params. params = dict( @@ -148,8 +148,8 @@ def test_gene_cnv_frequencies_with_min_cohort_size( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - region = random.choice(api.contigs) + sample_sets = rng.choice(all_sample_sets) + region = rng.choice(api.contigs) cohorts = "admin1_year" # Set up call params. @@ -199,13 +199,11 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query( # Pick test parameters at random. sample_sets = None min_cohort_size = 0 - region = random.choice(api.contigs) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + region = rng.choice(api.contigs) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" # Figure out expected cohort labels. @@ -247,13 +245,11 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options( # Pick test parameters at random. sample_sets = None min_cohort_size = 0 - region = random.choice(api.contigs) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + region = rng.choice(api.contigs) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -303,8 +299,8 @@ def test_gene_cnv_frequencies_with_dict_cohorts( ): # Pick test parameters at random. sample_sets = None # all sample sets - min_cohort_size = random.randint(0, 2) - region = random.choice(api.contigs) + min_cohort_size = int(rng.integers(0, 2)) + region = rng.choice(api.contigs) # Create cohorts by country. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -343,10 +339,10 @@ def test_gene_cnv_frequencies_without_drop_invariant( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) - region = random.choice(api.contigs) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) + region = rng.choice(api.contigs) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -398,9 +394,9 @@ def test_gene_cnv_frequencies_with_bad_region( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. params = dict( @@ -424,9 +420,9 @@ def test_gene_cnv_frequencies_with_max_coverage_variance( max_coverage_variance, ): all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) - region = random.choice(api.contigs) + sample_sets = rng.choice(all_sample_sets) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) + region = rng.choice(api.contigs) params = dict( region=region, @@ -503,7 +499,7 @@ def test_gene_cnv_frequencies_advanced_with_sample_query( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" check_gene_cnv_frequencies_advanced( @@ -522,7 +518,7 @@ def test_gene_cnv_frequencies_advanced_with_sample_query_options( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -549,7 +545,7 @@ def test_gene_cnv_frequencies_advanced_with_min_cohort_size( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) if min_cohort_size <= 10: # Expect this to find at least one cohort, so go ahead with full @@ -585,7 +581,7 @@ def test_gene_cnv_frequencies_advanced_with_max_coverage_variance( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) if max_coverage_variance >= 0.4: # Expect this to find at least one cohort, so go ahead with full @@ -620,7 +616,7 @@ def test_gene_cnv_frequencies_advanced_with_nobs_mode( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) check_gene_cnv_frequencies_advanced( api=api, @@ -642,7 +638,7 @@ def test_gene_cnv_frequencies_advanced_with_variant_query( all_sample_sets = api.sample_sets()["sample_set"].to_list() area_by = "admin1_iso" period_by = "year" - region = random.choice(api.contigs) + region = rng.choice(api.contigs) variant_query = "cnv_type == '{variant_query_option}'" check_gene_cnv_frequencies_advanced( @@ -710,16 +706,16 @@ def check_gene_cnv_frequencies_advanced( ): # Pick test parameters at random. if region is None: - region = random.choice(api.contigs) + region = rng.choice(api.contigs) if area_by is None: - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + area_by = rng.choice(["country", "admin1_iso", "admin2_name"]) if period_by is None: - period_by = random.choice(["year", "quarter", "month", "random_year"]) + period_by = rng.choice(["year", "quarter", "month"]) if sample_sets is None: all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = int(rng.integers(0, 2)) if period_by == "random_year": # Add a random_year column to the sample metadata, if there isn't already. diff --git a/tests/anoph/test_dipclust.py b/tests/anoph/test_dipclust.py index c0bbad03..23b04a41 100644 --- a/tests/anoph/test_dipclust.py +++ b/tests/anoph/test_dipclust.py @@ -1,17 +1,18 @@ -import random import pytest from pytest_cases import parametrize_with_cases - +import numpy as np from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.dipclust import AnophelesDipClustAnalysis +rng = np.random.default_rng(seed=42) + def random_transcripts_contig(*, api, contig, n): df_gff = api.genome_features(attributes=["ID", "Parent"]) df_transcripts = df_gff.query(f"type == 'mRNA' and contig == '{contig}'") transcript_ids = df_transcripts["ID"].dropna().to_list() - transcripts = random.sample(transcript_ids, n) + transcripts = rng.choice(transcript_ids, n, replace=False).tolist() return transcripts @@ -97,12 +98,13 @@ def test_plot_diplotype_clustering( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=fixture.random_region_str(region_size=5000), - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric=distance_metric, - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) @@ -127,12 +129,13 @@ def test_plot_diplotype_clustering_advanced( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=fixture.random_region_str(region_size=5000), - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric=distance_metric, - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) @@ -159,13 +162,14 @@ def test_plot_diplotype_clustering_advanced_with_transcript( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=contig, snp_transcript=transcripts, - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric="cityblock", - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) @@ -190,13 +194,14 @@ def test_plot_diplotype_clustering_advanced_with_cnv_region( "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) dipclust_params = dict( region=region, cnv_region=region, - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), distance_metric="cityblock", - sample_query=random.choice(sample_queries), + sample_query=sample_queries[idx], show=False, ) diff --git a/tests/anoph/test_distance.py b/tests/anoph/test_distance.py index 9091ee45..2694b492 100644 --- a/tests/anoph/test_distance.py +++ b/tests/anoph/test_distance.py @@ -1,5 +1,3 @@ -import random - import numpy as np import plotly.graph_objects as go # type: ignore import pytest @@ -9,6 +7,10 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.distance import AnophelesDistanceAnalysis from malariagen_data.anoph import pca_params +from .conftest import Af1Simulator, Ag3Simulator # Import the simulator classes + + +rng = np.random.default_rng(seed=42) @pytest.fixture @@ -81,7 +83,7 @@ def check_biallelic_diplotype_pairwise_distance(*, api, data_params, metric): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = int(rng.integers(4, n_snps_available)) # Run the distance computation. dist, samples, n_snps_used = api.biallelic_diplotype_pairwise_distances( @@ -123,9 +125,9 @@ def test_biallelic_diplotype_pairwise_distance_with_metric( ): all_sample_sets = api.sample_sets()["sample_set"].to_list() data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array([""] + list(api.site_mask_ids), dtype=object)), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) @@ -143,7 +145,7 @@ def check_njt(*, api, data_params, metric, algorithm): ds = api.biallelic_snp_calls(**data_params) n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = int(rng.integers(4, n_snps_available)) # Run the distance computation. Z, samples, n_snps_used = api.njt( @@ -171,15 +173,21 @@ def check_njt(*, api, data_params, metric, algorithm): @parametrize_with_cases("fixture,api", cases=".") def test_njt_with_metric(fixture, api: AnophelesDistanceAnalysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + expected_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + expected_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + expected_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=str(rng.choice(np.array(expected_site_masks, dtype=object))), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) parametrize_metric = "cityblock", "euclidean", "sqeuclidean" - algorithm = random.choice(["dynamic", "rapid", "canonical"]) + algorithm = str(rng.choice(["dynamic", "rapid", "canonical"])) for metric in parametrize_metric: check_njt( api=api, @@ -192,14 +200,20 @@ def test_njt_with_metric(fixture, api: AnophelesDistanceAnalysis): @parametrize_with_cases("fixture,api", cases=".") def test_njt_with_algorithm(fixture, api: AnophelesDistanceAnalysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + expected_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + expected_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + expected_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=str(rng.choice(np.array(expected_site_masks, dtype=object))), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) - metric = random.choice(["cityblock", "euclidean", "sqeuclidean"]) + metric = str(rng.choice(["cityblock", "euclidean", "sqeuclidean"])) parametrize_algorithm = "dynamic", "rapid", "canonical" for algorithm in parametrize_algorithm: check_njt( @@ -213,15 +227,21 @@ def test_njt_with_algorithm(fixture, api: AnophelesDistanceAnalysis): @parametrize_with_cases("fixture,api", cases=".") def test_plot_njt(fixture, api: AnophelesDistanceAnalysis): all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + expected_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + expected_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + expected_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=str(rng.choice(np.array(expected_site_masks, dtype=object))), min_minor_ac=pca_params.min_minor_ac_default, max_missing_an=pca_params.max_missing_an_default, ) - metric = random.choice(["cityblock", "euclidean", "sqeuclidean"]) - algorithm = random.choice(["dynamic", "rapid", "canonical"]) + metric = str(rng.choice(["cityblock", "euclidean", "sqeuclidean"])) + algorithm = str(rng.choice(["dynamic", "rapid", "canonical"])) custom_cohorts = { "male": "sex_call == 'M'", "female": "sex_call == 'F'", @@ -232,8 +252,7 @@ def test_plot_njt(fixture, api: AnophelesDistanceAnalysis): # Check available data. ds = api.biallelic_snp_calls(**data_params) n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - + n_snps = int(rng.integers(4, n_snps_available)) # Exercise the function. for color, symbol in zip(colors, symbols): fig = api.plot_njt( diff --git a/tests/anoph/test_frq.py b/tests/anoph/test_frq.py index 94ac5229..9172aa83 100644 --- a/tests/anoph/test_frq.py +++ b/tests/anoph/test_frq.py @@ -1,9 +1,9 @@ import pytest import plotly.graph_objects as go # type: ignore - import numpy as np import pandas as pd -import random + +rng = np.random.default_rng(seed=42) def check_plot_frequencies_heatmap(api, frq_df): @@ -11,8 +11,15 @@ def check_plot_frequencies_heatmap(api, frq_df): assert isinstance(fig, go.Figure) # Test max_len behaviour. + # Only test if we have more than 1 row, otherwise set max_len to 0 + # should still raise ValueError + if len(frq_df) > 1: + test_max_len = len(frq_df) - 1 + else: + test_max_len = 0 + with pytest.raises(ValueError): - api.plot_frequencies_heatmap(frq_df, show=False, max_len=len(frq_df) - 1) + api.plot_frequencies_heatmap(frq_df, show=False, max_len=test_max_len) # Test index parameter - if None, should use dataframe index. fig = api.plot_frequencies_heatmap(frq_df, show=False, index=None, max_len=None) @@ -41,7 +48,7 @@ def check_plot_frequencies_time_series_with_taxa(api, ds): ds = ds.isel(variants=slice(0, 100)) taxa = list(ds.cohort_taxon.to_dataframe()["cohort_taxon"].unique()) - taxon = random.choice(taxa) + taxon = rng.choice(taxa) # Plot with taxon. fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon) @@ -66,8 +73,10 @@ def check_plot_frequencies_time_series_with_areas(api, ds): # Pick a random area and areas from valid areas. cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist() - area = random.choice(cohorts_areas) - areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas))) + area = rng.choice(cohorts_areas) + areas = rng.choice( + cohorts_areas, int(rng.integers(1, len(cohorts_areas) + 1)), replace=False + ).tolist() # Plot with area. fig = api.plot_frequencies_time_series(ds, show=False, areas=area) @@ -105,8 +114,8 @@ def add_random_year(*, api): # Otherwise we'll get multiple columns with different suffixes, e.g. 'random_year_x' and 'random_year_y'. if "random_year" not in sample_metadata_df.columns: # Avoid "ValueError: No cohorts available" by selecting only a few different years at random. - selected_years = random.sample(range(1900, 2100), 3) - random_years_as_list = np.random.choice(selected_years, len(sample_metadata_df)) + selected_years = rng.choice(range(1900, 2100), size=3, replace=False) + random_years_as_list = rng.choice(selected_years, len(sample_metadata_df)) random_years_as_period_index = pd.PeriodIndex(random_years_as_list, freq="Y") extra_metadata_df = pd.DataFrame( { diff --git a/tests/anoph/test_fst.py b/tests/anoph/test_fst.py index 098f0853..6ba6908f 100644 --- a/tests/anoph/test_fst.py +++ b/tests/anoph/test_fst.py @@ -1,5 +1,4 @@ import itertools -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -11,6 +10,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.fst import AnophelesFstAnalysis +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -82,16 +83,16 @@ def test_fst_gwss(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() - countries = random.sample(all_countries, 2) + countries = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == {countries[0]!r}" cohort2_query = f"country == {countries[1]!r}" fst_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, cohort1_query=cohort1_query, cohort2_query=cohort2_query, - site_mask=random.choice(api.site_mask_ids), - window_size=random.randint(10, 50), + site_mask=rng.choice(api.site_mask_ids), + window_size=int(rng.integers(10, 50)), min_cohort_size=1, ) @@ -121,17 +122,17 @@ def test_average_fst(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() - countries = random.sample(all_countries, 2) + countries = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == {countries[0]!r}" cohort2_query = f"country == {countries[1]!r}" fst_params = dict( - region=random.choice(api.contigs), + region=rng.choice(api.contigs), sample_sets=all_sample_sets, cohort1_query=cohort1_query, cohort2_query=cohort2_query, - site_mask=random.choice(api.site_mask_ids), + site_mask=rng.choice(api.site_mask_ids), min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run main gwss function under test. @@ -149,15 +150,15 @@ def test_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() - countries = random.sample(all_countries, 2) + countries = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == {countries[0]!r}" cohort2_query = f"country == {countries[1]!r}" fst_params = dict( - region=random.choice(api.contigs), + region=rng.choice(api.contigs), sample_sets=all_sample_sets, cohort1_query=cohort1_query, cohort2_query=cohort2_query, - site_mask=random.choice(api.site_mask_ids), + site_mask=rng.choice(api.site_mask_ids), min_cohort_size=1000, ) @@ -221,15 +222,15 @@ def test_pairwise_average_fst_with_str_cohorts( ): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -240,8 +241,8 @@ def test_pairwise_average_fst_with_str_cohorts( def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) cohorts = "admin1_year" fst_params = dict( region=region, @@ -249,7 +250,7 @@ def test_pairwise_average_fst_with_min_cohort_size(fixture, api: AnophelesFstAna sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=15, - n_jack=random.randint(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -262,15 +263,15 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].dropna().unique().tolist() cohorts = {country: f"country == '{country}'" for country in all_countries} - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, sample_sets=all_sample_sets, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -281,12 +282,12 @@ def test_pairwise_average_fst_with_dict_cohorts(fixture, api: AnophelesFstAnalys def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalysis): # Set up test parameters. all_taxa = api.sample_metadata()["taxon"].dropna().unique().tolist() - taxon = random.choice(all_taxa) + taxon = rng.choice(all_taxa) sample_query = f"taxon == '{taxon}'" all_sample_sets = api.sample_sets()["sample_set"].to_list() cohorts = "admin2_month" - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, @@ -294,7 +295,7 @@ def test_pairwise_average_fst_with_sample_query(fixture, api: AnophelesFstAnalys sample_query=sample_query, site_mask=site_mask, min_cohort_size=1, - n_jack=random.randint(10, 200), + n_jack=int(rng.integers(10, 200)), ) # Run checks. @@ -306,8 +307,8 @@ def test_pairwise_average_fst_with_bad_cohorts(fixture, api: AnophelesFstAnalysi # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() cohorts = "foobar" - region = random.choice(api.contigs) - site_mask = random.choice(api.site_mask_ids) + region = rng.choice(api.contigs) + site_mask = rng.choice(api.site_mask_ids) fst_params = dict( region=region, cohorts=cohorts, diff --git a/tests/anoph/test_g123.py b/tests/anoph/test_g123.py index ab5e2553..a2a68e47 100644 --- a/tests/anoph/test_g123.py +++ b/tests/anoph/test_g123.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -8,6 +7,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.g123 import AnophelesG123Analysis +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -103,9 +105,9 @@ def test_g123_gwss_with_default_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -118,10 +120,10 @@ def test_g123_gwss_with_phased_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), - sites=random.choice(api.phasing_analysis_ids), - sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + contig=rng.choice(api.contigs), + sites=rng.choice(api.phasing_analysis_ids), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -134,11 +136,11 @@ def test_g123_gwss_with_segregating_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sites="segregating", - site_mask=random.choice(api.site_mask_ids), - sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + site_mask=rng.choice(api.site_mask_ids), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -151,11 +153,11 @@ def test_g123_gwss_with_all_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sites="all", site_mask=None, - sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, ) @@ -168,9 +170,9 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() g123_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=10, sites="foobar", ) @@ -184,12 +186,12 @@ def test_g123_gwss_with_bad_sites(fixture, api: AnophelesG123Analysis): def test_g123_calibration(fixture, api: AnophelesG123Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = np.random.randint(100, 500, size=random.randint(2, 5)).tolist() - window_sizes = sorted([int(x) for x in window_sizes]) + window_sizes = rng.integers(100, 500, size=int(rng.integers(2, 5))).tolist() + window_sizes = sorted(window_sizes) g123_params = dict( - contig=random.choice(api.contigs), - sites=random.choice(api.phasing_analysis_ids), - sample_sets=[random.choice(all_sample_sets)], + contig=rng.choice(api.contigs), + sites=rng.choice(api.phasing_analysis_ids), + sample_sets=[rng.choice(all_sample_sets)], min_cohort_size=10, window_sizes=window_sizes, ) diff --git a/tests/anoph/test_genome_features.py b/tests/anoph/test_genome_features.py index eaef4ee4..1f12dac4 100644 --- a/tests/anoph/test_genome_features.py +++ b/tests/anoph/test_genome_features.py @@ -10,6 +10,9 @@ from malariagen_data.anoph.genome_features import AnophelesGenomeFeaturesData from malariagen_data.util import Region, resolve_region +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -145,7 +148,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): # If there are no genes, we cannot label them. if not genes_df.empty: # Get a random number of genes to sample. - random_genes_n = np.random.randint(low=1, high=len(genes_df) + 1) + random_genes_n = int(rng.integers(low=1, high=len(genes_df) + 1)) # Get a random sample of genes. random_sample_genes_df = genes_df.sample(n=random_genes_n) @@ -166,7 +169,7 @@ def test_plot_genes_with_gene_labels(fixture, api: AnophelesGenomeFeaturesData): def test_plot_transcript(fixture, api: AnophelesGenomeFeaturesData): for contig in fixture.contigs: df_transcripts = api.genome_features(region=contig).query("type == 'mRNA'") - transcript = np.random.choice(df_transcripts["ID"].values) + transcript = rng.choice(df_transcripts["ID"].values) fig = api.plot_transcript(transcript=transcript, show=False) assert isinstance(fig, bokeh.plotting.figure) @@ -212,7 +215,7 @@ def test_genome_features_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" df = api.genome_features(region=region) assert isinstance(df, pd.DataFrame) diff --git a/tests/anoph/test_genome_sequence.py b/tests/anoph/test_genome_sequence.py index d2eca6bb..6dafd030 100644 --- a/tests/anoph/test_genome_sequence.py +++ b/tests/anoph/test_genome_sequence.py @@ -10,6 +10,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.genome_sequence import AnophelesGenomeSequenceData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -79,7 +82,7 @@ def test_genome_sequence_region(fixture, api): for contig in fixture.contigs: contig_seq = api.genome_sequence(region=contig) # Pick a random start and stop position. - start, stop = sorted(np.random.randint(low=1, high=len(contig_seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(contig_seq), size=2)) region = f"{contig}:{start:,}-{stop:,}" seq = api.genome_sequence(region=region) assert isinstance(seq, da.Array) @@ -118,7 +121,7 @@ def test_genome_sequence_virtual_contigs(ag3_sim_api, chrom): ) # Test with region. - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" seq_region = api.genome_sequence(region=region) assert isinstance(seq_region, da.Array) diff --git a/tests/anoph/test_h12.py b/tests/anoph/test_h12.py index 29262e94..6cda7e02 100644 --- a/tests/anoph/test_h12.py +++ b/tests/anoph/test_h12.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -9,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.h12 import AnophelesH12Analysis, haplotype_frequencies +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -104,11 +106,12 @@ def test_haplotype_frequencies(): def test_h12_calibration(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - window_sizes = np.random.randint(100, 500, size=random.randint(2, 5)).tolist() + window_sizes = rng.integers(100, 500, size=int(rng.integers(2, 5))).tolist() + # Convert window_sizes to a flattened list of integers window_sizes = sorted(set([int(x) for x in window_sizes])) h12_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], window_sizes=window_sizes, min_cohort_size=5, ) @@ -170,9 +173,9 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() h12_params = dict( - contig=random.choice(api.contigs), - sample_sets=[random.choice(all_sample_sets)], - window_size=random.randint(100, 500), + contig=rng.choice(api.contigs), + sample_sets=[rng.choice(all_sample_sets)], + window_size=int(rng.integers(100, 500)), min_cohort_size=5, ) @@ -184,9 +187,9 @@ def test_h12_gwss_with_default_analysis(fixture, api: AnophelesH12Analysis): def test_h12_gwss_with_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = [random.choice(all_sample_sets)] - contig = random.choice(api.contigs) - window_size = random.randint(100, 500) + sample_sets = [rng.choice(all_sample_sets)] + contig = rng.choice(api.contigs) + window_size = int(rng.integers(100, 500)) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -234,13 +237,13 @@ def test_h12_gwss_multi_with_default_analysis(fixture, api: AnophelesH12Analysis # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" h12_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) @@ -254,15 +257,15 @@ def test_h12_gwss_multi_with_window_size_dict(fixture, api: AnophelesH12Analysis # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" h12_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, window_size={ - "cohort1": random.randint(100, 500), - "cohort2": random.randint(100, 500), + "cohort1": int(rng.integers(100, 500)), + "cohort2": int(rng.integers(100, 500)), }, min_cohort_size=1, cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, @@ -277,10 +280,10 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -313,7 +316,7 @@ def test_h12_gwss_multi_with_analysis(fixture, api: AnophelesH12Analysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=min(n1, n2), cohorts={"cohort1": cohort1_query, "cohort2": cohort2_query}, ) diff --git a/tests/anoph/test_h1x.py b/tests/anoph/test_h1x.py index 627717b5..5c8528e0 100644 --- a/tests/anoph/test_h1x.py +++ b/tests/anoph/test_h1x.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases import numpy as np @@ -9,6 +8,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.h1x import AnophelesH1XAnalysis, haplotype_joint_frequencies +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -141,13 +142,13 @@ def test_h1x_gwss_with_default_analysis(fixture, api: AnophelesH1XAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" h1x_params = dict( - contig=random.choice(api.contigs), + contig=rng.choice(api.contigs), sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=1, cohort1_query=cohort1_query, cohort2_query=cohort2_query, @@ -162,10 +163,10 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis): # Set up test parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_countries = api.sample_metadata()["country"].unique().tolist() - country1, country2 = random.sample(all_countries, 2) + country1, country2 = rng.choice(all_countries, 2, replace=False).tolist() cohort1_query = f"country == '{country1}'" cohort2_query = f"country == '{country2}'" - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) for analysis in api.phasing_analysis_ids: # Check if any samples available for the given phasing analysis. @@ -198,7 +199,7 @@ def test_h1x_gwss_with_analysis(fixture, api: AnophelesH1XAnalysis): analysis=analysis, contig=contig, sample_sets=all_sample_sets, - window_size=random.randint(100, 500), + window_size=int(rng.integers(100, 500)), min_cohort_size=min(n1, n2), cohort1_query=cohort1_query, cohort2_query=cohort2_query, diff --git a/tests/anoph/test_hap_data.py b/tests/anoph/test_hap_data.py index 3b633b3e..f8b9ea7b 100644 --- a/tests/anoph/test_hap_data.py +++ b/tests/anoph/test_hap_data.py @@ -1,5 +1,3 @@ -import random - import dask.array as da import numpy as np import pytest @@ -12,6 +10,10 @@ from malariagen_data.anoph.hap_data import AnophelesHapData +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesHapData( @@ -322,9 +324,9 @@ def test_haplotypes_with_sample_sets_param(fixture, api: AnophelesHapData): all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -342,7 +344,7 @@ def test_haplotypes_with_sample_sets_param(fixture, api: AnophelesHapData): def test_haplotypes_with_region_param(fixture, api: AnophelesHapData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) analysis = api.phasing_analysis_ids[0] # Parametrize region. @@ -352,7 +354,7 @@ def test_haplotypes_with_region_param(fixture, api: AnophelesHapData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -370,7 +372,7 @@ def test_haplotypes_with_region_param(fixture, api: AnophelesHapData): def test_haplotypes_with_analysis_param(fixture, api: AnophelesHapData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize analysis. @@ -395,7 +397,7 @@ def test_haplotypes_with_sample_query_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] @@ -422,7 +424,7 @@ def test_haplotypes_with_sample_query_options_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] sample_query_options = { @@ -461,12 +463,16 @@ def test_haplotypes_with_cohort_size_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] # Parametrize over cohort_size. - parametrize_cohort_size = [random.randint(1, 10), random.randint(10, 50), 1_000] + parametrize_cohort_size = [ + int(rng.integers(1, 10)), + int(rng.integers(10, 50)), + 1_000, + ] for cohort_size in parametrize_cohort_size: check_haplotypes( fixture=fixture, @@ -487,14 +493,14 @@ def test_haplotypes_with_min_cohort_size_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] # Parametrize over min_cohort_size. parametrize_min_cohort_size = [ - random.randint(1, 10), - random.randint(10, 50), + int(rng.integers(1, 10)), + int(rng.integers(10, 50)), 1_000, ] for min_cohort_size in parametrize_min_cohort_size: @@ -517,14 +523,14 @@ def test_haplotypes_with_max_cohort_size_param( # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() analysis = api.phasing_analysis_ids[0] # Parametrize over max_cohort_size. parametrize_max_cohort_size = [ - random.randint(1, 10), - random.randint(10, 50), + int(rng.integers(1, 10)), + int(rng.integers(10, 50)), 1_000, ] for max_cohort_size in parametrize_max_cohort_size: @@ -605,7 +611,9 @@ def test_haplotypes_virtual_contigs( # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(map(int, np.random.randint(low=1, high=len(seq), size=2))) + start, stop = sorted( + [int(x) for x in rng.integers(low=1, high=len(seq), size=2)] + ) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -655,7 +663,7 @@ def test_haplotype_sites(fixture, api: AnophelesHapData): # Test with genome feature ID. df_gff = api.genome_features(attributes=["ID"]) - region = random.choice(df_gff["ID"].dropna().to_list()) + region = rng.choice(df_gff["ID"].dropna().to_list()) check_haplotype_sites(api=api, region=region) @@ -677,7 +685,7 @@ def test_haplotype_sites_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py index 2b9f3bfe..26e2c1fa 100644 --- a/tests/anoph/test_hap_frq.py +++ b/tests/anoph/test_hap_frq.py @@ -1,5 +1,3 @@ -import random - import pandas as pd import numpy as np import xarray as xr @@ -18,6 +16,8 @@ add_random_year, ) +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -168,8 +168,8 @@ def test_hap_frequencies_with_str_cohorts( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) region = fixture.random_region_str() # Set up call params. @@ -210,8 +210,8 @@ def test_hap_frequencies_advanced( fixture, api: AnophelesHapFrequencyAnalysis, area_by, period_by ): all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - min_cohort_size = random.randint(0, 2) + sample_sets = rng.choice(all_sample_sets) + min_cohort_size = int(rng.integers(0, 2)) region = fixture.random_region_str() if period_by == "random_year": diff --git a/tests/anoph/test_hapclust.py b/tests/anoph/test_hapclust.py index 454b6e40..068c229e 100644 --- a/tests/anoph/test_hapclust.py +++ b/tests/anoph/test_hapclust.py @@ -1,11 +1,12 @@ -import random import pytest from pytest_cases import parametrize_with_cases - +import numpy as np from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.hapclust import AnophelesHapClustAnalysis +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -86,11 +87,12 @@ def test_plot_haplotype_clustering(fixture, api: AnophelesHapClustAnalysis): "ward", ) sample_queries = (None, "sex_call == 'F'") + idx = rng.choice([0, 1]) # to genrate a random index hapclust_params = dict( region=fixture.random_region_str(region_size=5000), - sample_sets=[random.choice(all_sample_sets)], - linkage_method=random.choice(linkage_methods), - sample_query=random.choice(sample_queries), + sample_sets=[rng.choice(all_sample_sets)], + linkage_method=str(rng.choice(linkage_methods)), + sample_query=sample_queries[idx], show=False, ) diff --git a/tests/anoph/test_igv.py b/tests/anoph/test_igv.py index a468af72..854b4b0f 100644 --- a/tests/anoph/test_igv.py +++ b/tests/anoph/test_igv.py @@ -1,5 +1,4 @@ -import random - +import numpy as np import igv_notebook # type: ignore import pytest from pytest_cases import parametrize_with_cases @@ -8,6 +7,8 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.igv import AnophelesIgv +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -81,7 +82,7 @@ def test_igv(fixture, api: AnophelesIgv): @parametrize_with_cases("fixture,api", cases=".") def test_view_alignments(fixture, api: AnophelesIgv): region = fixture.random_region_str() - sample = random.choice(api.sample_metadata()["sample_id"]) + sample = rng.choice(api.sample_metadata()["sample_id"]) ret = api.view_alignments(region=region, sample=sample, init=False) # No return value to avoid cluttering notebook output. assert ret is None diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index e5fa667a..54dafd67 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pandas as pd import plotly.graph_objects as go # type: ignore @@ -10,6 +8,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.pca import AnophelesPca from malariagen_data.anoph import pca_params +from .conftest import Af1Simulator, Ag3Simulator + +rng = np.random.default_rng(seed=42) @pytest.fixture @@ -82,9 +83,9 @@ def test_pca_plotting(fixture, api: AnophelesPca): # Parameters for selecting input data. all_sample_sets = api.sample_sets()["sample_set"].to_list() data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array([""] + list(api.site_mask_ids), dtype=object)), ) ds = api.biallelic_snp_calls( min_minor_ac=pca_params.min_minor_ac_default, @@ -95,10 +96,10 @@ def test_pca_plotting(fixture, api: AnophelesPca): # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) + n_snps = int(rng.integers(4, n_snps_available)) # PC3 required for plot_pca_coords_3d() assert min(n_samples, n_snps) > 3 - n_components = random.randint(3, min(n_samples, n_snps, 10)) + n_components = int(rng.integers(3, min(n_samples, n_snps, 10))) # Run the PCA. pca_df, pca_evr = api.pca( @@ -167,10 +168,17 @@ def test_pca_plotting(fixture, api: AnophelesPca): def test_pca_exclude_samples(fixture, api: AnophelesPca): # Parameters for selecting input data. all_sample_sets = api.sample_sets()["sample_set"].to_list() + + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array(valid_site_masks, dtype=object)), ) ds = api.biallelic_snp_calls( min_minor_ac=pca_params.min_minor_ac_default, @@ -179,15 +187,17 @@ def test_pca_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = random.randint(1, 5) + n_samples_excluded = int(rng.integers(1, 5)) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, n_samples_excluded) + exclude_samples = rng.choice( + samples, int(n_samples_excluded), replace=False + ).tolist() # PCA parameters. n_samples = ds.sizes["samples"] - n_samples_excluded n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - n_components = random.randint(2, min(n_samples, n_snps, 10)) + n_snps = int(rng.integers(4, n_snps_available)) + n_components = int(rng.integers(2, min(n_samples, n_snps, 10))) # Run the PCA. pca_df, pca_evr = api.pca( @@ -228,10 +238,16 @@ def test_pca_exclude_samples(fixture, api: AnophelesPca): def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): # Parameters for selecting input data. all_sample_sets = api.sample_sets()["sample_set"].to_list() + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array(valid_site_masks, dtype=object)), ) ds = api.biallelic_snp_calls( min_minor_ac=pca_params.min_minor_ac_default, @@ -240,15 +256,17 @@ def test_pca_fit_exclude_samples(fixture, api: AnophelesPca): ) # Exclusion parameters. - n_samples_excluded = random.randint(1, 5) + n_samples_excluded = int(rng.integers(1, 5)) samples = ds["sample_id"].values.tolist() - exclude_samples = random.sample(samples, n_samples_excluded) + exclude_samples = rng.choice( + samples, int(n_samples_excluded), replace=False + ).tolist() # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(4, n_snps_available) - n_components = random.randint(2, min(n_samples, n_snps, 10)) + n_snps = int(rng.integers(4, n_snps_available)) + n_components = int(rng.integers(2, min(n_samples, n_snps, 10))) # Run the PCA. pca_df, pca_evr = api.pca( diff --git a/tests/anoph/test_plink_converter.py b/tests/anoph/test_plink_converter.py index 44e476cd..323af84c 100644 --- a/tests/anoph/test_plink_converter.py +++ b/tests/anoph/test_plink_converter.py @@ -1,4 +1,3 @@ -import random import pytest from pytest_cases import parametrize_with_cases @@ -8,9 +7,11 @@ import os import bed_reader - +import numpy as np from numpy.testing import assert_array_equal +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -83,13 +84,13 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): all_sample_sets = api.sample_sets()["sample_set"].to_list() data_params = dict( - region=random.choice(api.contigs), - sample_sets=random.sample(all_sample_sets, 2), - site_mask=random.choice((None,) + api.site_mask_ids), + region=rng.choice(api.contigs), + sample_sets=rng.choice(all_sample_sets, 2, replace=False).tolist(), + site_mask=rng.choice(np.array([""] + list(api.site_mask_ids), dtype=object)), min_minor_ac=1, max_missing_an=1, thin_offset=1, - random_seed=random.randint(1, 2000), + random_seed=int(rng.integers(1, 2000)), ) # Load a ds containing the randomly generated samples and regions to get the number of available snps to subset from. @@ -98,7 +99,7 @@ def test_plink_converter(fixture, api: PlinkConverter, tmp_path): ) n_snps_available = ds.sizes["variants"] - n_snps = random.randint(1, n_snps_available) + n_snps = int(rng.integers(1, n_snps_available)) # Define plink params. plink_params = dict(output_dir=str(tmp_path), n_snps=n_snps, **data_params) diff --git a/tests/anoph/test_sample_metadata.py b/tests/anoph/test_sample_metadata.py index 6df29011..89e6b224 100644 --- a/tests/anoph/test_sample_metadata.py +++ b/tests/anoph/test_sample_metadata.py @@ -1,5 +1,3 @@ -import random - import ipyleaflet # type: ignore import numpy as np import pandas as pd @@ -14,6 +12,9 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.sample_metadata import AnophelesSampleMetadata +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): @@ -120,7 +121,7 @@ def test_general_metadata_with_single_sample_set(fixture, api: AnophelesSampleMe df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.general_metadata(sample_sets=sample_set) @@ -139,7 +140,7 @@ def test_general_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.general_metadata(sample_sets=sample_sets) @@ -153,7 +154,7 @@ def test_general_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_general_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up the test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.general_metadata(sample_sets=release) @@ -200,7 +201,7 @@ def test_sequence_qc_metadata_with_single_sample_set( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.sequence_qc_metadata(sample_sets=sample_set) @@ -221,7 +222,7 @@ def test_sequence_qc_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.sequence_qc_metadata(sample_sets=sample_sets) @@ -237,7 +238,7 @@ def test_sequence_qc_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_sequence_qc_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up the test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.sequence_qc_metadata(sample_sets=release) @@ -311,7 +312,7 @@ def test_aim_metadata_with_single_sample_set(ag3_sim_api): df_sample_sets = ag3_sim_api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = ag3_sim_api.aim_metadata(sample_sets=sample_set) @@ -329,7 +330,7 @@ def test_aim_metadata_with_multiple_sample_sets(ag3_sim_api): df_sample_sets = ag3_sim_api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = ag3_sim_api.aim_metadata(sample_sets=sample_sets) @@ -344,7 +345,7 @@ def test_aim_metadata_with_release(ag3_sim_api): # N.B., only Ag3 has AIM data. # Set up the test. - release = random.choice(ag3_sim_api.releases) + release = rng.choice(ag3_sim_api.releases) # Call function to be tested. df = ag3_sim_api.aim_metadata(sample_sets=release) @@ -423,7 +424,7 @@ def test_cohorts_metadata_with_single_sample_set(fixture, api: AnophelesSampleMe df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.cohorts_metadata(sample_sets=sample_set) @@ -442,7 +443,7 @@ def test_cohorts_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.cohorts_metadata(sample_sets=sample_sets) @@ -456,7 +457,7 @@ def test_cohorts_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_cohorts_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.cohorts_metadata(sample_sets=release) @@ -517,7 +518,7 @@ def test_sample_metadata_with_single_sample_set(fixture, api: AnophelesSampleMet df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. df = api.sample_metadata(sample_sets=sample_set) @@ -544,7 +545,7 @@ def test_sample_metadata_with_multiple_sample_sets( df_sample_sets = api.sample_sets().set_index("sample_set") sample_count = df_sample_sets["sample_count"] all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() # Call function to be tested. df = api.sample_metadata(sample_sets=sample_sets) @@ -566,7 +567,7 @@ def test_sample_metadata_with_multiple_sample_sets( @parametrize_with_cases("fixture,api", cases=".") def test_sample_metadata_with_release(fixture, api: AnophelesSampleMetadata): # Set up test. - release = random.choice(api.releases) + release = rng.choice(api.releases) # Call function to be tested. df = api.sample_metadata(sample_sets=release) @@ -590,10 +591,10 @@ def test_sample_metadata_with_duplicate_sample_sets( fixture, api: AnophelesSampleMetadata ): # Set up test. - release = random.choice(api.releases) + release = rng.choice(api.releases) df_sample_sets = api.sample_sets(release=release).set_index("sample_set") all_sample_sets = df_sample_sets.index.to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) # Call function to be tested. assert_frame_equal( @@ -945,7 +946,7 @@ def test_plot_sample_location_mapbox(fixture, api): # Get test sample_sets. df_sample_sets = api.sample_sets().set_index("sample_set") all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() fig = api.plot_sample_location_mapbox( sample_sets=sample_sets, @@ -960,7 +961,7 @@ def test_plot_sample_location_geo(fixture, api): # Get test sample_sets. df_sample_sets = api.sample_sets().set_index("sample_set") all_sample_sets = df_sample_sets.index.to_list() - sample_sets = random.sample(all_sample_sets, 2) + sample_sets = rng.choice(all_sample_sets, 2, replace=False).tolist() fig = api.plot_sample_location_geo( sample_sets=sample_sets, @@ -975,7 +976,7 @@ def test_lookup_sample(fixture, api): # Set up test. df_samples = api.sample_metadata() all_sample_ids = df_samples["sample_id"].values - sample_id = np.random.choice(all_sample_ids) + sample_id = rng.choice(all_sample_ids) # Check we get the same sample_id back. sample_rec_by_sample_id = api.lookup_sample(sample_id) diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index ba4f1b69..23607aa3 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -1,4 +1,3 @@ -import random from itertools import product import allel # type: ignore @@ -15,6 +14,10 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.base_params import DEFAULT from malariagen_data.anoph.snp_data import AnophelesSnpData +from .conftest import Af1Simulator, Ag3Simulator + +# Global RNG for test file; functions may override with local RNG for reproducibility +rng = np.random.default_rng(seed=42) @pytest.fixture @@ -178,7 +181,7 @@ def test_site_filters(fixture, api: AnophelesSnpData): # Test with genome feature ID. df_gff = api.genome_features(attributes=["ID"]) - region = random.choice(df_gff["ID"].dropna().to_list()) + region = rng.choice(df_gff["ID"].dropna().to_list()) check_site_filters(api, mask=mask, region=region) @@ -198,7 +201,7 @@ def check_snp_sites(api: AnophelesSnpData, region): assert pos.shape[0] == ref.shape[0] == alt.shape[0] # Apply site mask. - mask = random.choice(api.site_mask_ids) + mask = rng.choice(api.site_mask_ids) filter_pass = api.site_filters(region=region, mask=mask).compute() n_pass = np.count_nonzero(filter_pass) pos_pass = api.snp_sites( @@ -234,7 +237,7 @@ def test_snp_sites(fixture, api: AnophelesSnpData): # Test with genome feature ID. df_gff = api.genome_features(attributes=["ID"]) - region = random.choice(df_gff["ID"].dropna().to_list()) + region = rng.choice(df_gff["ID"].dropna().to_list()) check_snp_sites(api=api, region=region) @@ -256,7 +259,7 @@ def test_snp_sites_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -322,11 +325,11 @@ def test_site_annotations(fixture, api): parametrize_region = [ contig, fixture.random_region_str(), - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Parametrize site_mask. - parametrize_site_mask = (None, random.choice(api.site_mask_ids)) + parametrize_site_mask = (None, rng.choice(api.site_mask_ids)) # Run tests. for region, site_mask in product( @@ -415,7 +418,7 @@ def check_snp_genotypes( assert ad.shape[2] == 4 # Check with site mask. - mask = random.choice(api.site_mask_ids) + mask = rng.choice(api.site_mask_ids) filter_pass = api.site_filters(region=region, mask=mask).compute() gt_pass = api.snp_genotypes( region=region, @@ -459,9 +462,9 @@ def test_snp_genotypes_with_sample_sets_param(fixture, api: AnophelesSnpData): all_sample_sets = api.sample_sets()["sample_set"].to_list() parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -473,7 +476,7 @@ def test_snp_genotypes_with_sample_sets_param(fixture, api: AnophelesSnpData): def test_snp_genotypes_with_region_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) # Parametrize region. contig = fixture.random_contig() @@ -482,7 +485,7 @@ def test_snp_genotypes_with_region_param(fixture, api: AnophelesSnpData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -497,7 +500,7 @@ def test_snp_genotypes_with_region_param(fixture, api: AnophelesSnpData): def test_snp_genotypes_with_sample_query_param( ag3_sim_api: AnophelesSnpData, sample_query ): - contig = random.choice(ag3_sim_api.contigs) + contig = rng.choice(ag3_sim_api.contigs) df_samples = ag3_sim_api.sample_metadata().query(sample_query) if len(df_samples) == 0: @@ -526,7 +529,7 @@ def test_snp_genotypes_with_sample_query_param( def test_snp_genotypes_with_sample_query_options_param( ag3_sim_api: AnophelesSnpData, sample_query, sample_query_options ): - contig = random.choice(ag3_sim_api.contigs) + contig = rng.choice(ag3_sim_api.contigs) df_samples = ag3_sim_api.sample_metadata().query( sample_query, **sample_query_options ) @@ -565,7 +568,7 @@ def test_snp_genotypes_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. check_snp_genotypes(api, region=region) @@ -590,7 +593,7 @@ def test_snp_variants_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" pos = api.snp_sites(region=region, field="POS").compute() ds_region = api.snp_variants(region=region) @@ -692,16 +695,22 @@ def check_snp_calls(api, sample_sets, region, site_mask): def test_snp_calls_with_sample_sets_param(fixture, api: AnophelesSnpData): # Fixed parameters. region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_sets. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -715,8 +724,14 @@ def test_snp_calls_with_sample_sets_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_region_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize region. contig = fixture.random_contig() @@ -725,7 +740,7 @@ def test_snp_calls_with_region_param(fixture, api: AnophelesSnpData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -739,7 +754,7 @@ def test_snp_calls_with_region_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_site_mask_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize site_mask. @@ -813,7 +828,7 @@ def test_snp_calls_with_sample_query_options_param( def test_snp_calls_with_min_cohort_size_param(fixture, api: AnophelesSnpData): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with minimum cohort size. @@ -836,7 +851,7 @@ def test_snp_calls_with_min_cohort_size_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_max_cohort_size_param(fixture, api: AnophelesSnpData): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with maximum cohort size. @@ -853,11 +868,11 @@ def test_snp_calls_with_max_cohort_size_param(fixture, api: AnophelesSnpData): def test_snp_calls_with_cohort_size_param(fixture, api: AnophelesSnpData): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = random.randint(1, 10) + cohort_size = int(rng.integers(1, 10)) ds = api.snp_calls( sample_sets=sample_sets, region=region, @@ -916,7 +931,7 @@ def test_snp_calls_with_virtual_contigs(ag3_sim_api, chrom): # Test with region. seq = api.genome_sequence(region=chrom) - start, stop = sorted(np.random.randint(low=1, high=len(seq), size=2)) + start, stop = sorted(rng.integers(low=1, high=len(seq), size=2)) region = f"{chrom}:{start:,}-{stop:,}" # Standard checks. @@ -964,7 +979,8 @@ def check_snp_allele_counts( assert ac.shape == (pos.shape[0], 4) assert np.all(ac >= 0) an = ac.sum(axis=1) - assert an.max() <= 2 * n_samples + if an.size > 0: # Check if 'an' is not empty + assert an.max() <= 2 * n_samples # Run again to ensure loading from results cache produces the same result. ac2 = api.snp_allele_counts( @@ -981,16 +997,21 @@ def check_snp_allele_counts( def test_snp_allele_counts_with_sample_sets_param(fixture, api: AnophelesSnpData): # Fixed parameters. region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) - + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_sets. all_sample_sets = api.sample_sets()["sample_set"].to_list() all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -1008,8 +1029,14 @@ def test_snp_allele_counts_with_sample_sets_param(fixture, api: AnophelesSnpData def test_snp_allele_counts_with_region_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize region. contig = fixture.random_contig() @@ -1018,7 +1045,7 @@ def test_snp_allele_counts_with_region_param(fixture, api: AnophelesSnpData): contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -1036,7 +1063,7 @@ def test_snp_allele_counts_with_region_param(fixture, api: AnophelesSnpData): def test_snp_allele_counts_with_site_mask_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize site_mask. @@ -1057,9 +1084,15 @@ def test_snp_allele_counts_with_site_mask_param(fixture, api: AnophelesSnpData): def test_snp_allele_counts_with_sample_query_param(fixture, api: AnophelesSnpData): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_query. parametrize_sample_query = (None, "sex_call == 'F'") @@ -1081,9 +1114,15 @@ def test_snp_allele_counts_with_sample_query_options_param( ): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) sample_query_options = { "local_dict": { "sex_call_list": ["F", "M"], @@ -1121,7 +1160,7 @@ def test_is_accessible(fixture, api: AnophelesSnpData): parametrize_region = [ contig, fixture.random_region_str(), - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Parametrize site_mask. @@ -1143,9 +1182,9 @@ def test_is_accessible(fixture, api: AnophelesSnpData): def test_plot_snps(fixture, api: AnophelesSnpData): # Randomly choose parameter values. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() - site_mask = random.choice(api.site_mask_ids) + site_mask = rng.choice(api.site_mask_ids) # Exercise the function. fig = api.plot_snps( @@ -1288,18 +1327,25 @@ def check_biallelic_snp_calls_and_diplotypes( def test_biallelic_snp_calls_and_diplotypes_with_sample_sets_param( fixture, api: AnophelesSnpData ): + all_sample_sets = api.sample_sets()["sample_set"].to_list() # Fixed parameters. region = fixture.random_region_str() - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize sample_sets. - all_sample_sets = api.sample_sets()["sample_set"].to_list() all_releases = api.releases parametrize_sample_sets = [ None, - random.choice(all_sample_sets), - random.sample(all_sample_sets, 2), - random.choice(all_releases), + rng.choice(all_sample_sets), + rng.choice(all_sample_sets, 2, replace=False).tolist(), + rng.choice(all_releases), ] # Run tests. @@ -1315,8 +1361,14 @@ def test_biallelic_snp_calls_and_diplotypes_with_region_param( ): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrize region. contig = fixture.random_contig() @@ -1325,7 +1377,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_region_param( contig, fixture.random_region_str(), [fixture.random_region_str(), fixture.random_region_str()], - random.choice(df_gff["ID"].dropna().to_list()), + rng.choice(df_gff["ID"].dropna().to_list()), ] # Run tests. @@ -1341,7 +1393,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_site_mask_param( ): # Fixed parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Parametrize site_mask. @@ -1420,7 +1472,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_min_cohort_size_param( ): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with minimum cohort size. @@ -1445,7 +1497,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_max_cohort_size_param( ): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with maximum cohort size. @@ -1464,11 +1516,11 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( ): # Randomly fix some input parameters. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) region = fixture.random_region_str() # Test with specific cohort size. - cohort_size = random.randint(1, 10) + cohort_size = int(rng.integers(1, 10)) ds = api.biallelic_snp_calls( sample_sets=sample_sets, region=region, @@ -1502,7 +1554,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_cohort_size_param( def test_biallelic_snp_calls_and_diplotypes_with_site_class_param( ag3_sim_api: AnophelesSnpData, site_class ): - contig = random.choice(ag3_sim_api.contigs) + contig = rng.choice(ag3_sim_api.contigs) ds1 = ag3_sim_api.biallelic_snp_calls(region=contig) ds2 = ag3_sim_api.biallelic_snp_calls(region=contig, site_class=site_class) assert ds2.sizes["variants"] < ds1.sizes["variants"] @@ -1516,14 +1568,20 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( fixture, api: AnophelesSnpData ): # Fixed parameters. - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrise conditions. - min_minor_ac = random.randint(1, 3) - max_missing_an = random.randint(5, 10) + min_minor_ac = int(rng.integers(1, 3)) + max_missing_an = int(rng.integers(5, 10)) # Run tests. ds = check_biallelic_snp_calls_and_diplotypes( @@ -1551,7 +1609,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = random.randint(1, n_snps_available // 2) + n_snps_requested = int(rng.integers(1, n_snps_available // 2)) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, @@ -1582,14 +1640,20 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( fixture, api: AnophelesSnpData ): # Fixed parameters. - contig = random.choice(api.contigs) + contig = rng.choice(api.contigs) all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice((None,) + api.site_mask_ids) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Parametrise conditions. - min_minor_ac = random.uniform(0, 0.05) - max_missing_an = random.uniform(0.05, 0.2) + min_minor_ac = rng.uniform(0, 0.05) + max_missing_an = rng.uniform(0.05, 0.2) # Run tests. ds = check_biallelic_snp_calls_and_diplotypes( @@ -1617,7 +1681,7 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( # This should always be true, although depends on min_minor_ac and max_missing_an, # so the range of values for those parameters needs to be chosen with some care. assert n_snps_available > 2 - n_snps_requested = random.randint(1, n_snps_available // 2) + n_snps_requested = int(rng.integers(1, n_snps_available // 2)) ds_thinned = check_biallelic_snp_calls_and_diplotypes( api=api, sample_sets=sample_sets, diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py index f7a0224f..2e2048cc 100644 --- a/tests/anoph/test_snp_frq.py +++ b/tests/anoph/test_snp_frq.py @@ -1,5 +1,3 @@ -import random - import numpy as np import pandas as pd from pandas.testing import assert_frame_equal @@ -7,6 +5,7 @@ from pytest_cases import parametrize_with_cases import xarray as xr from numpy.testing import assert_allclose, assert_array_equal +from .conftest import Af1Simulator, Ag3Simulator from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 @@ -22,6 +21,9 @@ ) +rng = np.random.default_rng(seed=42) + + @pytest.fixture def ag3_sim_api(ag3_sim_fixture): return AnophelesSnpFrequencyAnalysis( @@ -112,7 +114,7 @@ def random_transcript(*, api): df_gff = api.genome_features(attributes=["ID", "Parent"]) df_transcripts = df_gff.query("type == 'mRNA'") transcript_ids = df_transcripts["ID"].dropna().to_list() - transcript_id = random.choice(transcript_ids) + transcript_id = rng.choice(transcript_ids) transcript = df_transcripts.set_index("ID").loc[transcript_id] return transcript @@ -123,7 +125,13 @@ def test_snp_effects(fixture, api: AnophelesSnpFrequencyAnalysis): transcript = random_transcript(api=api) # Pick a random site mask. - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) # Compute effects. df = api.snp_effects(transcript=transcript.name, site_mask=site_mask) @@ -300,9 +308,16 @@ def test_allele_frequencies_with_str_cohorts( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) # Set up call params. @@ -365,8 +380,14 @@ def test_allele_frequencies_with_min_cohort_size( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) transcript = random_transcript(api=api) cohorts = "admin1_year" @@ -430,15 +451,19 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query( ): # Pick test parameters at random. sample_sets = None - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) min_cohort_size = 0 transcript = random_transcript(api=api) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" # Figure out expected cohort labels. @@ -491,15 +516,19 @@ def test_allele_frequencies_with_str_cohorts_and_sample_query_options( ): # Pick test parameters at random. sample_sets = None - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) min_cohort_size = 0 transcript = random_transcript(api=api) - cohorts = random.choice( - ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] - ) + cohorts = rng.choice(["admin1_year", "admin1_month", "admin2_year", "admin2_month"]) df_samples = api.sample_metadata(sample_sets=sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -561,8 +590,8 @@ def test_allele_frequencies_with_dict_cohorts( ): # Pick test parameters at random. sample_sets = None # all sample sets - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + site_mask = rng.choice(list(api.site_mask_ids) + [""]) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) # Create cohorts by country. @@ -614,11 +643,11 @@ def test_allele_frequencies_without_drop_invariant( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + sample_sets = rng.choice(all_sample_sets) + site_mask = rng.choice(list(api.site_mask_ids) + [""]) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -670,11 +699,17 @@ def test_allele_frequencies_without_effects( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) @@ -752,10 +787,10 @@ def test_allele_frequencies_with_bad_transcript( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + site_mask = rng.choice(list(api.site_mask_ids) + [""]) + min_cohort_size = int(rng.integers(0, 2)) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. params = dict( @@ -779,10 +814,16 @@ def test_allele_frequencies_with_region( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + sample_sets = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + min_cohort_size = int(rng.integers(0, 2)) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # This should work, as long as effects=False - i.e., can get frequencies # for any genome region. transcript = fixture.random_region_str(region_size=500) @@ -837,11 +878,17 @@ def test_allele_frequencies_with_dup_samples( ): # Pick test parameters at random. all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) - site_mask = random.choice(api.site_mask_ids + (None,)) - min_cohort_size = random.randint(0, 2) + sample_set = rng.choice(all_sample_sets) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) + min_cohort_size = int(rng.integers(0, 2)) transcript = random_transcript(api=api) - cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + cohorts = rng.choice(["admin1_year", "admin2_month", "country"]) # Set up call params. params = dict( @@ -873,6 +920,7 @@ def test_allele_frequencies_with_dup_samples( def check_snp_allele_frequencies_advanced( *, + fixture, api: AnophelesSnpFrequencyAnalysis, transcript=None, area_by="admin1_iso", @@ -889,16 +937,22 @@ def check_snp_allele_frequencies_advanced( if transcript is None: transcript = random_transcript(api=api).name if area_by is None: - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + area_by = rng.choice(["country", "admin1_iso", "admin2_name"]) if period_by is None: - period_by = random.choice(["year", "quarter", "month", "random_year"]) + period_by = rng.choice(["year", "quarter", "month"]) if sample_sets is None: all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = int(rng.integers(0, 2)) if site_mask is None: - site_mask = random.choice(api.site_mask_ids + (None,)) + if isinstance(fixture, Af1Simulator): + valid_site_masks = ["funestus"] + elif isinstance(fixture, Ag3Simulator): + valid_site_masks = ["gamb_colu_arab", "gamb_colu", "arab"] + else: + valid_site_masks = [""] + list(api.site_mask_ids) + site_mask = rng.choice(valid_site_masks) if period_by == "random_year": # Add a random_year column to the sample metadata, if there isn't already. @@ -1089,14 +1143,14 @@ def check_aa_allele_frequencies_advanced( if transcript is None: transcript = random_transcript(api=api).name if area_by is None: - area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + area_by = rng.choice(["country", "admin1_iso", "admin2_name"]) if period_by is None: - period_by = random.choice(["year", "quarter", "month", "random_year"]) + period_by = rng.choice(["year", "quarter", "month"]) if sample_sets is None: all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_sets = random.choice(all_sample_sets) + sample_sets = rng.choice(all_sample_sets) if min_cohort_size is None: - min_cohort_size = random.randint(0, 2) + min_cohort_size = int(rng.integers(0, 2)) if period_by == "random_year": # Add a random_year column to the sample metadata, if there isn't already. @@ -1272,6 +1326,7 @@ def test_allele_frequencies_advanced_with_area_by( area_by, ): check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, area_by=area_by, ) @@ -1289,6 +1344,7 @@ def test_allele_frequencies_advanced_with_period_by( period_by, ): check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, period_by=period_by, ) @@ -1306,10 +1362,11 @@ def test_allele_frequencies_advanced_with_sample_query( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique() - country = random.choice(countries) + country = rng.choice(countries) sample_query = f"country == '{country}'" check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, sample_sets=all_sample_sets, sample_query=sample_query, @@ -1331,7 +1388,7 @@ def test_allele_frequencies_advanced_with_sample_query_options( all_sample_sets = api.sample_sets()["sample_set"].to_list() df_samples = api.sample_metadata(sample_sets=all_sample_sets) countries = df_samples["country"].unique().tolist() - countries_list = random.sample(countries, 2) + countries_list = rng.choice(countries, 2, replace=False).tolist() sample_query_options = { "local_dict": { "countries_list": countries_list, @@ -1340,6 +1397,7 @@ def test_allele_frequencies_advanced_with_sample_query_options( sample_query = "country in @countries_list" check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, sample_sets=all_sample_sets, sample_query=sample_query, @@ -1371,6 +1429,7 @@ def test_allele_frequencies_advanced_with_min_cohort_size( # Expect this to find at least one cohort, so go ahead with full # checks. check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, transcript=transcript, sample_sets=all_sample_sets, @@ -1419,6 +1478,7 @@ def test_allele_frequencies_advanced_with_variant_query( # Test a query that should succeed. variant_query = "effect == 'NON_SYNONYMOUS_CODING'" check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, transcript=transcript, sample_sets=all_sample_sets, @@ -1463,6 +1523,7 @@ def test_allele_frequencies_advanced_with_nobs_mode( nobs_mode, ): check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, nobs_mode=nobs_mode, ) @@ -1478,10 +1539,11 @@ def test_allele_frequencies_advanced_with_dup_samples( api: AnophelesSnpFrequencyAnalysis, ): all_sample_sets = api.sample_sets()["sample_set"].to_list() - sample_set = random.choice(all_sample_sets) + sample_set = rng.choice(all_sample_sets) sample_sets = [sample_set, sample_set] check_snp_allele_frequencies_advanced( + fixture=fixture, api=api, sample_sets=sample_sets, )