diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index c0b48eb..03ea709 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -41,6 +41,6 @@ jobs: - name: Run Unit Tests run: | # When run the first time, it'll build the library - python -m unittest tests.test_pyard tests.test_smart_sort + python -m unittest tests.unit.test_pyard tests.unit.test_smart_sort # When run the second time, it should use the already installed library - python -m unittest tests.test_pyard tests.test_smart_sort + python -m unittest tests.unit.test_pyard tests.unit.test_smart_sort diff --git a/Dockerfile b/Dockerfile index 0195989..5a36ef0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ LABEL MAINTAINER="Pradeep Bashyal" WORKDIR /app -ARG PY_ARD_VERSION=1.5.5 +ARG PY_ARD_VERSION=2.0.0b1 COPY requirements.txt /app RUN pip install --no-cache-dir --upgrade pip && \ diff --git a/api-spec.yaml b/api-spec.yaml index 935cfec..c4baca3 100644 --- a/api-spec.yaml +++ b/api-spec.yaml @@ -2,7 +2,7 @@ openapi: 3.0.3 info: title: ARD Reduction description: Reduce to ARD Level - version: "1.5.5" + version: "2.0.0b1" servers: - url: 'http://localhost:8080' tags: diff --git a/pyard/__init__.py b/pyard/__init__.py index 3c81f16..81c5da6 100644 --- a/pyard/__init__.py +++ b/pyard/__init__.py @@ -26,7 +26,7 @@ from .misc import get_imgt_db_versions as db_versions __author__ = """NMDP Bioinformatics""" -__version__ = "1.5.5" +__version__ = "2.0.0b1" def init( diff --git a/pyard/data_repository.py b/pyard/data_repository.py index 6ad0e94..c8e80b8 100644 --- a/pyard/data_repository.py +++ b/pyard/data_repository.py @@ -23,18 +23,24 @@ import copy import functools import sqlite3 +import itertools -import pyard.load +import pyard.loader +import pyard.loader.cwd +import pyard.loader.mac_codes +import pyard.loader.serology from pyard.smart_sort import smart_sort_comparator from . import db from .constants import expression_chars -from .load import ( - load_g_group, - load_p_group, - load_allele_list, - load_serology_mappings, - load_latest_version, -) +from .loader.allele_list import load_allele_list +from .loader.serology import load_serology_mappings, load_serology_broad_split_mapping +from .loader.version import load_latest_version + +from .loader.p_group import load_p_group +from .loader.g_group import load_g_group + +from .simple_table import Table + from .mappings import ( ars_mapping_tables, ARSMapping, @@ -53,7 +59,7 @@ from .smart_sort import smart_sort_comparator -def expression_reduce(df): +def expression_reduce(exp_alleles_table): """ For each group of expression alleles, check if __all__ of them have the same expression character. If so, the second field @@ -63,21 +69,50 @@ def expression_reduce(df): field level if all three-field and/or four-field alleles have the same expression character. - :param df: dataframe with Allele column that is all expression characters - :return: 2 field allele or None + Given + allele_groups = { + 'A*01:01': [ + {'AlleleID': 'HLA02169', 'Allele': 'A*01:01:01:02N', '2d': 'A*01:01', '3d': 'A*01:01:01', + 'Exp': 'A*01:01:01:02N'}, + {'AlleleID': 'HLA03587', 'Allele': 'A*01:01:38L', '2d': 'A*01:01', '3d': 'A*01:01:38L', + 'Exp': 'A*01:01:38L'} + ], + 'A*01:04': [ + {'AlleleID': 'HLA00004', 'Allele': 'A*01:04:01:01N', '2d': 'A*01:04', '3d': 'A*01:04:01', + 'Exp': 'A*01:04:01:01N'}, + {'AlleleID': 'HLA18724', 'Allele': 'A*01:04:01:02N', '2d': 'A*01:04', '3d': 'A*01:04:01', + 'Exp': 'A*01:04:01:02N'} + ], 'A*01:52': [ + {'AlleleID': 'HLA04761', 'Allele': 'A*01:52:01N', '2d': 'A*01:52', '3d': 'A*01:52:01N', + 'Exp': 'A*01:52:01N'}, + {'AlleleID': 'HLA14127', 'Allele': 'A*01:52:02N', '2d': 'A*01:52', '3d': 'A*01:52:02N', + 'Exp': 'A*01:52:02N'}] + } + + """ - for e in expression_chars: - if df["Allele"].str.endswith(e).all(): - return df["2d"].iloc[0] + e - return None + allele_groups = exp_alleles_table.group_by("2d") + valid_2d_exp_alleles = dict() + for allele_2d, allele_group in allele_groups.items(): + # Get the expression characters for the current allele_2d + expression_chars = {allele["Exp"][-1] for allele in allele_group} + + # Check if all expression characters are the same + if len(expression_chars) == 1: + # If all expression characters are the same, return the 2d allele with the expression character + valid_2d_exp_alleles[allele_2d] = allele_2d + expression_chars.pop() + + return valid_2d_exp_alleles + + +def join_allele_list(alleles: list): + return "/".join(sorted(alleles, key=functools.cmp_to_key(smart_sort_comparator))) def generate_ard_mapping(db_connection: sqlite3.Connection, imgt_version) -> ARSMapping: if db.tables_exist(db_connection, ars_mapping_tables): return db.load_ars_mappings(db_connection) - import pandas as pd - df_g_group = load_g_group(imgt_version) df_p_group = load_p_group(imgt_version) @@ -86,100 +121,73 @@ def generate_ard_mapping(db_connection: sqlite3.Connection, imgt_version) -> ARS p_not_in_g = set(df_p_group["2d"]) - set(df_g_group["2d"]) # filter to find these 2-field alleles (2d) in the P-group data frame - df_p_not_g = df_p_group[df_p_group["2d"].isin(p_not_in_g)] - # dictionary which will define the table - p_not_g = df_p_not_g.set_index("A")["lgx"].to_dict() + p_not_g = df_p_group.where_in("2d", p_not_in_g, ["A", "lgx"]).to_dict("A", "lgx") # multiple Gs # goal: identify 2-field alleles that are in multiple G-groups - # group by 2d and G, and select the 2d column and count the columns - mg = df_g_group.drop_duplicates(["2d", "G"])["2d"].value_counts() - # filter out the mg with count > 1, leaving only duplicates - # take the index from the 2d version the data frame, make that a column - # and turn that into a list - multiple_g_list = mg[mg > 1].index.to_list() + multiple_g = df_g_group.unique(["2d", "G"]).value_counts("2d") + # filter out the multiple_g with count > 1, leaving only duplicates + multiple_g_list = multiple_g.where("count > 1")["2d"] - # Keep only the alleles that have more than 1 mapping + # Keep only the alleles that have more than 1 mapping as allele list dup_g = ( - df_g_group[df_g_group["2d"].isin(multiple_g_list)][["G", "2d"]] - .drop_duplicates() - .groupby("2d", as_index=True) - .agg("/".join) - .to_dict()["G"] + df_g_group.where_in("2d", multiple_g_list, ["G", "2d"]) + .unique(["G", "2d"]) + .agg("2d", "G", join_allele_list) + .to_dict("2d", "agg") ) # multiple lgx - mlgx = df_g_group.drop_duplicates(["2d", "lgx"])["2d"].value_counts() - multiple_lgx_list = mlgx[mlgx > 1].index.to_list() + # goal: identify 2-field alleles that are in multiple lgx-groups + # group by 2d and lgx, and select the 2d column and count the columns + mlgx = df_g_group.unique(["2d", "lgx"]).value_counts("2d") + # filter out the mlgx with count > 1, leaving only duplicates + multiple_lgx_list = mlgx.where("count > 1")["2d"] - # Keep only the alleles that have more than 1 mapping + # Keep only the alleles that have more than 1 mapping as allele list dup_lgx = ( - df_g_group[df_g_group["2d"].isin(multiple_lgx_list)][["lgx", "2d"]] - .drop_duplicates() - .groupby("2d", as_index=True) - .agg("/".join) - .to_dict()["lgx"] + df_g_group.where_in("2d", multiple_lgx_list, ["lgx", "2d"]) + .unique(["lgx", "2d"]) + .agg("2d", "lgx", join_allele_list) + .to_dict("2d", "agg") ) # Extract G group mapping - df_g = pd.concat( - [ - df_g_group[["2d", "G"]].rename(columns={"2d": "A"}), - df_g_group[["3d", "G"]].rename(columns={"3d": "A"}), - df_g_group[["A", "G"]], - ], - ignore_index=True, - ) - g_group = df_g.set_index("A")["G"].to_dict() + g_2d = df_g_group[["2d", "G"]].rename(column_mapping={"2d": "A"}) + g_3d = df_g_group[["3d", "G"]].rename(column_mapping={"3d": "A"}) + g_a = df_g_group[["A", "G"]] + g_all = g_2d.union(g_3d).union(g_a) + g_group = g_all.to_dict("A", "G") # Extract P group mapping - df_p = pd.concat( - [ - df_p_group[["2d", "P"]].rename(columns={"2d": "A"}), - df_p_group[["3d", "P"]].rename(columns={"3d": "A"}), - df_p_group[["A", "P"]], - ], - ignore_index=True, - ) - p_group = df_p.set_index("A")["P"].to_dict() + p_2d = df_p_group[["2d", "P"]].rename(column_mapping={"2d": "A"}) + p_3d = df_p_group[["3d", "P"]].rename(column_mapping={"3d": "A"}) + p_a = df_p_group[["A", "P"]] + p_all = p_2d.union(p_3d).union(p_a) + p_group = p_all.to_dict("A", "P") # Extract lgx group mapping - df_lgx = pd.concat( - [ - df_g_group[["2d", "lgx"]].rename(columns={"2d": "A"}), - df_g_group[["3d", "lgx"]].rename(columns={"3d": "A"}), - df_g_group[["A", "lgx"]], - ] - ) - lgx_group = df_lgx.set_index("A")["lgx"].to_dict() + lgx_2d = df_g_group[["2d", "lgx"]].rename(column_mapping={"2d": "A"}) + lgx_3d = df_g_group[["3d", "lgx"]].rename(column_mapping={"3d": "A"}) + lgx_a = df_g_group[["A", "lgx"]] + lgx_all = lgx_2d.union(lgx_3d).union(lgx_a) + lgx_group = lgx_all.to_dict("A", "lgx") - # Find the alleles that have more than 1 mapping - dup_lgx = ( - df_g_group[df_g_group["2d"].isin(multiple_lgx_list)][["lgx", "2d"]] - .drop_duplicates() - .groupby("2d", as_index=True) - .agg(list) - .to_dict()["lgx"] - ) # Do not keep duplicate alleles for lgx. Issue #333 # DPA1*02:02/DPA1*02:07 ==> DPA1*02:02 # lowest_numbered_dup_lgx = { - k: sorted(v, key=functools.cmp_to_key(smart_sort_comparator))[0] + k: sorted(v.split("/"), key=functools.cmp_to_key(smart_sort_comparator))[0] for k, v in dup_lgx.items() } # Update the lgx_group with the allele with the lowest number lgx_group.update(lowest_numbered_dup_lgx) # Extract exon mapping - df_exon = pd.concat( - [ - df_g_group[["A", "3d"]].rename(columns={"3d": "exon"}), - ] - ) - exon_group = df_exon.set_index("A")["exon"].to_dict() + exon_a = df_g_group[["A", "3d"]].rename(column_mapping={"3d": "exon"}) + exon_group = exon_a.to_dict("A", "exon") ars_mapping = ARSMapping( dup_g=dup_g, @@ -200,21 +208,17 @@ def generate_alleles_and_xx_codes_and_who( if db.tables_exist(db_connection, code_mapping_tables + allele_tables): return db.load_code_mappings(db_connection) - import pandas as pd - allele_df = load_allele_list(imgt_version) - # Create a set of valid alleles - # All 2-field, 3-field and the original Alleles are considered valid alleles + # Create columns of alleles of various fields + allele_df["1d"] = allele_df["Allele"].apply(get_1field_allele) allele_df["2d"] = allele_df["Allele"].apply(get_2field_allele) allele_df["3d"] = allele_df["Allele"].apply(get_3field_allele) - # For all Alleles with expression characters, find 2-field valid alleles - exp_alleles = allele_df[ - allele_df["Allele"].apply( - lambda a: a[-1] in expression_chars and number_of_fields(a) > 2 - ) - ] - exp_alleles = exp_alleles.groupby("2d").apply(expression_reduce).dropna() + allele_df["Exp"] = allele_df["Allele"].apply( + lambda a: a if a[-1] in expression_chars and number_of_fields(a) > 2 else None + ) + exp_alleles_table = allele_df.where_not_null("Exp") + exp_alleles = expression_reduce(exp_alleles_table) # Create valid set of alleles: # All full length alleles @@ -224,15 +228,17 @@ def generate_alleles_and_xx_codes_and_who( set(allele_df["Allele"]) .union(set(allele_df["2d"])) .union(set(allele_df["3d"])) - .union(set(exp_alleles)) + .union(set(exp_alleles.values())) ) + valid_alleles = sorted(valid_alleles) - # Create xx_codes mapping from the unique alleles in 2-field column - xx_df = pd.DataFrame(allele_df["2d"].unique(), columns=["Allele"]) - # Also create a first-field column - xx_df["1d"] = xx_df["Allele"].apply(lambda x: x.split(":")[0]) - # xx_codes maps a first field name to its 2 field expansion - xx_codes = xx_df.groupby(["1d"]).apply(lambda x: list(x["Allele"])).to_dict() + # unique_2d = allele_df.unique('2d') + # xx_code_1d = unique_2d.apply(lambda x: x.split(":")[0]) + # xx_mapping = itertools.groupby(zip(xx_code_1d, unique_2d), key=lambda x: x[0]) + # xx_codes = {k: [x[1] for x in list(g)] for k, g in xx_mapping} + # + + xx_codes = allele_df.agg("1d", "2d", list) # Update xx codes with broads and splits for broad, splits in broad_splits_dna_mapping.items(): @@ -252,31 +258,22 @@ def generate_alleles_and_xx_codes_and_who( who_alleles = allele_df["Allele"].to_list() # Create WHO mapping from the unique alleles in the 1-field column - allele_df["1d"] = allele_df["Allele"].apply(get_1field_allele) - who_codes = pd.concat( - [ - allele_df[["Allele", "1d"]].rename(columns={"1d": "nd"}), - allele_df[["Allele", "2d"]].rename(columns={"2d": "nd"}), - allele_df[["Allele", "3d"]].rename(columns={"3d": "nd"}), - pd.DataFrame(ars_mappings.g_group.items(), columns=["Allele", "nd"]), - pd.DataFrame(ars_mappings.p_group.items(), columns=["Allele", "nd"]), - ], - ignore_index=True, - ) - - # remove valid alleles from who_codes to avoid recursion - for k in who_alleles: - if k in who_codes["nd"]: - who_codes.drop(labels=k, axis="index") + a1d = allele_df[["Allele", "1d"]].rename(column_mapping={"1d": "nd"}) + a2d = allele_df[["Allele", "2d"]].rename(column_mapping={"2d": "nd"}) + a3d = allele_df[["Allele", "3d"]].rename(column_mapping={"3d": "nd"}) + ag = Table(ars_mappings.g_group.items(), columns=["Allele", "nd"]) + ap = Table(ars_mappings.p_group.items(), columns=["Allele", "nd"]) + who_codes = a1d.union(a2d).union(a3d).union(ag).union(ap) # drop duplicates - who_codes = who_codes.drop_duplicates() - - # who_codes maps a first field name to its 2 field expansion - who_group = who_codes.groupby(["nd"]).apply(lambda x: list(x["Allele"])).to_dict() - + unique_who_codes = who_codes.unique(["Allele", "nd"]) + # remove valid alleles from who_codes to avoid recursion + # who_codes1.remove('nd', who_alleles) + # who_codes maps a first field name to its G field expansion + who_group = unique_who_codes.agg("nd", "Allele", list) # dictionary + # flat_who_group = who_group.to_dict() flat_who_group = { k: "/".join(sorted(v, key=functools.cmp_to_key(smart_sort_comparator))) for k, v in who_group.items() @@ -328,7 +325,12 @@ def generate_short_nulls(db_connection, who_group): # there is nothing to be done for who_groups that have both Q and L for example for a_shortnull in expression_alleles: # e.g. DRB4*01:03N - shortnulls[a_shortnull] = "/".join(expression_alleles[a_shortnull]) + shortnulls[a_shortnull] = "/".join( + sorted( + expression_alleles[a_shortnull], + key=functools.cmp_to_key(smart_sort_comparator), + ) + ) db.save_shortnulls(db_connection, shortnulls) @@ -348,9 +350,9 @@ def generate_mac_codes( if load_mac: mac_table_name = "mac_codes" if refresh_mac or not db.table_exists(db_connection, mac_table_name): - df_mac = pyard.load.load_mac_codes() + df_mac = pyard.loader.mac_codes.load_mac_codes() # Create a dict from code to alleles - mac = df_mac.set_index("Code")["Alleles"].to_dict() + mac = df_mac.to_dict() db.save_mac_codes(db_connection, mac, mac_table_name) @@ -382,43 +384,41 @@ def generate_serology_mapping( if not db.table_exists(db_connection, "serology_mapping"): df_sero = load_serology_mappings(imgt_version) - import pandas as pd + df_sero["Locus*Allele"] = df_sero.concat_columns(["Locus", "Allele"]) # Remove 0 and ? from USA - df_sero = df_sero[(df_sero["USA"] != "0") & (df_sero["USA"] != "?")] - df_sero["Allele"] = df_sero.loc[:, "Locus"] + df_sero.loc[:, "Allele"] - - usa = df_sero[["Locus", "Allele", "USA"]].dropna() - usa["Sero"] = usa["Locus"] + usa["USA"] - - psa = df_sero[["Locus", "Allele", "PSA"]].dropna() - psa["PSA"] = psa["PSA"].apply(lambda row: row.split("/")) - psa = psa.explode("PSA") - psa = psa[(psa["PSA"] != "0") & (psa["PSA"] != "?")].dropna() - psa["Sero"] = psa["Locus"] + psa["PSA"] - - asa = df_sero[["Locus", "Allele", "ASA"]].dropna() - asa["ASA"] = asa["ASA"].apply(lambda x: x.split("/")) - asa = asa.explode("ASA") - asa = asa[(asa["ASA"] != "0") & (asa["ASA"] != "?")].dropna() - asa["Sero"] = asa["Locus"] + asa["ASA"] - - sero_mapping_combined = pd.concat( - [usa[["Sero", "Allele"]], psa[["Sero", "Allele"]], asa[["Sero", "Allele"]]] + usa = df_sero.where("USA is not null and USA not in ('0', '?')") + usa["Sero"] = usa.concat_columns(["Locus", "USA"]) + + psa = df_sero.where_not_null("PSA") + psa = psa.explode("PSA", "/") + psa = psa.where("PSA not in ('0', '?')") + psa["Sero"] = psa.concat_columns(["Locus", "PSA"]) + + asa = df_sero.where_not_null("ASA") + asa = asa.explode("ASA", "/") + asa = asa.where("ASA not in ('0', '?')") + asa["Sero"] = asa.concat_columns(["Locus", "ASA"]) + + sero_mapping_combined = ( + usa[["Sero", "Locus*Allele"]] + .union(psa[["Sero", "Locus*Allele"]]) + .union(asa[["Sero", "Locus*Allele"]]) ) # Map to only valid serological antigen name sero_mapping_combined["Sero"] = sero_mapping_combined["Sero"].apply( to_serological_name ) - sero_mapping_combined["lgx"] = sero_mapping_combined["Allele"].apply( + sero_mapping_combined["lgx"] = sero_mapping_combined["Locus*Allele"].apply( lambda allele: redux_function(allele, "lgx") ) - sero_mapping = ( - sero_mapping_combined.groupby("Sero") - .apply(lambda x: (set(x["Allele"]), set(x["lgx"]))) - .to_dict() - ) + sero_allele_mapping = sero_mapping_combined.agg("Sero", "Locus*Allele", set) + sero_lgx_mapping = sero_mapping_combined.agg("Sero", "lgx", set) + sero_mapping = { + k: (sero_allele_mapping[k], sero_lgx_mapping[k]) + for k in sero_allele_mapping.keys() + } # map alleles for split serology to their corresponding broad # Update xx codes with broads and splits @@ -492,13 +492,18 @@ def get_db_version(db_connection: sqlite3.Connection): def generate_broad_splits_mapping(db_connection: sqlite3.Connection, imgt_version): - if not db.table_exists(db_connection, "serology_broad_split_mapping"): - sero_mapping, associated_mapping = pyard.load.load_serology_broad_split_mapping( + if not db.tables_exist( + db_connection, ["serology_broad_split_mapping", "serology_associated_mappings"] + ): + sero_mapping, associated_mapping = load_serology_broad_split_mapping( imgt_version ) - db.save_serology_broad_split_mappings(db_connection, sero_mapping) - db.save_serology_associated_mappings(db_connection, associated_mapping) - return sero_mapping, associated_mapping + + # Save the `splits` as a "/" delimited string to db + db.save_serology_broad_split_mappings(db_connection, sero_mapping.to_dict()) + db.save_serology_associated_mappings( + db_connection, associated_mapping.to_dict() + ) sero_mapping = db.load_serology_broad_split_mappings(db_connection) associated_mapping = db.load_serology_associated_mappings(db_connection) @@ -508,5 +513,5 @@ def generate_broad_splits_mapping(db_connection: sqlite3.Connection, imgt_versio def generate_cwd_mapping(db_connection: sqlite3.Connection): if not db.table_exists(db_connection, "cwd2"): - cwd2_map = pyard.load.load_cwd2() + cwd2_map = pyard.loader.cwd.load_cwd2() db.save_cwd2(db_connection, cwd2_map) diff --git a/pyard/db.py b/pyard/db.py index 45ce425..87d2ec8 100644 --- a/pyard/db.py +++ b/pyard/db.py @@ -641,12 +641,10 @@ def load_serology_associated_mappings(db_connection): def save_serology_broad_split_mappings(db_connection, sero_mapping): - # Save the `splits` as a "/" delimited string to db - sero_splits = {sero: "/".join(splits) for sero, splits in sero_mapping.items()} save_dict( db_connection, table_name="serology_broad_split_mapping", - dictionary=sero_splits, + dictionary=sero_mapping, columns=("broad", "splits"), ) diff --git a/pyard/load.py b/pyard/load.py deleted file mode 100644 index 6ce2a3c..0000000 --- a/pyard/load.py +++ /dev/null @@ -1,317 +0,0 @@ -# -# py-ard -# Copyright (c) 2023 Be The Match operated by National Marrow Donor Program. All Rights Reserved. -# -# This library is free software; you can redistribute it and/or modify it -# under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation; either version 3 of the License, or (at -# your option) any later version. -# -# This library is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public -# License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this library; if not, write to the Free Software Foundation, -# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -# -# > http://www.fsf.org/licensing/licenses/lgpl.html -# > http://www.opensource.org/licenses/lgpl-license.php -# -import sys -from typing import Dict, List, Tuple -from urllib.error import URLError - -from pyard.misc import get_G_name, get_2field_allele, get_3field_allele, get_P_name - -# GitHub URL where IMGT HLA files are downloaded. -IMGT_HLA_URL = "https://raw.githubusercontent.com/ANHIG/IMGTHLA/" - - -def add_locus_name(locus: str, splits: str) -> List: - split_list = map(lambda sero: locus + sero, splits.split("/")) - return list(split_list) - - -# -# Derived from rel_ser_ser.txt -# https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/wmda/rel_ser_ser.txt -# -def load_serology_broad_split_mapping(imgt_version: str) -> Tuple[Dict, Dict]: - import pandas as pd - - ser_ser_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/rel_ser_ser.txt" - try: - df_p = pd.read_csv( - ser_ser_url, - skiprows=6, - names=["Locus", "A", "Splits", "Associated"], - dtype="string", - sep=";", - ) - except URLError as e: - print(f"Error downloading {ser_ser_url}", e, file=sys.stderr) - sys.exit(1) - - splits_df = df_p[["Locus", "A", "Splits"]].dropna() - associated_df = df_p[["Locus", "A", "Associated"]].dropna() - - splits_df["Sero"] = splits_df["Locus"] + splits_df["A"] - splits_df["Splits"] = splits_df[["Locus", "Splits"]].apply( - lambda x: add_locus_name(x["Locus"], x["Splits"]), axis=1 - ) - splits_df = splits_df.astype({"A": "int32"}).sort_values(by=["Locus", "A"]) - - associated_df["Sero"] = associated_df["Locus"] + associated_df["A"] - associated_df["Associated"] = associated_df[["Locus", "Associated"]].apply( - lambda x: add_locus_name(x["Locus"], x["Associated"]), axis=1 - ) - associated_df = associated_df.astype({"A": "int32"}).sort_values(by=["Locus", "A"]) - - splits_mapping = splits_df[["Sero", "Splits"]].set_index("Sero")["Splits"].to_dict() - associated_mapping = ( - associated_df.explode("Associated")[["Associated", "Sero"]] - .set_index("Associated")["Sero"] - .to_dict() - ) - - return splits_mapping, associated_mapping - - -def load_g_group(imgt_version): - import pandas as pd - - # load the hla_nom_g.txt - ars_g_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/hla_nom_g.txt" - try: - df = pd.read_csv( - ars_g_url, skiprows=6, names=["Locus", "A", "G"], sep=";" - ).dropna() - except URLError as e: - print(f"Error downloading {ars_g_url}", e, file=sys.stderr) - sys.exit(1) - - # the G-group is named for its first allele - df["G"] = df["A"].apply(get_G_name) - # convert slash delimited string to a list - df["A"] = df["A"].apply(lambda a: a.split("/")) - # convert the list into separate rows for each element - df = df.explode("A") - # A* + 02:01 = A*02:01 - df["A"] = df["Locus"] + df["A"] - df["G"] = df["Locus"] + df["G"] - # Create 2,3 field versions of the alleles - df["2d"] = df["A"].apply(get_2field_allele) - df["3d"] = df["A"].apply(get_3field_allele) - # lgx is 2 field version of the G group allele - df["lgx"] = df["G"].apply(get_2field_allele) - - return df - - -def load_p_group(imgt_version): - import pandas as pd - - # load the hla_nom_p.txt - ars_p_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/hla_nom_p.txt" - # example: C*;06:06:01:01/06:06:01:02/06:271;06:06P - try: - df_p = pd.read_csv( - ars_p_url, skiprows=6, names=["Locus", "A", "P"], sep=";" - ).dropna() - except URLError as e: - print(f"Error downloading {ars_p_url}", e, file=sys.stderr) - sys.exit(1) - - # the P-group is named for its first allele - # The P column is already present in the file - # df_p["P"] = df_p["A"].apply(get_P_name) - # convert slash delimited string to a list - df_p["A"] = df_p["A"].apply(lambda a: a.split("/")) - df_p = df_p.explode("A") - # C* 06:06:01:01/06:06:01:02/06:271 06:06P - df_p["A"] = df_p["Locus"] + df_p["A"] - df_p["P"] = df_p["Locus"] + df_p["P"] - # C* 06:06:01:01 06:06P - # C* 06:06:01:02 06:06P - # C* 06:271 06:06P - df_p["2d"] = df_p["A"].apply(get_2field_allele) - df_p["3d"] = df_p["A"].apply(get_3field_allele) - # lgx has the P-group name without the P for comparison - df_p["lgx"] = df_p["P"].apply(get_2field_allele) - return df_p - - -def load_allele_list(imgt_version): - """ - The format of the AlleleList file has a 6-line header with a header - on the 7th line - ``` - # file: Allelelist.3290.txt - # date: 2017-07-10 - # version: IPD-IMGT/HLA 3.29.0 - # origin: https://github.com/ANHIG/IMGTHLA/Allelelist.3290.txt - # repository: https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/allelelist/Allelelist.3290.txt - # author: WHO, Steven G. E. Marsh (steven.marsh@ucl.ac.uk) - AlleleID,Allele - HLA00001,A*01:01:01:01 - HLA02169,A*01:01:01:02N - HLA14798,A*01:01:01:03 - HLA15760,A*01:01:01:04 - HLA16415,A*01:01:01:05 - HLA16417,A*01:01:01:06 - HLA16436,A*01:01:01:07 - ``` - - :param imgt_version: IMGT database version - :return: pandas Dataframe of Alleles - """ - - # Create a Pandas DataFrame from the mac_code list file - # Skip the header (first 6 lines) and use only the Allele column - if imgt_version == "Latest": - allele_list_url = f"{IMGT_HLA_URL}Latest/Allelelist.txt" - else: - if imgt_version == "3130": - # 3130 was renamed to 3131 for Allelelist file only 🤷🏾‍ - imgt_version = "3131" - allele_list_url = ( - f"{IMGT_HLA_URL}Latest/allelelist/Allelelist.{imgt_version}.txt" - ) - import pandas as pd - - try: - allele_df = pd.read_csv(allele_list_url, header=6, usecols=["Allele"]) - except URLError as e: - print(f"Error downloading {allele_list_url}", e, file=sys.stderr) - sys.exit(1) - - return allele_df - - -def load_serology_mappings(imgt_version): - """ - Read `rel_dna_ser.txt` file that contains alleles and their serological equivalents. - - The fields of the Alleles->Serological mapping file are: - Locus - HLA Locus - Allele - HLA Allele Name - USA - Unambiguous Serological Antigen associated with allele - PSA - Possible Serological Antigen associated with allele - ASA - Assumed Serological Antigen associated with allele - EAE - Expert Assigned Exceptions in search determinants of some registries - - EAE is ignored when generating the serology map. - """ - rel_dna_ser_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/rel_dna_ser.txt" - # Load WMDA serology mapping data from URL - import pandas as pd - - try: - df_sero = pd.read_csv( - rel_dna_ser_url, - sep=";", - skiprows=6, - names=["Locus", "Allele", "USA", "PSA", "ASA", "EAE"], - index_col=False, - ) - except URLError as e: - print(f"Error downloading {rel_dna_ser_url}", e, file=sys.stderr) - sys.exit(1) - - return df_sero - - -def load_mac_codes(): - """ - MAC files come in 2 different versions: - - Martin: when they’re printed, the first is better for encoding and the - second is better for decoding. The entire list was maintained both as an - Excel spreadsheet and also as a sybase database table. The Excel was the - one that was printed and distributed. - - **==> numer.v3.txt <==** - - Sorted by the length and the values in the list - ``` - "LAST UPDATED: 09/30/20" - CODE SUBTYPE - - AB 01/02 - AC 01/03 - AD 01/04 - AE 01/05 - AG 01/06 - AH 01/07 - AJ 01/08 - ``` - - **==> alpha.v3.txt <==** - - Sorted by the code - - ``` - "LAST UPDATED: 10/01/20" - * CODE SUBTYPE - - AA 01/02/03/05 - AB 01/02 - AC 01/03 - AD 01/04 - AE 01/05 - AF 01/09 - AG 01/06 - ``` - """ - # Load the MAC file to a DataFrame - mac_url = "https://hml.nmdp.org/mac/files/numer.v3.zip" - import pandas as pd - - try: - df_mac = pd.read_csv( - mac_url, - sep="\t", - compression="zip", - skiprows=3, - names=["Code", "Alleles"], - keep_default_na=False, - ) - except URLError as e: - print(f"Error downloading {mac_url}", e, file=sys.stderr) - sys.exit(1) - - return df_mac - - -def load_latest_version(): - from urllib.request import urlopen - - version_txt = ( - "https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/release_version.txt" - ) - try: - response = urlopen(version_txt) - except URLError as e: - print(f"Error downloading {version_txt}", e, file=sys.stderr) - sys.exit(1) - - version = 0 - for line in response: - l = line.decode("utf-8") - if l.find("version:") != -1: - # Version line looks like - # # version: IPD-IMGT/HLA 3.51.0 - version = l.split()[-1].replace(".", "") - return version - - -def load_cwd2(): - import pandas as pd - import os - - cwd_csv_path = os.path.join(os.path.dirname(__file__), "CWD2.csv") - df = pd.read_csv(cwd_csv_path) - cwd_map = df.set_index("ALLELE")["LOCUS"].to_dict() - return cwd_map diff --git a/pyard/CWD2.csv b/pyard/loader/CWD2.csv similarity index 100% rename from pyard/CWD2.csv rename to pyard/loader/CWD2.csv diff --git a/pyard/loader/__init__.py b/pyard/loader/__init__.py new file mode 100644 index 0000000..b4624a8 --- /dev/null +++ b/pyard/loader/__init__.py @@ -0,0 +1,2 @@ +# GitHub URL where IMGT HLA files are downloaded. +IMGT_HLA_URL = "https://raw.githubusercontent.com/ANHIG/IMGTHLA/" diff --git a/pyard/loader/allele_list.py b/pyard/loader/allele_list.py new file mode 100644 index 0000000..2389120 --- /dev/null +++ b/pyard/loader/allele_list.py @@ -0,0 +1,59 @@ +from urllib.request import urlopen +from urllib.error import URLError +import csv +import sys +from ..simple_table import Table +from ..loader import IMGT_HLA_URL + + +def load_allele_list(imgt_version): + """ + The format of the AlleleList file has a 6-line header with a header + on the 7th line + ``` + # file: Allelelist.3290.txt + # date: 2017-07-10 + # version: IPD-IMGT/HLA 3.29.0 + # origin: https://github.com/ANHIG/IMGTHLA/Allelelist.3290.txt + # repository: https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/allelelist/Allelelist.3290.txt + # author: WHO, Steven G. E. Marsh (steven.marsh@ucl.ac.uk) + AlleleID,Allele + HLA00001,A*01:01:01:01 + HLA02169,A*01:01:01:02N + HLA14798,A*01:01:01:03 + HLA15760,A*01:01:01:04 + HLA16415,A*01:01:01:05 + HLA16417,A*01:01:01:06 + HLA16436,A*01:01:01:07 + ``` + + Returns a Table object with AlleleID and Allele columns + + :param imgt_version: IMGT database version + :return: Table object with AlleleID and Allele data + """ + + if imgt_version == "Latest": + allele_list_url = f"{IMGT_HLA_URL}Latest/Allelelist.txt" + else: + if imgt_version == "3130": + # 3130 was renamed to 3131 for Allelelist file only 🤷🏾 + imgt_version = "3131" + allele_list_url = ( + f"{IMGT_HLA_URL}Latest/allelelist/Allelelist.{imgt_version}.txt" + ) + + try: + response = urlopen(allele_list_url) + lines = [line.decode("utf-8").strip() for line in response] + + # Skip first 6 header lines + data_lines = lines[6:] + + reader = csv.DictReader(data_lines) + columns = ["AlleleID", "Allele"] + + return Table(reader, columns) + except URLError as e: + print(f"Error downloading {allele_list_url}", e, file=sys.stderr) + sys.exit(1) diff --git a/pyard/loader/cwd.py b/pyard/loader/cwd.py new file mode 100644 index 0000000..3d17117 --- /dev/null +++ b/pyard/loader/cwd.py @@ -0,0 +1,14 @@ +import os +import csv + + +def load_cwd2(): + cwd_csv_path = os.path.join(os.path.dirname(__file__), "CWD2.csv") + cwd_map = {} + + with open(cwd_csv_path, "r") as file: + reader = csv.DictReader(file) + for row in reader: + cwd_map[row["ALLELE"]] = row["LOCUS"] + + return cwd_map diff --git a/pyard/loader/g_group.py b/pyard/loader/g_group.py new file mode 100644 index 0000000..3e4036a --- /dev/null +++ b/pyard/loader/g_group.py @@ -0,0 +1,46 @@ +import sys +from urllib.error import URLError +from urllib.request import urlopen + +from ..loader import IMGT_HLA_URL +from ..misc import get_G_name, get_2field_allele, get_3field_allele +from ..simple_table import Table + + +def load_g_group(imgt_version): + # load the hla_nom_g.txt + ars_g_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/hla_nom_g.txt" + try: + response = urlopen(ars_g_url) + lines = [line.decode("utf-8").strip() for line in response] + data_lines = lines[6:] # Skip first 6 header lines + + data_tuples = [] + for line in data_lines: + if line: + fields = line.split(";") + if len(fields) >= 3 and fields[1] and fields[2]: + locus, a_list, g = fields[0], fields[1], fields[2] + g_name = get_G_name(a_list) + + # Explode slash-delimited alleles + for a in a_list.split("/"): + full_a = locus + a + full_g = locus + g_name + data_tuples.append( + ( + locus, + full_a, + full_g, + get_2field_allele(full_a), + get_3field_allele(full_a), + get_2field_allele(full_g), + ) + ) + + columns = ["Locus", "A", "G", "2d", "3d", "lgx"] + return Table(data_tuples, columns) + + except URLError as e: + print(f"Error downloading {ars_g_url}", e, file=sys.stderr) + sys.exit(1) diff --git a/pyard/loader/mac_codes.py b/pyard/loader/mac_codes.py new file mode 100644 index 0000000..5857b6c --- /dev/null +++ b/pyard/loader/mac_codes.py @@ -0,0 +1,74 @@ +import sys +from urllib.error import URLError +from urllib.request import urlopen +import zipfile +import io +from ..simple_table import Table + + +def load_mac_codes(): + """ + MAC files come in 2 different versions: + + Martin: when they’re printed, the first is better for encoding and the + second is better for decoding. The entire list was maintained both as an + Excel spreadsheet and also as a sybase database table. The Excel was the + one that was printed and distributed. + + **==> numer.v3.txt <==** + + Sorted by the length and the values in the list + ``` + "LAST UPDATED: 09/30/20" + CODE SUBTYPE + + AB 01/02 + AC 01/03 + AD 01/04 + AE 01/05 + AG 01/06 + AH 01/07 + AJ 01/08 + ``` + + **==> alpha.v3.txt <==** + + Sorted by the code + + ``` + "LAST UPDATED: 10/01/20" + * CODE SUBTYPE + + AA 01/02/03/05 + AB 01/02 + AC 01/03 + AD 01/04 + AE 01/05 + AF 01/09 + AG 01/06 + ``` + """ + mac_url = "https://hml.nmdp.org/mac/files/numer.v3.zip" + try: + response = urlopen(mac_url) + zip_data = response.read() + + with zipfile.ZipFile(io.BytesIO(zip_data)) as zip_file: + file_name = zip_file.namelist()[0] + with zip_file.open(file_name) as file: + lines = [line.decode("utf-8").strip() for line in file] + data_lines = lines[3:] # Skip first 3 header lines + + data_tuples = [] + for line in data_lines: + if line: + fields = line.split("\t") + if len(fields) >= 2: + data_tuples.append((fields[0], fields[1])) + + columns = ["Code", "Alleles"] + return Table(data_tuples, columns) + + except URLError as e: + print(f"Error downloading {mac_url}", e, file=sys.stderr) + sys.exit(1) diff --git a/pyard/loader/p_group.py b/pyard/loader/p_group.py new file mode 100644 index 0000000..9aeee8a --- /dev/null +++ b/pyard/loader/p_group.py @@ -0,0 +1,45 @@ +import sys +from urllib.error import URLError +from urllib.request import urlopen + +from ..loader import IMGT_HLA_URL +from ..misc import get_2field_allele, get_3field_allele +from ..simple_table import Table + + +def load_p_group(imgt_version): + # load the hla_nom_p.txt + ars_p_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/hla_nom_p.txt" + try: + response = urlopen(ars_p_url) + lines = [line.decode("utf-8").strip() for line in response] + data_lines = lines[6:] # Skip first 6 header lines + + data_tuples = [] + for line in data_lines: + if line: + fields = line.split(";") + if len(fields) >= 3 and fields[1] and fields[2]: + locus, a_list, p = fields[0], fields[1], fields[2] + + # Explode slash-delimited alleles + for a in a_list.split("/"): + full_a = locus + a + full_p = locus + p + data_tuples.append( + ( + locus, + full_a, + full_p, + get_2field_allele(full_a), + get_3field_allele(full_a), + get_2field_allele(full_p), + ) + ) + + columns = ["Locus", "A", "P", "2d", "3d", "lgx"] + return Table(data_tuples, columns) + + except URLError as e: + print(f"Error downloading {ars_p_url}", e, file=sys.stderr) + sys.exit(1) diff --git a/pyard/loader/serology.py b/pyard/loader/serology.py new file mode 100644 index 0000000..391da7b --- /dev/null +++ b/pyard/loader/serology.py @@ -0,0 +1,121 @@ +import sys +import csv +import io +from typing import Tuple, Dict, List +from urllib.request import urlopen +from urllib.error import URLError +from ..simple_table import Table + +# GitHub URL where IMGT HLA files are downloaded. +IMGT_HLA_URL = "https://raw.githubusercontent.com/ANHIG/IMGTHLA/" + + +def load_serology_mappings(imgt_version): + """ + Read `rel_dna_ser.txt` file that contains alleles and their serological equivalents. + + The fields of the Alleles->Serological mapping file are: + Locus - HLA Locus + Allele - HLA Allele Name + USA - Unambiguous Serological Antigen associated with allele + PSA - Possible Serological Antigen associated with allele + ASA - Assumed Serological Antigen associated with allele + EAE - Expert Assigned Exceptions in search determinants of some registries + + EAE is ignored when generating the serology map. + + :param imgt_version: IMGT database version + :return: Table object with serology mapping data + """ + + rel_dna_ser_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/rel_dna_ser.txt" + + try: + response = urlopen(rel_dna_ser_url) + lines = [line.decode("utf-8").strip() for line in response] + + # Skip first 6 header lines + data_lines = lines[6:] + + # Convert semicolon-separated data to list of tuples + # Original format: "A;A*01:01:01:01;A1;A1;;" + data_tuples = [] + for line in data_lines: + if line: + fields = line.split(";") + if len(fields) >= 6: + # Extract first 6 fields as tuple, replace empty strings with None + rel_dna_fields = tuple( + field if field else None for field in fields[:6] + ) + data_tuples.append(rel_dna_fields) + + columns = ["Locus", "Allele", "USA", "PSA", "ASA", "EAE"] + + return Table(data_tuples, columns) + except URLError as e: + print(f"Error downloading {rel_dna_ser_url}", e, file=sys.stderr) + sys.exit(1) + + +def load_serology_broad_split_mapping(imgt_version: str) -> Tuple[Table, Table]: + """ + Load serology broad/split mapping from rel_ser_ser.txt file. + + :param imgt_version: IMGT database version + :return: Tuple of (splits_table, associated_table) Table objects + - splits_table: Table with 'broad' and 'splits' columns + - associated_table: Table with 'split' and 'broad' columns + """ + + ser_ser_url = f"{IMGT_HLA_URL}{imgt_version}/wmda/rel_ser_ser.txt" + try: + response = urlopen(ser_ser_url) + lines = [line.decode("utf-8").strip() for line in response] + + # Skip first 6 header lines + data_lines = lines[6:] + + # Prepare data as lists of tuples + splits_tuples = [] + associated_tuples = [] + + for line in data_lines: + if line: # Skip empty lines + fields = line.split(";") + if len(fields) >= 4: + locus, a, splits, associated = ( + fields[0], + fields[1], + fields[2], + fields[3], + ) + + # Process splits: broad antigen -> list of splits + if splits: + sero = locus + a # e.g. "A" + "10" = "A10" + splits_list = add_locus_name( + locus, splits + ) # Add locus prefix to each split + splits_str = "/".join(splits_list) + splits_tuples.append((sero, splits_str)) + + # Process associated: create reverse mapping from split -> broad + if associated: + sero = locus + a + associated_list = add_locus_name(locus, associated) + for assoc in associated_list: + associated_tuples.append((assoc, sero)) + + splits_table = Table(splits_tuples, ["broad", "splits"]) + associated_table = Table(associated_tuples, ["split", "broad"]) + + return splits_table, associated_table + except URLError as e: + print(f"Error downloading {ser_ser_url}", e, file=sys.stderr) + sys.exit(1) + + +def add_locus_name(locus: str, splits: str) -> List: + split_list = map(lambda sero: locus + sero, splits.split("/")) + return list(split_list) diff --git a/pyard/loader/version.py b/pyard/loader/version.py new file mode 100644 index 0000000..a42b764 --- /dev/null +++ b/pyard/loader/version.py @@ -0,0 +1,24 @@ +import sys +from urllib.error import URLError + + +def load_latest_version(): + from urllib.request import urlopen + + version_txt = ( + "https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/release_version.txt" + ) + try: + response = urlopen(version_txt) + except URLError as e: + print(f"Error downloading {version_txt}", e, file=sys.stderr) + sys.exit(1) + + version = 0 + for line in response: + l = line.decode("utf-8") + if l.find("version:") != -1: + # Version line looks like + # # version: IPD-IMGT/HLA 3.51.0 + version = l.split()[-1].replace(".", "") + return version diff --git a/pyard/simple_table.py b/pyard/simple_table.py new file mode 100644 index 0000000..118d309 --- /dev/null +++ b/pyard/simple_table.py @@ -0,0 +1,337 @@ +import sqlite3 +import csv +import itertools +from collections import defaultdict +from typing import List + + +class Table: + def __init__(self, data, columns: list, table_name: str = "data"): + self._conn = sqlite3.connect(":memory:") + self._name = table_name + self._columns = columns + if isinstance(data, csv.DictReader): + self._create_table_from_reader(data, columns) + else: + self._create_table_from_tuples(data, columns) + + def _create_table_from_reader(self, reader: csv.DictReader, columns: list): + rows = list(reader) + if not rows: + return + + column_defs = ", ".join(f"`{col}` TEXT" for col in columns) + + self._conn.execute(f"CREATE TABLE {self._name} ({column_defs})") + + placeholders = ", ".join("?" * len(columns)) + for row in rows: + values = [row[col] for col in columns] + self._conn.execute( + f"INSERT INTO {self._name} VALUES ({placeholders})", values + ) + + self._conn.commit() + + def _create_table_from_tuples(self, data: list, columns: list): + if not data: + return + + column_defs = ", ".join(f"`{col}` TEXT" for col in columns) + + self._conn.execute(f"CREATE TABLE {self._name} ({column_defs})") + + placeholders = ", ".join("?" * len(columns)) + for row in data: + self._conn.execute(f"INSERT INTO {self._name} VALUES ({placeholders})", row) + + self._conn.commit() + + def query(self, sql: str): + return self._conn.execute(sql).fetchall() + + def close(self): + if hasattr(self, "_conn") and self._conn: + self._conn.close() + + @property + def columns(self): + cursor = self._conn.execute(f"PRAGMA table_info({self._name})") + return [row[1] for row in cursor.fetchall()] + + def head(self, n: int = 5): + cursor = self._conn.execute(f"SELECT * FROM {self._name} LIMIT {n}") + rows = cursor.fetchall() + return PrintableTable(self.columns, rows) + + def tail(self, n: int = 5): + cursor = self._conn.execute( + f"SELECT * FROM {self._name} ORDER BY rowid DESC LIMIT {n}" + ) + rows = cursor.fetchall() + return PrintableTable(self.columns, rows) + + def group_by(self, group_by_column: str, return_columns: List[str] = None): + if group_by_column not in self.columns: + raise ValueError(f"Column '{group_by_column}' not found in table") + if return_columns is None: + return_columns = self.columns + column_names = ", ".join([f"`{col}`" for col in return_columns]) + cursor = self._conn.execute( + f"SELECT {column_names} FROM {self._name} ORDER BY `{group_by_column}`" + ) + rows = cursor.fetchall() + col_index = self.columns.index(group_by_column) + grouped = itertools.groupby(rows, key=lambda row: row[col_index]) + return { + key: [{col: row[i] for i, col in enumerate(self.columns)} for row in group] + for key, group in grouped + } + + def unique(self, columns): + if isinstance(columns, str): + cursor = self._conn.execute( + f"SELECT DISTINCT `{columns}` FROM {self._name}" + ) + values = [row[0] for row in cursor.fetchall()] + return Column(columns, values) + else: + column_names = ", ".join([f"`{col}`" for col in columns]) + cursor = self._conn.execute( + f"SELECT DISTINCT {column_names} FROM {self._name}" + ) + return Table(cursor.fetchall(), columns, f"{self._name}_unique") + + def where(self, where_clause: str): + try: + cursor = self._conn.execute( + f"SELECT * FROM {self._name} WHERE {where_clause}" + ) + return Table(cursor.fetchall(), self.columns, f"{self._name}_filtered") + except Exception as e: + raise ValueError(f"Invalid WHERE clause: {where_clause}") from e + + def where_not_null(self, null_column): + if isinstance(null_column, list): + conditions = " AND ".join([f"`{col}` IS NOT NULL" for col in null_column]) + table_suffix = "_".join(null_column) + else: + conditions = f"`{null_column}` IS NOT NULL" + table_suffix = null_column + + table_name = f"{self._name}_not_null_{table_suffix}" + cursor = self._conn.execute(f"SELECT * FROM {self._name} WHERE {conditions}") + return Table(cursor.fetchall(), table_name=table_name, columns=self.columns) + + def where_in(self, column_name: str, values: set, columns: list): + placeholders = ", ".join("?" * len(values)) + column_names = ", ".join([f"`{col}`" for col in columns]) + cursor = self._conn.execute( + f"SELECT {column_names} FROM {self._name} WHERE `{column_name}` IN ({placeholders})", + list(values), + ) + return Table(cursor.fetchall(), columns, f"{self._name}_filtered") + + def to_dict(self, key_column: str = None, value_column: str = None): + if not key_column and not value_column: + key_column, value_column = self.columns + elif key_column not in self.columns or value_column not in self.columns: + raise ValueError( + f"Columns {key_column} and {value_column} must be in the table" + ) + if key_column == value_column: + raise ValueError( + f"Columns {key_column} and {value_column} must be different" + ) + cursor = self._conn.execute( + f"SELECT `{key_column}`, `{value_column}` FROM {self._name}" + ) + return dict(cursor.fetchall()) + + def value_counts(self, column: str): + if column not in self.columns: + raise ValueError(f"Column '{column}' not found in table") + cursor = self._conn.execute( + f"SELECT `{column}`, COUNT(*) FROM {self._name} GROUP BY `{column}` ORDER BY COUNT(*) DESC" + ) + return Table(cursor.fetchall(), [column, "count"], f"{self._name}_counts") + + def agg(self, group_column: str, agg_column: str, func): + builtin_funcs = {list, set} + query = f"SELECT `{group_column}`, `{agg_column}` FROM {self._name} GROUP BY `{group_column}`, `{agg_column}`" + result = self._conn.execute(query).fetchall() + d = defaultdict(list) + for k, v in result: + d[k].append(v) + for k, v in d.items(): + d[k] = func(v) + if func in builtin_funcs: + return d + return Table(list(d.items()), [group_column, "agg"], f"{self._name}_agg") + + def __setitem__(self, column: str, values): + if column in self.columns: + self._conn.execute(f"ALTER TABLE {self._name} DROP COLUMN `{column}`") + self._conn.execute(f"ALTER TABLE {self._name} ADD COLUMN `{column}` TEXT") + for i, value in enumerate(values): + self._conn.execute( + f"UPDATE {self._name} SET `{column}` = ? WHERE rowid = ?", + (value, i + 1), + ) + self._conn.commit() + + def __getitem__(self, column): + if isinstance(column, list): + for col in column: + if col not in self.columns: + raise ValueError(f"Column '{col}' not found in table") + column_names = ", ".join([f"`{col}`" for col in column]) + result = self._conn.execute( + f"SELECT {column_names} FROM {self._name}" + ).fetchall() + return Table(result, column, f"{self._name}_subset") + else: + if column not in self.columns: + raise ValueError(f"Column '{column}' not found in table") + result = self._conn.execute( + f"SELECT `{column}` FROM {self._name}" + ).fetchall() + values = [row[0] for row in result] + return Column(column, values) + + def rename(self, column_mapping: dict): + for old_name, new_name in column_mapping.items(): + if old_name not in self.columns: + raise ValueError(f"Column '{old_name}' not found in table") + self._conn.execute( + f"ALTER TABLE {self._name} RENAME COLUMN `{old_name}` TO `{new_name}`" + ) + self._conn.commit() + return self + + def union(self, other_table): + if self.columns != other_table.columns: + raise ValueError("Tables must have the same columns for union") + + self_data = self._conn.execute(f"SELECT * FROM {self._name}").fetchall() + other_data = other_table._conn.execute( + f"SELECT * FROM {other_table._name}" + ).fetchall() + + union_data = self_data + other_data + return Table(union_data, self.columns, f"{self._name}_union") + + def remove(self, column_name: str, values): + placeholders = ", ".join("?" * len(values)) + self._conn.execute( + f"DELETE FROM {self._name} WHERE `{column_name}` IN ({placeholders})", + list(values), + ) + self._conn.commit() + return self + + def concat_columns(self, columns: list): + for col in columns: + if col not in self.columns: + raise ValueError(f"Column '{col}' not found in table") + column_names = " || ".join([f"`{col}`" for col in columns]) + result = self._conn.execute( + f"SELECT {column_names} FROM {self._name}" + ).fetchall() + values = [row[0] for row in result] + concat_name = "_".join(columns) + return Column(concat_name, values) + + def explode(self, column: str, delimiter: str): + if column not in self.columns: + raise ValueError(f"Column '{column}' not found in table") + all_data = self._conn.execute(f"SELECT * FROM {self._name}").fetchall() + col_index = self.columns.index(column) + + exploded_data = [] + for row in all_data: + if row[col_index]: + split_values = row[col_index].split(delimiter) + for value in split_values: + new_row = list(row) + new_row[col_index] = value.strip() + exploded_data.append(tuple(new_row)) + else: + exploded_data.append(row) + + return Table(exploded_data, self.columns, f"{self._name}_exploded") + + def __len__(self): + cursor = self._conn.execute(f"SELECT COUNT(*) FROM {self._name}") + return cursor.fetchone()[0] + + def __str__(self): + return str(self.head()) + "\n" + "." * 10 + "\n" + str(self.tail()) + + def __repr__(self): + return str(self) + + def __del__(self): + self.close() + + +class Column: + def __init__(self, name, values): + self._name = name + self._values = values + + @property + def name(self): + return f"-{self._name}" + + def apply(self, func): + return [func(value) for value in self._values] + + def to_list(self): + return list(self._values) + + def __len__(self): + return len(self._values) + + def __getitem__(self, index): + return self._values[index] + + def __iter__(self): + return iter(self._values) + + +class PrintableTable: + def __init__(self, columns, rows): + self.columns = columns + self.rows = rows + + def __str__(self): + if not self.rows: + return "" + + # Calculate column widths + widths = [len(col) for col in self.columns] + for row in self.rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(str(cell))) + + # Create header + header = ( + "| " + + " | ".join(col.ljust(widths[i]) for i, col in enumerate(self.columns)) + + " |" + ) + separator = "|" + "".join("-" * (w + 2) + "|" for w in widths) + + # Create rows + result = [separator, header, separator] + for row in self.rows: + row_str = ( + "| " + + " | ".join(str(cell).ljust(widths[i]) for i, cell in enumerate(row)) + + " |" + ) + result.append(row_str) + result.append(separator) + + return "\n".join(result) diff --git a/requirements.txt b/requirements.txt index 6c2cf6b..f7a9e73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,2 @@ setuptools==78.1.1 toml==0.10.2 -numpy==2.0.2 -pandas==2.2.2 diff --git a/setup.cfg b/setup.cfg index 659797f..5c5d004 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.5.5 +current_version = 2.0.0b1 commit = True tag = True diff --git a/setup.py b/setup.py index 3981ef1..f47f7a5 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ setup( name="py-ard", - version="1.5.5", + version="2.0.0b1", description="ARD reduction for HLA with Python", long_description=readme, long_description_content_type="text/markdown", diff --git a/tests/test_pyard.py b/tests/test_pyard.py deleted file mode 100644 index a773acd..0000000 --- a/tests/test_pyard.py +++ /dev/null @@ -1,222 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# py-ard -# Copyright (c) 2023 Be The Match operated by National Marrow Donor Program. All Rights Reserved. -# -# This library is free software; you can redistribute it and/or modify it -# under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation; either version 3 of the License, or (at -# your option) any later version. -# -# This library is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public -# License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this library; if not, write to the Free Software Foundation, -# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -# -# > http://www.fsf.org/licensing/licenses/lgpl.html -# > http://www.opensource.org/licenses/lgpl-license.php -# - -""" -test_pyard ----------------------------------- - -Tests for `py-ard` module. -""" -import json -import os -import unittest - -import pyard -from pyard.constants import DEFAULT_CACHE_SIZE -from pyard.exceptions import InvalidAlleleError -from pyard.misc import validate_reduction_type - - -class TestPyArd(unittest.TestCase): - db_version = None - ard = None - - @classmethod - def setUpClass(cls) -> None: - cls.db_version = "3440" - cls.ard = pyard.init(cls.db_version, data_dir="/tmp/py-ard") - - def addDuration(self, test, elapsed): # Required for Python >= 3.12 - pass - - def test_no_mac(self): - self.assertEqual(self.ard.redux("A*01:01:01", "G"), "A*01:01:01G") - self.assertEqual(self.ard.redux("A*01:01:01", "lg"), "A*01:01g") - self.assertEqual(self.ard.redux("A*01:01:01", "lgx"), "A*01:01") - self.assertEqual(self.ard.redux("HLA-A*01:01:01", "G"), "HLA-A*01:01:01G") - self.assertEqual(self.ard.redux("HLA-A*01:01:01", "lg"), "HLA-A*01:01g") - self.assertEqual(self.ard.redux("HLA-A*01:01:01", "lgx"), "HLA-A*01:01") - - def test_remove_invalid(self): - self.assertEqual(self.ard.redux("A*01:01:01", "G"), "A*01:01:01G") - - def test_mac(self): - self.assertEqual(self.ard.redux("A*01:AB", "G"), "A*01:01:01G/A*01:02") - self.assertEqual( - self.ard.redux("HLA-A*01:AB", "G"), "HLA-A*01:01:01G/HLA-A*01:02" - ) - - def test_redux(self): - data_dir = os.path.dirname(__file__) - expected_json = data_dir + "/expected.json" - with open(expected_json) as json_data: - expected = json.load(json_data) - for ex in expected["redux"]: - glstring = ex["glstring"] - ard_type = ex["ard_type"] - expected_gl = ex["expected_gl"] - self.assertEqual(self.ard.redux(glstring, ard_type), expected_gl) - - def test_serology(self): - data_dir = os.path.dirname(__file__) - expected_json = data_dir + "/expected-serology.json" - with open(expected_json) as json_data: - expected = json.load(json_data) - for ex in expected["redux"]: - glstring = ex["glstring"] - ard_type = ex["ard_type"] - expected_gl = ex["expected_gl"] - self.assertEqual(self.ard.redux(glstring, ard_type), expected_gl) - - def test_mac_G(self): - self.assertEqual(self.ard.redux("A*01:01:01", "G"), "A*01:01:01G") - self.assertEqual( - self.ard.redux("HLA-A*01:AB", "G"), "HLA-A*01:01:01G/HLA-A*01:02" - ) - with self.assertRaises(InvalidAlleleError): - self.ard._redux_allele("HLA-A*01:AB", "G") - - def test_xx_code(self): - expanded_string = """ - B*40:01:01G/B*40:01:03G/B*40:02:01G/B*40:03:01G/B*40:04:01G/B*40:05:01G/B*40:06:01G/B*40:07/B*40:08/B*40:09/B*40:10:01G/B*40:11:01G/B*40:12/B*40:13/B*40:14/B*40:15/B*40:16:01G/B*40:18/B*40:19/B*40:20:01G/B*40:21/B*40:22N/B*40:23/B*40:24/B*40:25/B*40:26/B*40:27/B*40:28/B*40:29/B*40:30/B*40:31/B*40:32/B*40:33/B*40:34/B*40:35/B*40:36/B*40:37/B*40:38/B*40:39/B*40:40:01G/B*40:42/B*40:43/B*40:44/B*40:45/B*40:46/B*40:47/B*40:48/B*40:49/B*40:50:01G/B*40:51/B*40:52/B*40:53/B*40:54/B*40:57/B*40:58/B*40:59/B*40:60/B*40:61/B*40:62/B*40:63/B*40:64:01G/B*40:65/B*40:66/B*40:67/B*40:68/B*40:69/B*40:70/B*40:71/B*40:72/B*40:73/B*40:74/B*40:75/B*40:76/B*40:77/B*40:78/B*40:79/B*40:80/B*40:81/B*40:82/B*40:83/B*40:84/B*40:85/B*40:86/B*40:87/B*40:88/B*40:89/B*40:90/B*40:91/B*40:92/B*40:93/B*40:94/B*40:95/B*40:96/B*40:98/B*40:99/B*40:100/B*40:101/B*40:102/B*40:103/B*40:104/B*40:105/B*40:106/B*40:107/B*40:108/B*40:109/B*40:110/B*40:111/B*40:112/B*40:113/B*40:114:01G/B*40:115/B*40:116/B*40:117/B*40:118N/B*40:119/B*40:120/B*40:121/B*40:122/B*40:123/B*40:124/B*40:125/B*40:126/B*40:127/B*40:128/B*40:129/B*40:130/B*40:131/B*40:132/B*40:133Q/B*40:134/B*40:135/B*40:136/B*40:137/B*40:138/B*40:139/B*40:140/B*40:142N/B*40:143/B*40:145/B*40:146/B*40:147/B*40:148/B*40:149/B*40:152/B*40:153/B*40:154/B*40:155:01G/B*40:156/B*40:157/B*40:158/B*40:159/B*40:160/B*40:161/B*40:162/B*40:163/B*40:164/B*40:165/B*40:166/B*40:167/B*40:168/B*40:169/B*40:170/B*40:171/B*40:172/B*40:173/B*40:174/B*40:175/B*40:177/B*40:178/B*40:180/B*40:181/B*40:182/B*40:183/B*40:184/B*40:185/B*40:186/B*40:187/B*40:188/B*40:189/B*40:190/B*40:191/B*40:192/B*40:193/B*40:194/B*40:195/B*40:196/B*40:197/B*40:198/B*40:199/B*40:200/B*40:201/B*40:202/B*40:203/B*40:204/B*40:205/B*40:206/B*40:207/B*40:208/B*40:209/B*40:210/B*40:211/B*40:212/B*40:213:01G/B*40:214/B*40:215/B*40:216N/B*40:217/B*40:218/B*40:219/B*40:220/B*40:222/B*40:223/B*40:224/B*40:225/B*40:226/B*40:227/B*40:228/B*40:230/B*40:231/B*40:232/B*40:233/B*40:234/B*40:235/B*40:237/B*40:238/B*40:239/B*40:240/B*40:242/B*40:243/B*40:244/B*40:245/B*40:246/B*40:248/B*40:249/B*40:250/B*40:251/B*40:252/B*40:253/B*40:254/B*40:255/B*40:256N/B*40:257/B*40:258/B*40:259/B*40:260/B*40:261/B*40:262/B*40:263N/B*40:265N/B*40:266/B*40:268/B*40:269/B*40:270/B*40:271/B*40:273/B*40:274/B*40:275/B*40:276/B*40:277/B*40:279/B*40:280/B*40:281/B*40:282/B*40:283/B*40:284/B*40:285/B*40:286N/B*40:287/B*40:288/B*40:289/B*40:290/B*40:291N/B*40:292/B*40:293/B*40:294/B*40:295/B*40:296/B*40:297/B*40:298/B*40:300/B*40:302/B*40:304/B*40:305/B*40:306/B*40:307/B*40:308/B*40:309/B*40:310/B*40:311/B*40:312/B*40:313/B*40:314/B*40:315/B*40:316/B*40:317/B*40:318/B*40:319/B*40:320/B*40:321/B*40:322/B*40:323/B*40:324/B*40:325/B*40:326/B*40:327/B*40:328/B*40:330/B*40:331/B*40:332/B*40:333/B*40:334/B*40:335/B*40:336/B*40:337N/B*40:339/B*40:340/B*40:341/B*40:342/B*40:343/B*40:344/B*40:345N/B*40:346/B*40:347/B*40:348/B*40:349/B*40:350/B*40:351/B*40:352/B*40:354/B*40:355/B*40:357/B*40:358/B*40:359/B*40:360/B*40:361N/B*40:362/B*40:363/B*40:364/B*40:365/B*40:366/B*40:367/B*40:368/B*40:369/B*40:370/B*40:371/B*40:372N/B*40:373/B*40:374/B*40:375/B*40:376/B*40:377/B*40:378/B*40:380/B*40:381/B*40:382/B*40:385/B*40:388/B*40:389/B*40:390/B*40:391/B*40:392/B*40:393/B*40:394/B*40:396/B*40:397/B*40:398/B*40:399N/B*40:400/B*40:401/B*40:402/B*40:403/B*40:404/B*40:407/B*40:408/B*40:409/B*40:410/B*40:411/B*40:412/B*40:413/B*40:414/B*40:415/B*40:420/B*40:421Q/B*40:422/B*40:423/B*40:424/B*40:426N/B*40:428N/B*40:429/B*40:430/B*40:432/B*40:433/B*40:434/B*40:436/B*40:437/B*40:438N/B*40:441/B*40:445/B*40:447/B*40:448/B*40:449/B*40:451/B*40:452/B*40:454/B*40:457/B*40:458/B*40:459/B*40:460/B*40:461/B*40:462/B*40:463/B*40:465/B*40:466/B*40:467/B*40:468/B*40:469/B*40:470/B*40:471/B*40:472/B*40:477/B*40:478/B*40:479/B*40:481N/B*40:482 - """.strip() - gl = self.ard.redux("B*40:XX", "G") - self.assertEqual(gl, expanded_string) - - def test_xx_code_with_prefix(self): - expanded_string = """ - HLA-B*40:01:01G/HLA-B*40:01:03G/HLA-B*40:02:01G/HLA-B*40:03:01G/HLA-B*40:04:01G/HLA-B*40:05:01G/HLA-B*40:06:01G/HLA-B*40:07/HLA-B*40:08/HLA-B*40:09/HLA-B*40:10:01G/HLA-B*40:11:01G/HLA-B*40:12/HLA-B*40:13/HLA-B*40:14/HLA-B*40:15/HLA-B*40:16:01G/HLA-B*40:18/HLA-B*40:19/HLA-B*40:20:01G/HLA-B*40:21/HLA-B*40:22N/HLA-B*40:23/HLA-B*40:24/HLA-B*40:25/HLA-B*40:26/HLA-B*40:27/HLA-B*40:28/HLA-B*40:29/HLA-B*40:30/HLA-B*40:31/HLA-B*40:32/HLA-B*40:33/HLA-B*40:34/HLA-B*40:35/HLA-B*40:36/HLA-B*40:37/HLA-B*40:38/HLA-B*40:39/HLA-B*40:40:01G/HLA-B*40:42/HLA-B*40:43/HLA-B*40:44/HLA-B*40:45/HLA-B*40:46/HLA-B*40:47/HLA-B*40:48/HLA-B*40:49/HLA-B*40:50:01G/HLA-B*40:51/HLA-B*40:52/HLA-B*40:53/HLA-B*40:54/HLA-B*40:57/HLA-B*40:58/HLA-B*40:59/HLA-B*40:60/HLA-B*40:61/HLA-B*40:62/HLA-B*40:63/HLA-B*40:64:01G/HLA-B*40:65/HLA-B*40:66/HLA-B*40:67/HLA-B*40:68/HLA-B*40:69/HLA-B*40:70/HLA-B*40:71/HLA-B*40:72/HLA-B*40:73/HLA-B*40:74/HLA-B*40:75/HLA-B*40:76/HLA-B*40:77/HLA-B*40:78/HLA-B*40:79/HLA-B*40:80/HLA-B*40:81/HLA-B*40:82/HLA-B*40:83/HLA-B*40:84/HLA-B*40:85/HLA-B*40:86/HLA-B*40:87/HLA-B*40:88/HLA-B*40:89/HLA-B*40:90/HLA-B*40:91/HLA-B*40:92/HLA-B*40:93/HLA-B*40:94/HLA-B*40:95/HLA-B*40:96/HLA-B*40:98/HLA-B*40:99/HLA-B*40:100/HLA-B*40:101/HLA-B*40:102/HLA-B*40:103/HLA-B*40:104/HLA-B*40:105/HLA-B*40:106/HLA-B*40:107/HLA-B*40:108/HLA-B*40:109/HLA-B*40:110/HLA-B*40:111/HLA-B*40:112/HLA-B*40:113/HLA-B*40:114:01G/HLA-B*40:115/HLA-B*40:116/HLA-B*40:117/HLA-B*40:118N/HLA-B*40:119/HLA-B*40:120/HLA-B*40:121/HLA-B*40:122/HLA-B*40:123/HLA-B*40:124/HLA-B*40:125/HLA-B*40:126/HLA-B*40:127/HLA-B*40:128/HLA-B*40:129/HLA-B*40:130/HLA-B*40:131/HLA-B*40:132/HLA-B*40:133Q/HLA-B*40:134/HLA-B*40:135/HLA-B*40:136/HLA-B*40:137/HLA-B*40:138/HLA-B*40:139/HLA-B*40:140/HLA-B*40:142N/HLA-B*40:143/HLA-B*40:145/HLA-B*40:146/HLA-B*40:147/HLA-B*40:148/HLA-B*40:149/HLA-B*40:152/HLA-B*40:153/HLA-B*40:154/HLA-B*40:155:01G/HLA-B*40:156/HLA-B*40:157/HLA-B*40:158/HLA-B*40:159/HLA-B*40:160/HLA-B*40:161/HLA-B*40:162/HLA-B*40:163/HLA-B*40:164/HLA-B*40:165/HLA-B*40:166/HLA-B*40:167/HLA-B*40:168/HLA-B*40:169/HLA-B*40:170/HLA-B*40:171/HLA-B*40:172/HLA-B*40:173/HLA-B*40:174/HLA-B*40:175/HLA-B*40:177/HLA-B*40:178/HLA-B*40:180/HLA-B*40:181/HLA-B*40:182/HLA-B*40:183/HLA-B*40:184/HLA-B*40:185/HLA-B*40:186/HLA-B*40:187/HLA-B*40:188/HLA-B*40:189/HLA-B*40:190/HLA-B*40:191/HLA-B*40:192/HLA-B*40:193/HLA-B*40:194/HLA-B*40:195/HLA-B*40:196/HLA-B*40:197/HLA-B*40:198/HLA-B*40:199/HLA-B*40:200/HLA-B*40:201/HLA-B*40:202/HLA-B*40:203/HLA-B*40:204/HLA-B*40:205/HLA-B*40:206/HLA-B*40:207/HLA-B*40:208/HLA-B*40:209/HLA-B*40:210/HLA-B*40:211/HLA-B*40:212/HLA-B*40:213:01G/HLA-B*40:214/HLA-B*40:215/HLA-B*40:216N/HLA-B*40:217/HLA-B*40:218/HLA-B*40:219/HLA-B*40:220/HLA-B*40:222/HLA-B*40:223/HLA-B*40:224/HLA-B*40:225/HLA-B*40:226/HLA-B*40:227/HLA-B*40:228/HLA-B*40:230/HLA-B*40:231/HLA-B*40:232/HLA-B*40:233/HLA-B*40:234/HLA-B*40:235/HLA-B*40:237/HLA-B*40:238/HLA-B*40:239/HLA-B*40:240/HLA-B*40:242/HLA-B*40:243/HLA-B*40:244/HLA-B*40:245/HLA-B*40:246/HLA-B*40:248/HLA-B*40:249/HLA-B*40:250/HLA-B*40:251/HLA-B*40:252/HLA-B*40:253/HLA-B*40:254/HLA-B*40:255/HLA-B*40:256N/HLA-B*40:257/HLA-B*40:258/HLA-B*40:259/HLA-B*40:260/HLA-B*40:261/HLA-B*40:262/HLA-B*40:263N/HLA-B*40:265N/HLA-B*40:266/HLA-B*40:268/HLA-B*40:269/HLA-B*40:270/HLA-B*40:271/HLA-B*40:273/HLA-B*40:274/HLA-B*40:275/HLA-B*40:276/HLA-B*40:277/HLA-B*40:279/HLA-B*40:280/HLA-B*40:281/HLA-B*40:282/HLA-B*40:283/HLA-B*40:284/HLA-B*40:285/HLA-B*40:286N/HLA-B*40:287/HLA-B*40:288/HLA-B*40:289/HLA-B*40:290/HLA-B*40:291N/HLA-B*40:292/HLA-B*40:293/HLA-B*40:294/HLA-B*40:295/HLA-B*40:296/HLA-B*40:297/HLA-B*40:298/HLA-B*40:300/HLA-B*40:302/HLA-B*40:304/HLA-B*40:305/HLA-B*40:306/HLA-B*40:307/HLA-B*40:308/HLA-B*40:309/HLA-B*40:310/HLA-B*40:311/HLA-B*40:312/HLA-B*40:313/HLA-B*40:314/HLA-B*40:315/HLA-B*40:316/HLA-B*40:317/HLA-B*40:318/HLA-B*40:319/HLA-B*40:320/HLA-B*40:321/HLA-B*40:322/HLA-B*40:323/HLA-B*40:324/HLA-B*40:325/HLA-B*40:326/HLA-B*40:327/HLA-B*40:328/HLA-B*40:330/HLA-B*40:331/HLA-B*40:332/HLA-B*40:333/HLA-B*40:334/HLA-B*40:335/HLA-B*40:336/HLA-B*40:337N/HLA-B*40:339/HLA-B*40:340/HLA-B*40:341/HLA-B*40:342/HLA-B*40:343/HLA-B*40:344/HLA-B*40:345N/HLA-B*40:346/HLA-B*40:347/HLA-B*40:348/HLA-B*40:349/HLA-B*40:350/HLA-B*40:351/HLA-B*40:352/HLA-B*40:354/HLA-B*40:355/HLA-B*40:357/HLA-B*40:358/HLA-B*40:359/HLA-B*40:360/HLA-B*40:361N/HLA-B*40:362/HLA-B*40:363/HLA-B*40:364/HLA-B*40:365/HLA-B*40:366/HLA-B*40:367/HLA-B*40:368/HLA-B*40:369/HLA-B*40:370/HLA-B*40:371/HLA-B*40:372N/HLA-B*40:373/HLA-B*40:374/HLA-B*40:375/HLA-B*40:376/HLA-B*40:377/HLA-B*40:378/HLA-B*40:380/HLA-B*40:381/HLA-B*40:382/HLA-B*40:385/HLA-B*40:388/HLA-B*40:389/HLA-B*40:390/HLA-B*40:391/HLA-B*40:392/HLA-B*40:393/HLA-B*40:394/HLA-B*40:396/HLA-B*40:397/HLA-B*40:398/HLA-B*40:399N/HLA-B*40:400/HLA-B*40:401/HLA-B*40:402/HLA-B*40:403/HLA-B*40:404/HLA-B*40:407/HLA-B*40:408/HLA-B*40:409/HLA-B*40:410/HLA-B*40:411/HLA-B*40:412/HLA-B*40:413/HLA-B*40:414/HLA-B*40:415/HLA-B*40:420/HLA-B*40:421Q/HLA-B*40:422/HLA-B*40:423/HLA-B*40:424/HLA-B*40:426N/HLA-B*40:428N/HLA-B*40:429/HLA-B*40:430/HLA-B*40:432/HLA-B*40:433/HLA-B*40:434/HLA-B*40:436/HLA-B*40:437/HLA-B*40:438N/HLA-B*40:441/HLA-B*40:445/HLA-B*40:447/HLA-B*40:448/HLA-B*40:449/HLA-B*40:451/HLA-B*40:452/HLA-B*40:454/HLA-B*40:457/HLA-B*40:458/HLA-B*40:459/HLA-B*40:460/HLA-B*40:461/HLA-B*40:462/HLA-B*40:463/HLA-B*40:465/HLA-B*40:466/HLA-B*40:467/HLA-B*40:468/HLA-B*40:469/HLA-B*40:470/HLA-B*40:471/HLA-B*40:472/HLA-B*40:477/HLA-B*40:478/HLA-B*40:479/HLA-B*40:481N/HLA-B*40:482 - """.strip() - gl = self.ard.redux("HLA-B*40:XX", "G") - self.assertEqual(expanded_string, gl) - - def test_expand_mac(self): - mac_ab_expanded = ["A*01:01", "A*01:02"] - self.assertEqual(self.ard.expand_mac("A*01:AB"), "/".join(mac_ab_expanded)) - - mac_hla_ab_expanded = ["HLA-A*01:01", "HLA-A*01:02"] - self.assertEqual( - self.ard.expand_mac("HLA-A*01:AB"), "/".join(mac_hla_ab_expanded) - ) - - mac_ac_expanded = ["A*01:01", "A*01:03"] - self.assertEqual(self.ard.expand_mac("A*01:AC"), "/".join(mac_ac_expanded)) - - mac_hla_ac_expanded = ["HLA-A*01:01", "HLA-A*01:03"] - self.assertEqual( - self.ard.expand_mac("HLA-A*01:AC"), "/".join(mac_hla_ac_expanded) - ) - - def test_redux_types(self): - self.assertIsNone(validate_reduction_type("G")) - self.assertIsNone(validate_reduction_type("lg")) - self.assertIsNone(validate_reduction_type("lgx")) - self.assertIsNone(validate_reduction_type("W")) - self.assertIsNone(validate_reduction_type("exon")) - with self.assertRaises(ValueError): - validate_reduction_type("XX") - - def test_empty_allele(self): - with self.assertRaises(InvalidAlleleError): - self.ard.redux("A*", "lgx") - - def test_fp_allele(self): - with self.assertRaises(InvalidAlleleError): - self.ard.redux("A*0.123", "lgx") - - def test_empty_fields(self): - with self.assertRaises(InvalidAlleleError): - # : without any data - self.ard.redux("DQA1*01:01:01:G", "lgx") - - def test_invalid_serology(self): - # Test that A10 works and the first one is 'A*25:01' - serology_a10 = self.ard.redux("A10", "lgx") - self.assertEqual(serology_a10.split("/")[0], "A*25:01") - # And A100 isn't a valid typing - with self.assertRaises(InvalidAlleleError): - self.ard.redux("A100", "lgx") - - def test_allele_duplicated(self): - # Make sure the reduced alleles are unique - # https://github.com/nmdp-bioinformatics/py-ard/issues/135 - allele_code = "C*02:ACMGS" - allele_code_rx = self.ard.redux(allele_code, "lgx") - self.assertEqual(allele_code_rx, "C*02:02") - - def test_imgt_db_version(self): - self.assertEqual(self.ard.get_db_version(), int(TestPyArd.db_version)) - - def test_xx_codes_broad_split(self): - self.assertFalse( - "DQB1*06" in self.ard.redux("DQB1*05:XX", "lgx"), - "The split shouldn't include other splits", - ) - - def test_cache_info(self): - # validate the default cache size - self.assertEqual( - self.ard._redux_allele.cache_info().maxsize, DEFAULT_CACHE_SIZE - ) - self.assertEqual(self.ard.redux.cache_info().maxsize, DEFAULT_CACHE_SIZE) - # validate you can change the cache size - higher_cache_size = 5_000_000 - another_ard = pyard.init( - self.db_version, data_dir="/tmp/py-ard", cache_size=higher_cache_size - ) - self.assertEqual( - another_ard._redux_allele.cache_info().maxsize, higher_cache_size - ) - self.assertEqual(another_ard.redux.cache_info().maxsize, higher_cache_size) - - def test_is_null(self): - # a null allele - allele = "A*01:01N" - self.assertTrue(self.ard.is_null(allele), msg="A Null Allele") - # not null allele - allele = "A*01:01" - self.assertFalse(self.ard.is_null(allele), msg="not null allele") - # MACs ending with N shouldn't be called as Nulls - allele = "A*01:MN" - self.assertFalse( - self.ard.is_null(allele), - msg="MACs ending with N shouldn't be called as Nulls", - ) - # MACs shouldn't be called as Nulls - allele = "A*01:AB" - self.assertFalse( - self.ard.is_null(allele), msg="MACs shouldn't be called as Nulls" - ) - - def test_default_redux_is_lgx(self): - allele = "A*24:BKKPV+A*26:03^B*15:BKNTS+B*15:07" - lgx_redux = self.ard.redux(allele, "lgx") - default_redux = self.ard.redux(allele) - self.assertEqual(lgx_redux, default_redux, msg="Default redux should be lgx") - - def test_mac_is_reversible(self): - mac_code = "A*68:AJEBX" - expanded_mac = self.ard.expand_mac(mac_code) - lookup_mac = self.ard.lookup_mac(expanded_mac) - self.assertEqual(mac_code, lookup_mac, msg="MACs should be reversible") diff --git a/tests/test_smart_sort.py b/tests/test_smart_sort.py deleted file mode 100644 index 3d330f2..0000000 --- a/tests/test_smart_sort.py +++ /dev/null @@ -1,127 +0,0 @@ -# -# py-ard -# Copyright (c) 2023 Be The Match operated by National Marrow Donor Program. All Rights Reserved. -# -# This library is free software; you can redistribute it and/or modify it -# under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation; either version 3 of the License, or (at -# your option) any later version. -# -# This library is distributed in the hope that it will be useful, but WITHOUT -# ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or -# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public -# License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this library; if not, write to the Free Software Foundation, -# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. -# -# > http://www.fsf.org/licensing/licenses/lgpl.html -# > http://www.opensource.org/licenses/lgpl-license.php -# -import unittest - -from pyard.smart_sort import smart_sort_comparator - - -class TestSmartSort(unittest.TestCase): - def setUp(self) -> None: - super().setUp() - - def addDuration(self, test, elapsed): # Required for Python >= 3.12 - pass - - def test_same_comparator(self): - allele = "HLA-A*01:01" - self.assertEqual(smart_sort_comparator(allele, allele), 0) - - def test_equal_comparator(self): - allele1 = "HLA-A*01:01" - allele2 = "HLA-A*01:01" - self.assertEqual(smart_sort_comparator(allele1, allele2), 0) - - def test_equal_comparator_G(self): - # Should compare without G - allele1 = "HLA-A*01:01G" - allele2 = "HLA-A*01:01" - self.assertEqual(smart_sort_comparator(allele1, allele2), 0) - - def test_equal_comparator_NG(self): - # Should compare without N and G - allele1 = "HLA-A*01:01G" - allele2 = "HLA-A*01:01N" - self.assertEqual(smart_sort_comparator(allele1, allele2), 0) - - def test_first_field_comparator_le(self): - allele1 = "HLA-A*01:01" - allele2 = "HLA-A*02:01" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_first_field_comparator_ge(self): - allele1 = "HLA-A*02:01" - allele2 = "HLA-A*01:01" - self.assertEqual(smart_sort_comparator(allele1, allele2), 1) - - def test_second_field_comparator_le(self): - allele1 = "HLA-A*01:01" - allele2 = "HLA-A*01:02" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_second_field_comparator_le_smart(self): - allele1 = "HLA-A*01:29" - allele2 = "HLA-A*01:100" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_second_field_comparator_ge(self): - allele1 = "HLA-A*01:02" - allele2 = "HLA-A*01:01" - self.assertEqual(smart_sort_comparator(allele1, allele2), 1) - - def test_third_field_comparator_le(self): - allele1 = "HLA-A*01:01:01" - allele2 = "HLA-A*01:01:20" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_third_field_comparator_le_smart(self): - allele1 = "HLA-A*01:01:29" - allele2 = "HLA-A*01:01:100" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_third_field_comparator_ge(self): - allele1 = "HLA-A*01:01:02" - allele2 = "HLA-A*01:01:01" - self.assertEqual(smart_sort_comparator(allele1, allele2), 1) - - def test_fourth_field_comparator_le(self): - allele1 = "HLA-A*01:01:01:01" - allele2 = "HLA-A*01:01:01:20" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_fourth_field_comparator_le_smart(self): - allele1 = "HLA-A*01:01:01:39" - allele2 = "HLA-A*01:01:01:200" - self.assertEqual(smart_sort_comparator(allele1, allele2), -1) - - def test_fourth_field_comparator_ge(self): - allele1 = "HLA-A*01:01:01:30" - allele2 = "HLA-A*01:01:01:09" - self.assertEqual(smart_sort_comparator(allele1, allele2), 1) - - def test_serology_ge(self): - serology1 = "Cw10" - serology2 = "Cw3" - self.assertEqual(smart_sort_comparator(serology1, serology2), 1) - - def test_serology_le(self): - serology1 = "A10" - serology2 = "A25" - self.assertEqual(smart_sort_comparator(serology1, serology2), -1) - - def test_serology_eq(self): - serology1 = "B70" - serology2 = "B70" - self.assertEqual(smart_sort_comparator(serology1, serology2), 0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/expected-serology.json b/tests/unit/expected-serology.json similarity index 100% rename from tests/expected-serology.json rename to tests/unit/expected-serology.json diff --git a/tests/expected.json b/tests/unit/expected.json similarity index 100% rename from tests/expected.json rename to tests/unit/expected.json diff --git a/tests/unit/simple_table/test_column.py b/tests/unit/simple_table/test_column.py new file mode 100644 index 0000000..ad03154 --- /dev/null +++ b/tests/unit/simple_table/test_column.py @@ -0,0 +1,44 @@ +from pyard.simple_table import Column + + +def test_column_creation(): + col = Column("age", ["25", "30", "35"]) + assert col.name == "-age" + assert len(col) == 3 + + +def test_column_apply(): + col = Column("age", ["25", "30"]) + result = col.apply(lambda x: int(x) * 2) + assert result == [50, 60] + + +def test_column_to_list(): + col = Column("name", ["John", "Jane"]) + assert col.to_list() == ["John", "Jane"] + + +def test_column_getitem(): + col = Column("age", ["25", "30", "35"]) + assert col[0] == "25" + assert col[1] == "30" + assert col[-1] == "35" + + +def test_column_iter(): + col = Column("name", ["John", "Jane"]) + values = list(col) + assert values == ["John", "Jane"] + + +def test_column_len(): + col = Column("empty", []) + assert len(col) == 0 + + col = Column("data", ["a", "b", "c"]) + assert len(col) == 3 + + +def test_column_name_property(): + col = Column("test_column", ["value"]) + assert col.name == "-test_column" diff --git a/tests/unit/simple_table/test_simple_table.py b/tests/unit/simple_table/test_simple_table.py new file mode 100644 index 0000000..e8ac4ce --- /dev/null +++ b/tests/unit/simple_table/test_simple_table.py @@ -0,0 +1,304 @@ +import csv +import io +from pyard.simple_table import Table + + +def test_create_table_with_data(): + csv_data = "name,age,city\nJohn,25,NYC\nJane,30,LA" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age", "city"] + + table = Table(reader, columns) + result = table.query("SELECT * FROM data") + + assert len(result) == 2 + assert result[0] == ("John", "25", "NYC") + assert result[1] == ("Jane", "30", "LA") + table.close() + + +def test_empty_reader(): + reader = csv.DictReader(io.StringIO("")) + columns = ["name", "age"] + + table = Table(reader, columns) + result = table.query("SELECT name FROM sqlite_master WHERE type='table'") + + assert len(result) == 0 + table.close() + + +def test_custom_table_name(): + csv_data = "id,value\n1,test" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["id", "value"] + + table = Table(reader, columns, "custom") + result = table.query("SELECT * FROM custom") + + assert result[0] == ("1", "test") + table.close() + + +def test_query_with_where_clause(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + + table = Table(reader, columns) + result = table.query("SELECT * FROM data WHERE age > 25") + + assert len(result) == 1 + assert result[0] == ("Jane", "30") + table.close() + + +def test_query_with_order_by(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + + table = Table(reader, columns) + result = table.query("SELECT * FROM data ORDER BY age DESC") + + assert len(result) == 2 + assert result[0] == ("Jane", "30") + assert result[1] == ("John", "25") + table.close() + + +def test_select_column_with_subscript_operator(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + + table = Table(reader, columns) + ages = table["age"] + + assert len(ages) == 2 + assert ages[0] == "25" + assert ages[1] == "30" + + +def test_create_new_column_with_subscript_operator(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + + table = Table(reader, columns) + table["double_age"] = table["age"].apply(lambda x: int(x) * 2) + double_ages = table["double_age"] + + assert len(double_ages) == 2 + assert double_ages[0] == "50" + assert double_ages[1] == "60" + + +def test_table_from_tuples(): + data = [("John", "25"), ("Jane", "30")] + columns = ["name", "age"] + table = Table(data, columns) + result = table.query("SELECT * FROM data") + assert len(result) == 2 + assert result[0] == ("John", "25") + table.close() + + +def test_columns_property(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + assert table.columns == ["name", "age"] + table.close() + + +def test_head(): + csv_data = "name,age\nJohn,25\nJane,30\nBob,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + head_result = table.head(2) + assert len(head_result.rows) == 2 + table.close() + + +def test_tail(): + csv_data = "name,age\nJohn,25\nJane,30\nBob,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + tail_result = table.tail(2) + assert len(tail_result.rows) == 2 + table.close() + + +def test_group_by(): + csv_data = "city,age\nNYC,25\nLA,30\nNYC,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["city", "age"] + table = Table(reader, columns) + grouped = table.group_by("city") + assert "NYC" in grouped + assert len(grouped["NYC"]) == 2 + table.close() + + +def test_unique_single_column(): + csv_data = "city,age\nNYC,25\nLA,30\nNYC,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["city", "age"] + table = Table(reader, columns) + unique_cities = table.unique("city") + assert len(unique_cities) == 2 + table.close() + + +def test_unique_multiple_columns(): + csv_data = "city,age\nNYC,25\nLA,30\nNYC,25" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["city", "age"] + table = Table(reader, columns) + unique_table = table.unique(["city", "age"]) + assert len(unique_table) == 2 + unique_table.close() + table.close() + + +def test_where(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + filtered = table.where("age > 25") + assert len(filtered) == 1 + filtered.close() + table.close() + + +def test_where_not_null(): + data = [("John", "25"), ("Jane", None), ("Bob", "35")] + columns = ["name", "age"] + table = Table(data, columns) + filtered = table.where_not_null("age") + assert len(filtered) == 2 + filtered.close() + table.close() + + +def test_where_in(): + csv_data = "name,age\nJohn,25\nJane,30\nBob,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + filtered = table.where_in("name", {"John", "Jane"}, ["name", "age"]) + assert len(filtered) == 2 + filtered.close() + table.close() + + +def test_to_dict(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + result_dict = table.to_dict("name", "age") + assert result_dict["John"] == "25" + assert result_dict["Jane"] == "30" + table.close() + + +def test_value_counts(): + csv_data = "city,age\nNYC,25\nLA,30\nNYC,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["city", "age"] + table = Table(reader, columns) + counts = table.value_counts("city") + result = counts.query("SELECT * FROM data_counts") + assert len(result) == 2 + counts.close() + table.close() + + +def test_agg(): + csv_data = "city,age\nNYC,25\nLA,30\nNYC,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["city", "age"] + table = Table(reader, columns) + result = table.agg("city", "age", list) + assert "NYC" in result + table.close() + + +def test_getitem_multiple_columns(): + csv_data = "name,age,city\nJohn,25,NYC\nJane,30,LA" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age", "city"] + table = Table(reader, columns) + subset = table[["name", "city"]] + assert len(subset) == 2 + subset.close() + table.close() + + +def test_rename(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + table.rename({"name": "full_name"}) + assert "full_name" in table.columns + table.close() + + +def test_union(): + data1 = [("John", "25")] + data2 = [("Jane", "30")] + columns = ["name", "age"] + table1 = Table(data1, columns) + table2 = Table(data2, columns) + union_table = table1.union(table2) + assert len(union_table) == 2 + union_table.close() + table1.close() + table2.close() + + +def test_remove(): + csv_data = "name,age\nJohn,25\nJane,30\nBob,35" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + table.remove("name", ["John"]) + assert len(table) == 2 + table.close() + + +def test_concat_columns(): + csv_data = "first,last\nJohn,Doe\nJane,Smith" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["first", "last"] + table = Table(reader, columns) + concat_col = table.concat_columns(["first", "last"]) + assert concat_col[0] == "JohnDoe" + table.close() + + +def test_explode(): + csv_data = "name,tags\nJohn,a;b\nJane,c" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "tags"] + table = Table(reader, columns) + exploded = table.explode("tags", ";") + assert len(exploded) == 3 + exploded.close() + table.close() + + +def test_len(): + csv_data = "name,age\nJohn,25\nJane,30" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + table = Table(reader, columns) + assert len(table) == 2 + table.close() diff --git a/tests/unit/simple_table/test_simple_table_failures.py b/tests/unit/simple_table/test_simple_table_failures.py new file mode 100644 index 0000000..88359ff --- /dev/null +++ b/tests/unit/simple_table/test_simple_table_failures.py @@ -0,0 +1,105 @@ +import pytest +import csv +import io +from pyard.simple_table import Table + + +def test_to_dict_same_columns_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError, match="must be different"): + table.to_dict("name", "name") + table.close() + + +def test_to_dict_invalid_columns_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError, match="must be in the table"): + table.to_dict("invalid", "age") + table.close() + + +def test_union_different_columns_fails(): + table1 = Table([("John", "25")], ["name", "age"]) + table2 = Table([("NYC",)], ["city"]) + with pytest.raises(ValueError, match="same columns"): + table1.union(table2) + table1.close() + table2.close() + + +def test_group_by_invalid_column_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table.group_by("invalid_column") + table.close() + + +def test_getitem_invalid_column_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table["invalid_column"] + table.close() + + +def test_where_invalid_syntax_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table.where("invalid syntax >>>") + table.close() + + +def test_explode_invalid_column_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table.explode("invalid_column", ";") + table.close() + + +def test_value_counts_invalid_column_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table.value_counts("invalid_column") + table.close() + + +def test_rename_invalid_column_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table.rename({"invalid_column": "new_name"}) + table.close() + + +def test_concat_columns_invalid_column_fails(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + table = Table(reader, ["name", "age"]) + with pytest.raises(ValueError): + table.concat_columns(["name", "invalid_column"]) + table.close() + + +def test_invalid_query(): + csv_data = "name,age\nJohn,25" + reader = csv.DictReader(io.StringIO(csv_data)) + columns = ["name", "age"] + + table = Table(reader, columns) + with pytest.raises(Exception): + table.query("SELECT * FROM non_existent_table") + table.close() diff --git a/tests/unit/test_load_allele_list.py b/tests/unit/test_load_allele_list.py new file mode 100644 index 0000000..ce1ed35 --- /dev/null +++ b/tests/unit/test_load_allele_list.py @@ -0,0 +1,75 @@ +import pytest +from unittest.mock import patch +from urllib.error import URLError +from pyard.loader.allele_list import load_allele_list +from pyard.simple_table import Table + +# The pytest unit test covers: +# Success case - Tests parsing CSV data and returning correct dictionary +# Version 3130 handling - Verifies it gets renamed to 3131 +# Latest version - Tests the "Latest" URL path +# Error handling - Tests URLError causes SystemExit +# Key test features: +# Uses unittest.mock.patch to mock urlopen +# Tests dictionary structure and content +# Verifies URL construction logic +# Tests error handling behavior +# +# +# fix patches urlopen at the module level where it's imported (pyard.load.load_new.urlopen) instead of at the global urllib.request.urlopen level. This ensures the mock intercepts the function call where it's actually used in the code. + +mock_data = """# file: Allelelist.3290.txt +# date: 2017-07-10 +# version: IPD-IMGT/HLA 3.29.0 +# origin: https://github.com/ANHIG/IMGTHLA/Allelelist.3290.txt +# repository: https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/allelelist/Allelelist.3290.txt +# author: WHO, Steven G. E. Marsh (steven.marsh@ucl.ac.uk) +AlleleID,Allele +HLA00001,A*01:01:01:01 +HLA02169,A*01:01:01:02N +HLA14798,A*01:01:01:03""" + + +def test_load_allele_list_success(): + with patch("pyard.loader.allele_list.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + result = load_allele_list("3290") + + assert isinstance(result, Table) + allele_ids = result["AlleleID"] + alleles = result["Allele"] + assert allele_ids[0] == "HLA00001" + assert alleles[0] == "A*01:01:01:01" + assert allele_ids[1] == "HLA02169" + assert alleles[1] == "A*01:01:01:02N" + + +def test_load_allele_list_version_3130(): + with patch("pyard.loader.allele_list.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + load_allele_list("3130") + + expected_url = "https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/allelelist/Allelelist.3131.txt" + mock_urlopen.assert_called_once_with(expected_url) + + +def test_load_allele_list_latest(): + with patch("pyard.loader.allele_list.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + load_allele_list("Latest") + + expected_url = ( + "https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/Allelelist.txt" + ) + mock_urlopen.assert_called_once_with(expected_url) + + +def test_load_allele_list_url_error(): + with patch( + "pyard.loader.allele_list.urlopen", side_effect=URLError("Network error") + ): + with pytest.raises(SystemExit): + load_allele_list("3290") diff --git a/tests/unit/test_load_serology_broad_split_mapping.py b/tests/unit/test_load_serology_broad_split_mapping.py new file mode 100644 index 0000000..6c2a352 --- /dev/null +++ b/tests/unit/test_load_serology_broad_split_mapping.py @@ -0,0 +1,72 @@ +import pytest +from unittest.mock import patch +from urllib.error import URLError +from pyard.loader.serology import load_serology_broad_split_mapping +from pyard.simple_table import Table + + +def test_load_serology_broad_split_mapping_success(): + mock_data = """ +# file: rel_ser_ser.txt +# date: 2025-07-14 +# version: IPD-IMGT/HLA 3.61.0 +# origin: http://hla.alleles.org/wmda/rel_ser_ser.txt +# repository: https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/wmda/rel_ser_ser.txt +# author: WHO, Steven G. E. Marsh (steven.marsh@ucl.ac.uk) +A;10;25/26/34/66;25/26/34/66 +B;14;64/65;64/65 +""".strip() + + with patch("pyard.loader.serology.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + splits_table, associated_table = load_serology_broad_split_mapping("3290") + + assert isinstance(splits_table, Table) + assert isinstance(associated_table, Table) + + # Test splits table + broad_col = splits_table["broad"] + splits_col = splits_table["splits"] + assert len(broad_col) == 2 + assert broad_col[0] == "A10" + assert splits_col[0] == "A25/A26/A34/A66" + assert len(splits_col[0].split("/")) == 4 # 4 A splits + assert broad_col[1] == "B14" + assert splits_col[1] == "B64/B65" + assert len(splits_col[1].split("/")) == 2 # 2 B splits + + # Test associated table + split_col = associated_table["split"] + broad_assoc_col = associated_table["broad"] + assert len(split_col) == 6 # 4 A associated + 2 B associated + assert "A25" in split_col + assert "A10" in broad_assoc_col + + +def test_load_serology_broad_split_mapping_empty_splits(): + mock_data = """ +# header 1 +# header 2 +# header 3 +# header 4 +# header 5 +# header 6 +A;10;;; +B;14;64/65;64/65 +""".strip() + + with patch("pyard.loader.serology.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + splits_table, associated_table = load_serology_broad_split_mapping("3290") + + # Only B14 should have splits + assert len(splits_table["broad"]) == 1 + assert splits_table["broad"][0] == "B14" + + +def test_load_serology_broad_split_mapping_url_error(): + with patch("pyard.loader.serology.urlopen", side_effect=URLError("Network error")): + with pytest.raises(SystemExit): + load_serology_broad_split_mapping("3290") diff --git a/tests/unit/test_load_serology_mappings.py b/tests/unit/test_load_serology_mappings.py new file mode 100644 index 0000000..3e7ef2f --- /dev/null +++ b/tests/unit/test_load_serology_mappings.py @@ -0,0 +1,64 @@ +import pytest +from unittest.mock import patch +from urllib.error import URLError +from pyard.loader.serology import load_serology_mappings +from pyard.simple_table import Table + + +def test_load_serology_mappings_success(): + mock_data = """ +# file: rel_dna_ser.txt +# date: 2025-07-14 +# version: IPD-IMGT/HLA 3.61.0 +# origin: http://hla.alleles.org/wmda/rel_dna_ser.txt +# repository: https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/wmda/rel_dna_ser.txt +# author: WHO, Steven G. E. Marsh (steven.marsh@ucl.ac.uk) +A*;01:01:01:01;1;;; +A*;01:01:01:02N;0;;; +A*;01:01:01:03;1;;; +A*;01:01:01:04;1;;; +A*;02:1068Q;;;0/2; +B*;01:01:01:05;1;;; + """.strip() + with patch("pyard.loader.serology.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + result = load_serology_mappings("3290") + + assert isinstance(result, Table) + loci = result["Locus"] + alleles = result["Allele"] + usa = result["USA"] + assert len(loci) == 6 + assert loci[0] == "A*" + assert alleles[0] == "01:01:01:01" + assert usa[0] == "1" + assert loci[5] == "B*" + assert alleles[4] == "02:1068Q" + + +def test_load_serology_mappings_empty_lines(): + mock_data = """ +# header line 1 +# header line 2 +# header line 3 +# header line 4 +# header line 5 +# header line 6 +A*;01:01:01:01;A1;A1;; + +B*;07:02:01;B7;B7;; + """.strip() + + with patch("pyard.loader.serology.urlopen") as mock_urlopen: + mock_urlopen.return_value = mock_data.encode().split(b"\n") + + result = load_serology_mappings("3290") + + assert len(result["Locus"]) == 2 + + +def test_load_serology_mappings_url_error(): + with patch("pyard.loader.serology.urlopen", side_effect=URLError("Network error")): + with pytest.raises(SystemExit): + load_serology_mappings("3290") diff --git a/tests/unit/test_pyard.py b/tests/unit/test_pyard.py new file mode 100644 index 0000000..4333c15 --- /dev/null +++ b/tests/unit/test_pyard.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# py-ard +# Copyright (c) 2023 Be The Match operated by National Marrow Donor Program. All Rights Reserved. +# +# This library is free software; you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation; either version 3 of the License, or (at +# your option) any later version. +# +# This library is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library; if not, write to the Free Software Foundation, +# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. +# +# > http://www.fsf.org/licensing/licenses/lgpl.html +# > http://www.opensource.org/licenses/lgpl-license.php +# + +""" +test_pyard +---------------------------------- + +Tests for `py-ard` module. +""" +import json +import os + +import pytest +import pyard +from pyard.constants import DEFAULT_CACHE_SIZE +from pyard.exceptions import InvalidAlleleError +from pyard.misc import validate_reduction_type + + +@pytest.fixture(scope="module") +def ard(): + db_version = "3440" + return pyard.init(db_version, data_dir="/tmp/py-ard") + + +def test_no_mac(ard): + assert ard.redux("A*01:01:01", "G") == "A*01:01:01G" + assert ard.redux("A*01:01:01", "lg") == "A*01:01g" + assert ard.redux("A*01:01:01", "lgx") == "A*01:01" + assert ard.redux("HLA-A*01:01:01", "G") == "HLA-A*01:01:01G" + assert ard.redux("HLA-A*01:01:01", "lg") == "HLA-A*01:01g" + assert ard.redux("HLA-A*01:01:01", "lgx") == "HLA-A*01:01" + + +def test_remove_invalid(ard): + assert ard.redux("A*01:01:01", "G") == "A*01:01:01G" + + +def test_mac(ard): + assert ard.redux("A*01:AB", "G") == "A*01:01:01G/A*01:02" + assert ard.redux("HLA-A*01:AB", "G") == "HLA-A*01:01:01G/HLA-A*01:02" + + +def test_redux(ard): + data_dir = os.path.dirname(__file__) + expected_json = data_dir + "/expected.json" + with open(expected_json) as json_data: + expected = json.load(json_data) + for ex in expected["redux"]: + glstring = ex["glstring"] + ard_type = ex["ard_type"] + expected_gl = ex["expected_gl"] + assert ard.redux(glstring, ard_type) == expected_gl + + +def test_serology(ard): + data_dir = os.path.dirname(__file__) + expected_json = data_dir + "/expected-serology.json" + with open(expected_json) as json_data: + expected = json.load(json_data) + for ex in expected["redux"]: + glstring = ex["glstring"] + ard_type = ex["ard_type"] + expected_gl = ex["expected_gl"] + assert ard.redux(glstring, ard_type) == expected_gl + + +def test_mac_G(ard): + assert ard.redux("A*01:01:01", "G") == "A*01:01:01G" + assert ard.redux("HLA-A*01:AB", "G") == "HLA-A*01:01:01G/HLA-A*01:02" + with pytest.raises(InvalidAlleleError): + ard._redux_allele("HLA-A*01:AB", "G") + + +def test_xx_code(ard): + expanded_string = """ + B*40:01:01G/B*40:01:03G/B*40:02:01G/B*40:03:01G/B*40:04:01G/B*40:05:01G/B*40:06:01G/B*40:07/B*40:08/B*40:09/B*40:10:01G/B*40:11:01G/B*40:12/B*40:13/B*40:14/B*40:15/B*40:16:01G/B*40:18/B*40:19/B*40:20:01G/B*40:21/B*40:22N/B*40:23/B*40:24/B*40:25/B*40:26/B*40:27/B*40:28/B*40:29/B*40:30/B*40:31/B*40:32/B*40:33/B*40:34/B*40:35/B*40:36/B*40:37/B*40:38/B*40:39/B*40:40:01G/B*40:42/B*40:43/B*40:44/B*40:45/B*40:46/B*40:47/B*40:48/B*40:49/B*40:50:01G/B*40:51/B*40:52/B*40:53/B*40:54/B*40:57/B*40:58/B*40:59/B*40:60/B*40:61/B*40:62/B*40:63/B*40:64:01G/B*40:65/B*40:66/B*40:67/B*40:68/B*40:69/B*40:70/B*40:71/B*40:72/B*40:73/B*40:74/B*40:75/B*40:76/B*40:77/B*40:78/B*40:79/B*40:80/B*40:81/B*40:82/B*40:83/B*40:84/B*40:85/B*40:86/B*40:87/B*40:88/B*40:89/B*40:90/B*40:91/B*40:92/B*40:93/B*40:94/B*40:95/B*40:96/B*40:98/B*40:99/B*40:100/B*40:101/B*40:102/B*40:103/B*40:104/B*40:105/B*40:106/B*40:107/B*40:108/B*40:109/B*40:110/B*40:111/B*40:112/B*40:113/B*40:114:01G/B*40:115/B*40:116/B*40:117/B*40:118N/B*40:119/B*40:120/B*40:121/B*40:122/B*40:123/B*40:124/B*40:125/B*40:126/B*40:127/B*40:128/B*40:129/B*40:130/B*40:131/B*40:132/B*40:133Q/B*40:134/B*40:135/B*40:136/B*40:137/B*40:138/B*40:139/B*40:140/B*40:142N/B*40:143/B*40:145/B*40:146/B*40:147/B*40:148/B*40:149/B*40:152/B*40:153/B*40:154/B*40:155:01G/B*40:156/B*40:157/B*40:158/B*40:159/B*40:160/B*40:161/B*40:162/B*40:163/B*40:164/B*40:165/B*40:166/B*40:167/B*40:168/B*40:169/B*40:170/B*40:171/B*40:172/B*40:173/B*40:174/B*40:175/B*40:177/B*40:178/B*40:180/B*40:181/B*40:182/B*40:183/B*40:184/B*40:185/B*40:186/B*40:187/B*40:188/B*40:189/B*40:190/B*40:191/B*40:192/B*40:193/B*40:194/B*40:195/B*40:196/B*40:197/B*40:198/B*40:199/B*40:200/B*40:201/B*40:202/B*40:203/B*40:204/B*40:205/B*40:206/B*40:207/B*40:208/B*40:209/B*40:210/B*40:211/B*40:212/B*40:213:01G/B*40:214/B*40:215/B*40:216N/B*40:217/B*40:218/B*40:219/B*40:220/B*40:222/B*40:223/B*40:224/B*40:225/B*40:226/B*40:227/B*40:228/B*40:230/B*40:231/B*40:232/B*40:233/B*40:234/B*40:235/B*40:237/B*40:238/B*40:239/B*40:240/B*40:242/B*40:243/B*40:244/B*40:245/B*40:246/B*40:248/B*40:249/B*40:250/B*40:251/B*40:252/B*40:253/B*40:254/B*40:255/B*40:256N/B*40:257/B*40:258/B*40:259/B*40:260/B*40:261/B*40:262/B*40:263N/B*40:265N/B*40:266/B*40:268/B*40:269/B*40:270/B*40:271/B*40:273/B*40:274/B*40:275/B*40:276/B*40:277/B*40:279/B*40:280/B*40:281/B*40:282/B*40:283/B*40:284/B*40:285/B*40:286N/B*40:287/B*40:288/B*40:289/B*40:290/B*40:291N/B*40:292/B*40:293/B*40:294/B*40:295/B*40:296/B*40:297/B*40:298/B*40:300/B*40:302/B*40:304/B*40:305/B*40:306/B*40:307/B*40:308/B*40:309/B*40:310/B*40:311/B*40:312/B*40:313/B*40:314/B*40:315/B*40:316/B*40:317/B*40:318/B*40:319/B*40:320/B*40:321/B*40:322/B*40:323/B*40:324/B*40:325/B*40:326/B*40:327/B*40:328/B*40:330/B*40:331/B*40:332/B*40:333/B*40:334/B*40:335/B*40:336/B*40:337N/B*40:339/B*40:340/B*40:341/B*40:342/B*40:343/B*40:344/B*40:345N/B*40:346/B*40:347/B*40:348/B*40:349/B*40:350/B*40:351/B*40:352/B*40:354/B*40:355/B*40:357/B*40:358/B*40:359/B*40:360/B*40:361N/B*40:362/B*40:363/B*40:364/B*40:365/B*40:366/B*40:367/B*40:368/B*40:369/B*40:370/B*40:371/B*40:372N/B*40:373/B*40:374/B*40:375/B*40:376/B*40:377/B*40:378/B*40:380/B*40:381/B*40:382/B*40:385/B*40:388/B*40:389/B*40:390/B*40:391/B*40:392/B*40:393/B*40:394/B*40:396/B*40:397/B*40:398/B*40:399N/B*40:400/B*40:401/B*40:402/B*40:403/B*40:404/B*40:407/B*40:408/B*40:409/B*40:410/B*40:411/B*40:412/B*40:413/B*40:414/B*40:415/B*40:420/B*40:421Q/B*40:422/B*40:423/B*40:424/B*40:426N/B*40:428N/B*40:429/B*40:430/B*40:432/B*40:433/B*40:434/B*40:436/B*40:437/B*40:438N/B*40:441/B*40:445/B*40:447/B*40:448/B*40:449/B*40:451/B*40:452/B*40:454/B*40:457/B*40:458/B*40:459/B*40:460/B*40:461/B*40:462/B*40:463/B*40:465/B*40:466/B*40:467/B*40:468/B*40:469/B*40:470/B*40:471/B*40:472/B*40:477/B*40:478/B*40:479/B*40:481N/B*40:482 + """.strip() + gl = ard.redux("B*40:XX", "G") + assert gl == expanded_string + + +def test_xx_code_with_prefix(ard): + expanded_string = """ + HLA-B*40:01:01G/HLA-B*40:01:03G/HLA-B*40:02:01G/HLA-B*40:03:01G/HLA-B*40:04:01G/HLA-B*40:05:01G/HLA-B*40:06:01G/HLA-B*40:07/HLA-B*40:08/HLA-B*40:09/HLA-B*40:10:01G/HLA-B*40:11:01G/HLA-B*40:12/HLA-B*40:13/HLA-B*40:14/HLA-B*40:15/HLA-B*40:16:01G/HLA-B*40:18/HLA-B*40:19/HLA-B*40:20:01G/HLA-B*40:21/HLA-B*40:22N/HLA-B*40:23/HLA-B*40:24/HLA-B*40:25/HLA-B*40:26/HLA-B*40:27/HLA-B*40:28/HLA-B*40:29/HLA-B*40:30/HLA-B*40:31/HLA-B*40:32/HLA-B*40:33/HLA-B*40:34/HLA-B*40:35/HLA-B*40:36/HLA-B*40:37/HLA-B*40:38/HLA-B*40:39/HLA-B*40:40:01G/HLA-B*40:42/HLA-B*40:43/HLA-B*40:44/HLA-B*40:45/HLA-B*40:46/HLA-B*40:47/HLA-B*40:48/HLA-B*40:49/HLA-B*40:50:01G/HLA-B*40:51/HLA-B*40:52/HLA-B*40:53/HLA-B*40:54/HLA-B*40:57/HLA-B*40:58/HLA-B*40:59/HLA-B*40:60/HLA-B*40:61/HLA-B*40:62/HLA-B*40:63/HLA-B*40:64:01G/HLA-B*40:65/HLA-B*40:66/HLA-B*40:67/HLA-B*40:68/HLA-B*40:69/HLA-B*40:70/HLA-B*40:71/HLA-B*40:72/HLA-B*40:73/HLA-B*40:74/HLA-B*40:75/HLA-B*40:76/HLA-B*40:77/HLA-B*40:78/HLA-B*40:79/HLA-B*40:80/HLA-B*40:81/HLA-B*40:82/HLA-B*40:83/HLA-B*40:84/HLA-B*40:85/HLA-B*40:86/HLA-B*40:87/HLA-B*40:88/HLA-B*40:89/HLA-B*40:90/HLA-B*40:91/HLA-B*40:92/HLA-B*40:93/HLA-B*40:94/HLA-B*40:95/HLA-B*40:96/HLA-B*40:98/HLA-B*40:99/HLA-B*40:100/HLA-B*40:101/HLA-B*40:102/HLA-B*40:103/HLA-B*40:104/HLA-B*40:105/HLA-B*40:106/HLA-B*40:107/HLA-B*40:108/HLA-B*40:109/HLA-B*40:110/HLA-B*40:111/HLA-B*40:112/HLA-B*40:113/HLA-B*40:114:01G/HLA-B*40:115/HLA-B*40:116/HLA-B*40:117/HLA-B*40:118N/HLA-B*40:119/HLA-B*40:120/HLA-B*40:121/HLA-B*40:122/HLA-B*40:123/HLA-B*40:124/HLA-B*40:125/HLA-B*40:126/HLA-B*40:127/HLA-B*40:128/HLA-B*40:129/HLA-B*40:130/HLA-B*40:131/HLA-B*40:132/HLA-B*40:133Q/HLA-B*40:134/HLA-B*40:135/HLA-B*40:136/HLA-B*40:137/HLA-B*40:138/HLA-B*40:139/HLA-B*40:140/HLA-B*40:142N/HLA-B*40:143/HLA-B*40:145/HLA-B*40:146/HLA-B*40:147/HLA-B*40:148/HLA-B*40:149/HLA-B*40:152/HLA-B*40:153/HLA-B*40:154/HLA-B*40:155:01G/HLA-B*40:156/HLA-B*40:157/HLA-B*40:158/HLA-B*40:159/HLA-B*40:160/HLA-B*40:161/HLA-B*40:162/HLA-B*40:163/HLA-B*40:164/HLA-B*40:165/HLA-B*40:166/HLA-B*40:167/HLA-B*40:168/HLA-B*40:169/HLA-B*40:170/HLA-B*40:171/HLA-B*40:172/HLA-B*40:173/HLA-B*40:174/HLA-B*40:175/HLA-B*40:177/HLA-B*40:178/HLA-B*40:180/HLA-B*40:181/HLA-B*40:182/HLA-B*40:183/HLA-B*40:184/HLA-B*40:185/HLA-B*40:186/HLA-B*40:187/HLA-B*40:188/HLA-B*40:189/HLA-B*40:190/HLA-B*40:191/HLA-B*40:192/HLA-B*40:193/HLA-B*40:194/HLA-B*40:195/HLA-B*40:196/HLA-B*40:197/HLA-B*40:198/HLA-B*40:199/HLA-B*40:200/HLA-B*40:201/HLA-B*40:202/HLA-B*40:203/HLA-B*40:204/HLA-B*40:205/HLA-B*40:206/HLA-B*40:207/HLA-B*40:208/HLA-B*40:209/HLA-B*40:210/HLA-B*40:211/HLA-B*40:212/HLA-B*40:213:01G/HLA-B*40:214/HLA-B*40:215/HLA-B*40:216N/HLA-B*40:217/HLA-B*40:218/HLA-B*40:219/HLA-B*40:220/HLA-B*40:222/HLA-B*40:223/HLA-B*40:224/HLA-B*40:225/HLA-B*40:226/HLA-B*40:227/HLA-B*40:228/HLA-B*40:230/HLA-B*40:231/HLA-B*40:232/HLA-B*40:233/HLA-B*40:234/HLA-B*40:235/HLA-B*40:237/HLA-B*40:238/HLA-B*40:239/HLA-B*40:240/HLA-B*40:242/HLA-B*40:243/HLA-B*40:244/HLA-B*40:245/HLA-B*40:246/HLA-B*40:248/HLA-B*40:249/HLA-B*40:250/HLA-B*40:251/HLA-B*40:252/HLA-B*40:253/HLA-B*40:254/HLA-B*40:255/HLA-B*40:256N/HLA-B*40:257/HLA-B*40:258/HLA-B*40:259/HLA-B*40:260/HLA-B*40:261/HLA-B*40:262/HLA-B*40:263N/HLA-B*40:265N/HLA-B*40:266/HLA-B*40:268/HLA-B*40:269/HLA-B*40:270/HLA-B*40:271/HLA-B*40:273/HLA-B*40:274/HLA-B*40:275/HLA-B*40:276/HLA-B*40:277/HLA-B*40:279/HLA-B*40:280/HLA-B*40:281/HLA-B*40:282/HLA-B*40:283/HLA-B*40:284/HLA-B*40:285/HLA-B*40:286N/HLA-B*40:287/HLA-B*40:288/HLA-B*40:289/HLA-B*40:290/HLA-B*40:291N/HLA-B*40:292/HLA-B*40:293/HLA-B*40:294/HLA-B*40:295/HLA-B*40:296/HLA-B*40:297/HLA-B*40:298/HLA-B*40:300/HLA-B*40:302/HLA-B*40:304/HLA-B*40:305/HLA-B*40:306/HLA-B*40:307/HLA-B*40:308/HLA-B*40:309/HLA-B*40:310/HLA-B*40:311/HLA-B*40:312/HLA-B*40:313/HLA-B*40:314/HLA-B*40:315/HLA-B*40:316/HLA-B*40:317/HLA-B*40:318/HLA-B*40:319/HLA-B*40:320/HLA-B*40:321/HLA-B*40:322/HLA-B*40:323/HLA-B*40:324/HLA-B*40:325/HLA-B*40:326/HLA-B*40:327/HLA-B*40:328/HLA-B*40:330/HLA-B*40:331/HLA-B*40:332/HLA-B*40:333/HLA-B*40:334/HLA-B*40:335/HLA-B*40:336/HLA-B*40:337N/HLA-B*40:339/HLA-B*40:340/HLA-B*40:341/HLA-B*40:342/HLA-B*40:343/HLA-B*40:344/HLA-B*40:345N/HLA-B*40:346/HLA-B*40:347/HLA-B*40:348/HLA-B*40:349/HLA-B*40:350/HLA-B*40:351/HLA-B*40:352/HLA-B*40:354/HLA-B*40:355/HLA-B*40:357/HLA-B*40:358/HLA-B*40:359/HLA-B*40:360/HLA-B*40:361N/HLA-B*40:362/HLA-B*40:363/HLA-B*40:364/HLA-B*40:365/HLA-B*40:366/HLA-B*40:367/HLA-B*40:368/HLA-B*40:369/HLA-B*40:370/HLA-B*40:371/HLA-B*40:372N/HLA-B*40:373/HLA-B*40:374/HLA-B*40:375/HLA-B*40:376/HLA-B*40:377/HLA-B*40:378/HLA-B*40:380/HLA-B*40:381/HLA-B*40:382/HLA-B*40:385/HLA-B*40:388/HLA-B*40:389/HLA-B*40:390/HLA-B*40:391/HLA-B*40:392/HLA-B*40:393/HLA-B*40:394/HLA-B*40:396/HLA-B*40:397/HLA-B*40:398/HLA-B*40:399N/HLA-B*40:400/HLA-B*40:401/HLA-B*40:402/HLA-B*40:403/HLA-B*40:404/HLA-B*40:407/HLA-B*40:408/HLA-B*40:409/HLA-B*40:410/HLA-B*40:411/HLA-B*40:412/HLA-B*40:413/HLA-B*40:414/HLA-B*40:415/HLA-B*40:420/HLA-B*40:421Q/HLA-B*40:422/HLA-B*40:423/HLA-B*40:424/HLA-B*40:426N/HLA-B*40:428N/HLA-B*40:429/HLA-B*40:430/HLA-B*40:432/HLA-B*40:433/HLA-B*40:434/HLA-B*40:436/HLA-B*40:437/HLA-B*40:438N/HLA-B*40:441/HLA-B*40:445/HLA-B*40:447/HLA-B*40:448/HLA-B*40:449/HLA-B*40:451/HLA-B*40:452/HLA-B*40:454/HLA-B*40:457/HLA-B*40:458/HLA-B*40:459/HLA-B*40:460/HLA-B*40:461/HLA-B*40:462/HLA-B*40:463/HLA-B*40:465/HLA-B*40:466/HLA-B*40:467/HLA-B*40:468/HLA-B*40:469/HLA-B*40:470/HLA-B*40:471/HLA-B*40:472/HLA-B*40:477/HLA-B*40:478/HLA-B*40:479/HLA-B*40:481N/HLA-B*40:482 + """.strip() + gl = ard.redux("HLA-B*40:XX", "G") + assert expanded_string == gl + + +def test_expand_mac(ard): + mac_ab_expanded = ["A*01:01", "A*01:02"] + assert ard.expand_mac("A*01:AB") == "/".join(mac_ab_expanded) + + mac_hla_ab_expanded = ["HLA-A*01:01", "HLA-A*01:02"] + assert ard.expand_mac("HLA-A*01:AB") == "/".join(mac_hla_ab_expanded) + + mac_ac_expanded = ["A*01:01", "A*01:03"] + assert ard.expand_mac("A*01:AC") == "/".join(mac_ac_expanded) + + mac_hla_ac_expanded = ["HLA-A*01:01", "HLA-A*01:03"] + assert ard.expand_mac("HLA-A*01:AC") == "/".join(mac_hla_ac_expanded) + + +def test_redux_types(): + assert validate_reduction_type("G") is None + assert validate_reduction_type("lg") is None + assert validate_reduction_type("lgx") is None + assert validate_reduction_type("W") is None + assert validate_reduction_type("exon") is None + with pytest.raises(ValueError): + validate_reduction_type("XX") + + +def test_empty_allele(ard): + with pytest.raises(InvalidAlleleError): + ard.redux("A*", "lgx") + + +def test_fp_allele(ard): + with pytest.raises(InvalidAlleleError): + ard.redux("A*0.123", "lgx") + + +def test_empty_fields(ard): + with pytest.raises(InvalidAlleleError): + # : without any data + ard.redux("DQA1*01:01:01:G", "lgx") + + +def test_invalid_serology(ard): + # Test that A10 works and the first one is 'A*25:01' + serology_a10 = ard.redux("A10", "lgx") + assert serology_a10.split("/")[0] == "A*25:01" + # And A100 isn't a valid typing + with pytest.raises(InvalidAlleleError): + ard.redux("A100", "lgx") + + +def test_allele_duplicated(ard): + # Make sure the reduced alleles are unique + # https://github.com/nmdp-bioinformatics/py-ard/issues/135 + allele_code = "C*02:ACMGS" + allele_code_rx = ard.redux(allele_code, "lgx") + assert allele_code_rx == "C*02:02" + + +def test_imgt_db_version(ard): + assert ard.get_db_version() == 3440 + + +def test_xx_codes_broad_split(ard): + assert "DQB1*06" not in ard.redux("DQB1*05:XX", "lgx") + + +def test_cache_info(ard): + # validate the default cache size + assert ard._redux_allele.cache_info().maxsize == DEFAULT_CACHE_SIZE + assert ard.redux.cache_info().maxsize == DEFAULT_CACHE_SIZE + # validate you can change the cache size + higher_cache_size = 5_000_000 + another_ard = pyard.init( + "3440", data_dir="/tmp/py-ard", cache_size=higher_cache_size + ) + assert another_ard._redux_allele.cache_info().maxsize == higher_cache_size + assert another_ard.redux.cache_info().maxsize == higher_cache_size + + +def test_is_null(ard): + # a null allele + allele = "A*01:01N" + assert ard.is_null(allele) + # not null allele + allele = "A*01:01" + assert not ard.is_null(allele) + # MACs ending with N shouldn't be called as Nulls + allele = "A*01:MN" + assert not ard.is_null(allele) + # MACs shouldn't be called as Nulls + allele = "A*01:AB" + assert not ard.is_null(allele) + + +def test_default_redux_is_lgx(ard): + allele = "A*24:BKKPV+A*26:03^B*15:BKNTS+B*15:07" + lgx_redux = ard.redux(allele, "lgx") + default_redux = ard.redux(allele) + assert lgx_redux == default_redux + + +def test_mac_is_reversible(ard): + mac_code = "A*68:AJEBX" + expanded_mac = ard.expand_mac(mac_code) + lookup_mac = ard.lookup_mac(expanded_mac) + assert mac_code == lookup_mac diff --git a/tests/unit/test_smart_sort.py b/tests/unit/test_smart_sort.py new file mode 100644 index 0000000..b4b49b4 --- /dev/null +++ b/tests/unit/test_smart_sort.py @@ -0,0 +1,131 @@ +# +# py-ard +# Copyright (c) 2023 Be The Match operated by National Marrow Donor Program. All Rights Reserved. +# +# This library is free software; you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation; either version 3 of the License, or (at +# your option) any later version. +# +# This library is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; with out even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this library; if not, write to the Free Software Foundation, +# Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. +# +# > http://www.fsf.org/licensing/licenses/lgpl.html +# > http://www.opensource.org/licenses/lgpl-license.php +# +from pyard.smart_sort import smart_sort_comparator + + +def test_same_comparator(): + allele = "HLA-A*01:01" + assert smart_sort_comparator(allele, allele) == 0 + + +def test_equal_comparator(): + allele1 = "HLA-A*01:01" + allele2 = "HLA-A*01:01" + assert smart_sort_comparator(allele1, allele2) == 0 + + +def test_equal_comparator_G(): + # Should compare without G + allele1 = "HLA-A*01:01G" + allele2 = "HLA-A*01:01" + assert smart_sort_comparator(allele1, allele2) == 0 + + +def test_equal_comparator_NG(): + # Should compare without N and G + allele1 = "HLA-A*01:01G" + allele2 = "HLA-A*01:01N" + assert smart_sort_comparator(allele1, allele2) == 0 + + +def test_first_field_comparator_le(): + allele1 = "HLA-A*01:01" + allele2 = "HLA-A*02:01" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_first_field_comparator_ge(): + allele1 = "HLA-A*02:01" + allele2 = "HLA-A*01:01" + assert smart_sort_comparator(allele1, allele2) == 1 + + +def test_second_field_comparator_le(): + allele1 = "HLA-A*01:01" + allele2 = "HLA-A*01:02" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_second_field_comparator_le_smart(): + allele1 = "HLA-A*01:29" + allele2 = "HLA-A*01:100" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_second_field_comparator_ge(): + allele1 = "HLA-A*01:02" + allele2 = "HLA-A*01:01" + assert smart_sort_comparator(allele1, allele2) == 1 + + +def test_third_field_comparator_le(): + allele1 = "HLA-A*01:01:01" + allele2 = "HLA-A*01:01:20" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_third_field_comparator_le_smart(): + allele1 = "HLA-A*01:01:29" + allele2 = "HLA-A*01:01:100" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_third_field_comparator_ge(): + allele1 = "HLA-A*01:01:02" + allele2 = "HLA-A*01:01:01" + assert smart_sort_comparator(allele1, allele2) == 1 + + +def test_fourth_field_comparator_le(): + allele1 = "HLA-A*01:01:01:01" + allele2 = "HLA-A*01:01:01:20" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_fourth_field_comparator_le_smart(): + allele1 = "HLA-A*01:01:01:39" + allele2 = "HLA-A*01:01:01:200" + assert smart_sort_comparator(allele1, allele2) == -1 + + +def test_fourth_field_comparator_ge(): + allele1 = "HLA-A*01:01:01:30" + allele2 = "HLA-A*01:01:01:09" + assert smart_sort_comparator(allele1, allele2) == 1 + + +def test_serology_ge(): + serology1 = "Cw10" + serology2 = "Cw3" + assert smart_sort_comparator(serology1, serology2) == 1 + + +def test_serology_le(): + serology1 = "A10" + serology2 = "A25" + assert smart_sort_comparator(serology1, serology2) == -1 + + +def test_serology_eq(): + serology1 = "B70" + serology2 = "B70" + assert smart_sort_comparator(serology1, serology2) == 0