Skip to content

Commit 31aa0e2

Browse files
authored
Merge pull request #210 from pbashyal-nmdp/cli_improvements
Cli Tools improvements
2 parents 4b942bc + d5a0767 commit 31aa0e2

File tree

8 files changed

+127
-64
lines changed

8 files changed

+127
-64
lines changed

pyard/data_repository.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,9 @@ def generate_short_nulls(db_connection, who_group):
486486
return shortnulls
487487

488488

489-
def generate_mac_codes(db_connection: sqlite3.Connection, refresh_mac: bool):
489+
def generate_mac_codes(
490+
db_connection: sqlite3.Connection, refresh_mac: bool = False, load_mac: bool = True
491+
):
490492
"""
491493
MAC files come in 2 different versions:
492494
@@ -530,29 +532,31 @@ def generate_mac_codes(db_connection: sqlite3.Connection, refresh_mac: bool):
530532
531533
:param db_connection: Database connection to the sqlite database
532534
:param refresh_mac: Refresh the database with newer MAC data ?
535+
:param load_mac: Should MAC be loaded at all
533536
:return: None
534537
"""
535-
mac_table_name = "mac_codes"
536-
if refresh_mac or not db.table_exists(db_connection, mac_table_name):
537-
# Load the MAC file to a DataFrame
538-
mac_url = "https://hml.nmdp.org/mac/files/numer.v3.zip"
539-
df_mac = pd.read_csv(
540-
mac_url,
541-
sep="\t",
542-
compression="zip",
543-
skiprows=3,
544-
names=["Code", "Alleles"],
545-
keep_default_na=False,
546-
)
547-
# Create a dict from code to alleles
548-
mac = df_mac.set_index("Code")["Alleles"].to_dict()
549-
# Save the mac dict to db
550-
db.save_dict(
551-
db_connection,
552-
table_name=mac_table_name,
553-
dictionary=mac,
554-
columns=("code", "alleles"),
555-
)
538+
if load_mac:
539+
mac_table_name = "mac_codes"
540+
if refresh_mac or not db.table_exists(db_connection, mac_table_name):
541+
# Load the MAC file to a DataFrame
542+
mac_url = "https://hml.nmdp.org/mac/files/numer.v3.zip"
543+
df_mac = pd.read_csv(
544+
mac_url,
545+
sep="\t",
546+
compression="zip",
547+
skiprows=3,
548+
names=["Code", "Alleles"],
549+
keep_default_na=False,
550+
)
551+
# Create a dict from code to alleles
552+
mac = df_mac.set_index("Code")["Alleles"].to_dict()
553+
# Save the mac dict to db
554+
db.save_dict(
555+
db_connection,
556+
table_name=mac_table_name,
557+
dictionary=mac,
558+
columns=("code", "alleles"),
559+
)
556560

557561

558562
def to_serological_name(locus_name: str):

pyard/db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def create_db_connection(data_dir, imgt_version, ro=False):
5555
# Open the database in read-only mode
5656
file_uri = f"file:{db_filename}?mode=ro"
5757
# Multiple threads can access the same connection since it's only ro
58-
return sqlite3.connect(file_uri, check_same_thread=False, uri=True)
58+
return sqlite3.connect(file_uri, check_same_thread=False, uri=True), db_filename
5959

6060
# Check the imgt_version is a valid IMGT DB Version
6161
# by querying the IMGT site
@@ -73,7 +73,7 @@ def create_db_connection(data_dir, imgt_version, ro=False):
7373

7474
# Open the database for read/write
7575
file_uri = f"file:{db_filename}"
76-
return sqlite3.connect(file_uri, uri=True)
76+
return sqlite3.connect(file_uri, uri=True), db_filename
7777

7878

7979
def table_exists(connection: sqlite3.Connection, table_name: str) -> bool:

pyard/misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,14 @@ def get_data_dir(data_dir):
9494
else:
9595
data_dir = db.get_pyard_db_install_directory()
9696
return data_dir
97+
98+
99+
def get_imgt_version(imgt_version):
100+
if imgt_version:
101+
version = imgt_version.replace(".", "")
102+
if version.isdigit():
103+
return version
104+
raise RuntimeError(
105+
f"{imgt_version} is not a valid IMGT database version number"
106+
)
107+
return "Latest"

pyard/pyard.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,11 @@ class ARD(object):
8080
"""
8181

8282
def __init__(
83-
self, imgt_version: str = "Latest", data_dir: str = None, config: dict = None
83+
self,
84+
imgt_version: str = "Latest",
85+
data_dir: str = None,
86+
load_mac: bool = True,
87+
config: dict = None,
8488
):
8589
"""
8690
ARD will load valid alleles, xx codes and MAC mappings for the given
@@ -97,10 +101,10 @@ def __init__(
97101
self._config.update(config)
98102

99103
# Create a database connection for writing
100-
self.db_connection = db.create_db_connection(data_dir, imgt_version)
104+
self.db_connection, _ = db.create_db_connection(data_dir, imgt_version)
101105

102106
# Load MAC codes
103-
dr.generate_mac_codes(self.db_connection, False)
107+
dr.generate_mac_codes(self.db_connection, refresh_mac=False, load_mac=load_mac)
104108
# Load ARS mappings
105109
self.ars_mappings, p_group = dr.generate_ars_mapping(
106110
self.db_connection, imgt_version
@@ -140,7 +144,7 @@ def __init__(
140144
gc.freeze()
141145

142146
# Re-open the connection in read-only mode as we're not updating it anymore
143-
self.db_connection = db.create_db_connection(data_dir, imgt_version, ro=True)
147+
self.db_connection, _ = db.create_db_connection(data_dir, imgt_version, ro=True)
144148

145149
def __del__(self):
146150
"""
@@ -393,7 +397,7 @@ def validate(self, glstring):
393397
except InvalidAlleleError as e:
394398
raise InvalidTypingError(
395399
f"{glstring} is not valid GL String. \n {e.message}", e
396-
)
400+
) from None
397401

398402
def is_XX(self, glstring: str, loc_antigen: str = None, code: str = None) -> bool:
399403
if loc_antigen is None or code is None:
@@ -718,7 +722,7 @@ def refresh_mac_codes(self) -> None:
718722
Refreshes MAC code for the current IMGT db version.
719723
:return: None
720724
"""
721-
dr.generate_mac_codes(self.db_connection, True)
725+
dr.generate_mac_codes(self.db_connection, refresh_mac=True)
722726

723727
def get_db_version(self) -> str:
724728
"""

scripts/pyard

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,8 @@ import argparse
2525
import sys
2626

2727
import pyard
28-
29-
30-
def get_imgt_version(imgt_version):
31-
if imgt_version:
32-
version = imgt_version.replace(".", "")
33-
if version.isdigit():
34-
return version
35-
raise RuntimeError(
36-
f"{imgt_version} is not a valid IMGT database version number"
37-
)
38-
return None
39-
28+
from pyard.exceptions import InvalidAlleleError
29+
from pyard.misc import get_data_dir, get_imgt_version
4030

4131
if __name__ == "__main__":
4232
parser = argparse.ArgumentParser(
@@ -51,6 +41,12 @@ if __name__ == "__main__":
5141
action="store_true",
5242
help="IPD-IMGT/HLA DB Version number",
5343
)
44+
parser.add_argument(
45+
"-d",
46+
"--data-dir",
47+
dest="data_dir",
48+
help="Data directory to store imported data",
49+
)
5450
parser.add_argument(
5551
"-i",
5652
"--imgt-version",
@@ -68,22 +64,33 @@ if __name__ == "__main__":
6864

6965
args = parser.parse_args()
7066

67+
if args.splits:
68+
mapping = pyard.find_broad_splits(args.splits)
69+
if mapping:
70+
print(f"{mapping[0]} = {'/'.join(mapping[1])}")
71+
sys.exit(0)
72+
7173
imgt_version = get_imgt_version(args.imgt_version)
72-
if imgt_version:
73-
ard = pyard.ARD(imgt_version)
74-
else:
75-
ard = pyard.ARD()
74+
data_dir = get_data_dir(args.data_dir)
75+
ard = pyard.ARD(imgt_version=imgt_version, data_dir=data_dir)
7676

7777
if args.version:
7878
version = ard.get_db_version()
7979
print(f"IPD-IMGT/HLA version:", version)
80+
print(f"py-ard version:", pyard.__version__)
8081
sys.exit(0)
8182

82-
if args.splits:
83-
mapping = pyard.find_broad_splits(args.splits)
84-
if mapping:
85-
print(f"{mapping[0]} = {'/'.join(mapping[1])}")
86-
sys.exit(0)
83+
try:
84+
if args.redux_type:
85+
print(ard.redux_gl(args.gl_string, args.redux_type))
86+
else:
87+
for redux_type in pyard.pyard.reduction_types:
88+
redux_type_info = f"Reduction Method: {redux_type}"
89+
print(redux_type_info)
90+
print("-" * len(redux_type_info))
91+
print(ard.redux_gl(args.gl_string, redux_type))
92+
except InvalidAlleleError as e:
93+
print("Error:", e)
8794

88-
print(ard.redux_gl(args.gl_string, args.redux_type))
95+
# Remove ard and close db connection
8996
del ard

scripts/pyard-import

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,13 @@ if __name__ == "__main__":
6868
help="Show Versions of available IMGT Databases",
6969
)
7070
parser.add_argument(
71-
"--db-version",
71+
"-i",
72+
"--imgt-version",
7273
dest="imgt_version",
7374
help="Import supplied IMGT_VERSION DB Version",
7475
)
7576
parser.add_argument(
77+
"-d",
7678
"--data-dir",
7779
dest="data_dir",
7880
help="Data directory to store imported data",
@@ -92,6 +94,12 @@ if __name__ == "__main__":
9294
action="store_true",
9395
help="reinstall a fresh version of database",
9496
)
97+
parser.add_argument(
98+
"--skip-mac",
99+
dest="skip_mac",
100+
action="store_true",
101+
help="Skip creating MAC mapping",
102+
)
95103
args = parser.parse_args()
96104

97105
if args.show_versions:
@@ -118,8 +126,14 @@ if __name__ == "__main__":
118126
db_fullname.unlink(missing_ok=True)
119127

120128
print(f"Importing IMGT database version: {imgt_version}")
129+
if args.skip_mac:
130+
load_mac = False
131+
print(f"Skipping MAC tables creation")
132+
else:
133+
load_mac = True
134+
121135
try:
122-
ard = pyard.ARD(imgt_version=imgt_version, data_dir=data_dir)
136+
ard = pyard.ARD(imgt_version=imgt_version, data_dir=data_dir, load_mac=load_mac)
123137
except ValueError as e:
124138
print(f"Error importing version {imgt_version}:", e)
125139
sys.exit(1)
@@ -141,6 +155,6 @@ if __name__ == "__main__":
141155

142156
if args.refresh_mac:
143157
print(f"Updating MACs")
144-
db_connection = db.create_db_connection(data_dir, imgt_version, ro=False)
145-
data_repository.generate_mac_codes(db_connection, True)
158+
db_connection, _ = db.create_db_connection(data_dir, imgt_version, ro=False)
159+
data_repository.generate_mac_codes(db_connection, refresh_mac=True)
146160
print(f"Updated MACs for {imgt_version} IMGT database.")

scripts/pyard-reduce-csv

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import pandas as pd
3838
import pyard
3939
import pyard.drbx as drbx
4040
from pyard.exceptions import PyArdError
41+
from pyard.misc import get_data_dir, get_imgt_version
4142

4243

4344
def is_serology(allele: str) -> bool:
@@ -171,11 +172,21 @@ def create_drbx(row, locus_in_allele_name):
171172

172173

173174
if __name__ == "__main__":
174-
175175
# config is specified with a -c parameter
176176
parser = argparse.ArgumentParser()
177177
parser.add_argument("-c", "--config", help="JSON Configuration file", required=True)
178-
178+
parser.add_argument(
179+
"-d",
180+
"--data-dir",
181+
dest="data_dir",
182+
help="Data directory to store imported data",
183+
)
184+
parser.add_argument(
185+
"-i",
186+
"--imgt-version",
187+
dest="imgt_version",
188+
help="IPD-IMGT/HLA db to use for redux",
189+
)
179190
args = parser.parse_args()
180191
config_filename = args.config
181192

@@ -197,8 +208,9 @@ if __name__ == "__main__":
197208
print(" pip install openpyxl")
198209
sys.exit(1)
199210

200-
# Instantiate py-ard object with the latest
201-
ard = pyard.ARD()
211+
data_dir = get_data_dir(args.data_dir)
212+
imgt_version = get_imgt_version(args.imgt_version)
213+
ard = pyard.ARD(imgt_version=imgt_version, data_dir=data_dir)
202214

203215
# Read the Input File
204216
# Read only the columns to be saved.

scripts/pyard-status

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,25 @@ def get_latest_imgt_version() -> int:
3939
return max(map(int, pyard.db_versions()[:-1]))
4040

4141

42+
def get_file_size(file_name: str) -> float:
43+
return os.path.getsize(file_name) / 1024 / 1024
44+
45+
4246
if __name__ == "__main__":
4347
parser = argparse.ArgumentParser(
4448
description="""
4549
py-ard tool to provide a status report for reference SQLite databases.
46-
""",
50+
"""
4751
)
4852
parser.add_argument(
53+
"-d",
4954
"--data-dir",
5055
dest="data_dir",
5156
help="Data directory to store imported data",
5257
)
5358

5459
args = parser.parse_args()
5560
data_dir = get_data_dir(args.data_dir)
56-
# print(data_dir)
5761

5862
imgt_regex = re.compile(r"pyard-(.+)\.sqlite3")
5963
for _, _, filenames in os.walk(data_dir):
@@ -62,7 +66,9 @@ if __name__ == "__main__":
6266
# eg: get 3440 from 'pyard-3440.sqlite3'
6367
match = imgt_regex.match(filename)
6468
imgt_version = match.group(1) # Get first group
65-
db_connection = db.create_db_connection(data_dir, imgt_version, ro=True)
69+
db_connection, db_filename = db.create_db_connection(
70+
data_dir, imgt_version, ro=True
71+
)
6672
print("-" * 43)
6773
if imgt_version == "Latest":
6874
db_version = data_repository.get_db_version(db_connection)
@@ -80,11 +86,16 @@ if __name__ == "__main__":
8086
)
8187
else:
8288
print(f"IMGT DB Version: {imgt_version}")
89+
file_size = get_file_size(db_filename)
90+
print(f"File: {db_filename}")
91+
print(f"Size: {file_size:.2f}MB")
8392
print("-" * 43)
8493
print(f"|{'Table Name':20}|{'Rows':20}|")
8594
print(f"|{'-' * 41}|")
8695
for table in (
87-
data_repository.ars_mapping_tables + data_repository.code_mapping_tables
96+
data_repository.ars_mapping_tables
97+
+ data_repository.code_mapping_tables
98+
+ ["mac_codes"]
8899
):
89100
if db.table_exists(db_connection, table):
90101
total_rows = db.count_rows(db_connection, table)

0 commit comments

Comments
 (0)