Skip to content

Commit d5a0767

Browse files
committed
- Allowpyard and pyard-reduce-csv to support --data-dir and --imgt-version options
- pyard-status shows the sqlite db and the corresponding file size
1 parent 2807cd7 commit d5a0767

File tree

7 files changed

+62
-35
lines changed

7 files changed

+62
-35
lines changed

pyard/db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def create_db_connection(data_dir, imgt_version, ro=False):
5555
# Open the database in read-only mode
5656
file_uri = f"file:{db_filename}?mode=ro"
5757
# Multiple threads can access the same connection since it's only ro
58-
return sqlite3.connect(file_uri, check_same_thread=False, uri=True)
58+
return sqlite3.connect(file_uri, check_same_thread=False, uri=True), db_filename
5959

6060
# Check the imgt_version is a valid IMGT DB Version
6161
# by querying the IMGT site
@@ -73,7 +73,7 @@ def create_db_connection(data_dir, imgt_version, ro=False):
7373

7474
# Open the database for read/write
7575
file_uri = f"file:{db_filename}"
76-
return sqlite3.connect(file_uri, uri=True)
76+
return sqlite3.connect(file_uri, uri=True), db_filename
7777

7878

7979
def table_exists(connection: sqlite3.Connection, table_name: str) -> bool:

pyard/misc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,14 @@ def get_data_dir(data_dir):
9494
else:
9595
data_dir = db.get_pyard_db_install_directory()
9696
return data_dir
97+
98+
99+
def get_imgt_version(imgt_version):
100+
if imgt_version:
101+
version = imgt_version.replace(".", "")
102+
if version.isdigit():
103+
return version
104+
raise RuntimeError(
105+
f"{imgt_version} is not a valid IMGT database version number"
106+
)
107+
return "Latest"

pyard/pyard.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
self._config.update(config)
102102

103103
# Create a database connection for writing
104-
self.db_connection = db.create_db_connection(data_dir, imgt_version)
104+
self.db_connection, _ = db.create_db_connection(data_dir, imgt_version)
105105

106106
# Load MAC codes
107107
dr.generate_mac_codes(self.db_connection, refresh_mac=False, load_mac=load_mac)
@@ -144,7 +144,7 @@ def __init__(
144144
gc.freeze()
145145

146146
# Re-open the connection in read-only mode as we're not updating it anymore
147-
self.db_connection = db.create_db_connection(data_dir, imgt_version, ro=True)
147+
self.db_connection, _ = db.create_db_connection(data_dir, imgt_version, ro=True)
148148

149149
def __del__(self):
150150
"""

scripts/pyard

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,7 @@ import sys
2626

2727
import pyard
2828
from pyard.exceptions import InvalidAlleleError
29-
30-
31-
def get_imgt_version(imgt_version):
32-
if imgt_version:
33-
version = imgt_version.replace(".", "")
34-
if version.isdigit():
35-
return version
36-
raise RuntimeError(
37-
f"{imgt_version} is not a valid IMGT database version number"
38-
)
39-
return None
40-
29+
from pyard.misc import get_data_dir, get_imgt_version
4130

4231
if __name__ == "__main__":
4332
parser = argparse.ArgumentParser(
@@ -52,6 +41,12 @@ if __name__ == "__main__":
5241
action="store_true",
5342
help="IPD-IMGT/HLA DB Version number",
5443
)
44+
parser.add_argument(
45+
"-d",
46+
"--data-dir",
47+
dest="data_dir",
48+
help="Data directory to store imported data",
49+
)
5550
parser.add_argument(
5651
"-i",
5752
"--imgt-version",
@@ -69,24 +64,22 @@ if __name__ == "__main__":
6964

7065
args = parser.parse_args()
7166

67+
if args.splits:
68+
mapping = pyard.find_broad_splits(args.splits)
69+
if mapping:
70+
print(f"{mapping[0]} = {'/'.join(mapping[1])}")
71+
sys.exit(0)
72+
7273
imgt_version = get_imgt_version(args.imgt_version)
73-
if imgt_version:
74-
ard = pyard.ARD(imgt_version)
75-
else:
76-
ard = pyard.ARD()
74+
data_dir = get_data_dir(args.data_dir)
75+
ard = pyard.ARD(imgt_version=imgt_version, data_dir=data_dir)
7776

7877
if args.version:
7978
version = ard.get_db_version()
8079
print(f"IPD-IMGT/HLA version:", version)
8180
print(f"py-ard version:", pyard.__version__)
8281
sys.exit(0)
8382

84-
if args.splits:
85-
mapping = pyard.find_broad_splits(args.splits)
86-
if mapping:
87-
print(f"{mapping[0]} = {'/'.join(mapping[1])}")
88-
sys.exit(0)
89-
9083
try:
9184
if args.redux_type:
9285
print(ard.redux_gl(args.gl_string, args.redux_type))

scripts/pyard-import

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,13 @@ if __name__ == "__main__":
6868
help="Show Versions of available IMGT Databases",
6969
)
7070
parser.add_argument(
71-
"--db-version",
71+
"-i",
72+
"--imgt-version",
7273
dest="imgt_version",
7374
help="Import supplied IMGT_VERSION DB Version",
7475
)
7576
parser.add_argument(
77+
"-d",
7678
"--data-dir",
7779
dest="data_dir",
7880
help="Data directory to store imported data",
@@ -153,6 +155,6 @@ if __name__ == "__main__":
153155

154156
if args.refresh_mac:
155157
print(f"Updating MACs")
156-
db_connection = db.create_db_connection(data_dir, imgt_version, ro=False)
158+
db_connection, _ = db.create_db_connection(data_dir, imgt_version, ro=False)
157159
data_repository.generate_mac_codes(db_connection, refresh_mac=True)
158160
print(f"Updated MACs for {imgt_version} IMGT database.")

scripts/pyard-reduce-csv

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import pandas as pd
3838
import pyard
3939
import pyard.drbx as drbx
4040
from pyard.exceptions import PyArdError
41+
from pyard.misc import get_data_dir, get_imgt_version
4142

4243

4344
def is_serology(allele: str) -> bool:
@@ -171,11 +172,21 @@ def create_drbx(row, locus_in_allele_name):
171172

172173

173174
if __name__ == "__main__":
174-
175175
# config is specified with a -c parameter
176176
parser = argparse.ArgumentParser()
177177
parser.add_argument("-c", "--config", help="JSON Configuration file", required=True)
178-
178+
parser.add_argument(
179+
"-d",
180+
"--data-dir",
181+
dest="data_dir",
182+
help="Data directory to store imported data",
183+
)
184+
parser.add_argument(
185+
"-i",
186+
"--imgt-version",
187+
dest="imgt_version",
188+
help="IPD-IMGT/HLA db to use for redux",
189+
)
179190
args = parser.parse_args()
180191
config_filename = args.config
181192

@@ -197,8 +208,9 @@ if __name__ == "__main__":
197208
print(" pip install openpyxl")
198209
sys.exit(1)
199210

200-
# Instantiate py-ard object with the latest
201-
ard = pyard.ARD()
211+
data_dir = get_data_dir(args.data_dir)
212+
imgt_version = get_imgt_version(args.imgt_version)
213+
ard = pyard.ARD(imgt_version=imgt_version, data_dir=data_dir)
202214

203215
# Read the Input File
204216
# Read only the columns to be saved.

scripts/pyard-status

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,25 @@ def get_latest_imgt_version() -> int:
3939
return max(map(int, pyard.db_versions()[:-1]))
4040

4141

42+
def get_file_size(file_name: str) -> float:
43+
return os.path.getsize(file_name) / 1024 / 1024
44+
45+
4246
if __name__ == "__main__":
4347
parser = argparse.ArgumentParser(
4448
description="""
4549
py-ard tool to provide a status report for reference SQLite databases.
46-
""",
50+
"""
4751
)
4852
parser.add_argument(
53+
"-d",
4954
"--data-dir",
5055
dest="data_dir",
5156
help="Data directory to store imported data",
5257
)
5358

5459
args = parser.parse_args()
5560
data_dir = get_data_dir(args.data_dir)
56-
# print(data_dir)
5761

5862
imgt_regex = re.compile(r"pyard-(.+)\.sqlite3")
5963
for _, _, filenames in os.walk(data_dir):
@@ -62,7 +66,9 @@ if __name__ == "__main__":
6266
# eg: get 3440 from 'pyard-3440.sqlite3'
6367
match = imgt_regex.match(filename)
6468
imgt_version = match.group(1) # Get first group
65-
db_connection = db.create_db_connection(data_dir, imgt_version, ro=True)
69+
db_connection, db_filename = db.create_db_connection(
70+
data_dir, imgt_version, ro=True
71+
)
6672
print("-" * 43)
6773
if imgt_version == "Latest":
6874
db_version = data_repository.get_db_version(db_connection)
@@ -80,6 +86,9 @@ if __name__ == "__main__":
8086
)
8187
else:
8288
print(f"IMGT DB Version: {imgt_version}")
89+
file_size = get_file_size(db_filename)
90+
print(f"File: {db_filename}")
91+
print(f"Size: {file_size:.2f}MB")
8392
print("-" * 43)
8493
print(f"|{'Table Name':20}|{'Rows':20}|")
8594
print(f"|{'-' * 41}|")

0 commit comments

Comments
 (0)