Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions src/jaxqtl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,12 +321,19 @@ def main(args):
geno, bim, sample_info = geno_reader(args.geno)

pheno_reader = PheBedReader()
# 1. Load raw phenotype data
pheno = pheno_reader(args.pheno)
n_genes_raw = pheno.shape[1] if hasattr(pheno, 'shape') else len(pheno)
logging.info(f"Raw phenotype data loaded. Total genes detected: {n_genes_raw}")

covar = covar_reader(args.covar, args.add_covar, args.covar_test, args.rm_covar)

if args.genelist is not None:
# Load provided gene list
genelist = pd.read_csv(args.genelist, header=None, sep="\t").iloc[:, 0].to_list()
logging.info(f"User-provided gene list loaded. Total target genes: {len(genelist)}")
# Use DEBUG level for verbose data dumping (requires --verbose True)
logging.debug(f"Gene list preview (first 5): {genelist[:5]}")
else:
genelist = None

Expand All @@ -336,8 +343,18 @@ def main(args):
else:
indList = None

# 2. Create ReadyData (Intersection of Genotype, Phenotype, and Covariates)
dat = create_readydata(geno, bim, pheno, covar, autosomal_only=args.autosomal_only, ind_list=indList)

# Check dimensions after sample intersection
n_genes_intersect = dat.pheno.count.shape[1] if hasattr(dat.pheno, 'count') else 0
n_samples_intersect = dat.pheno.count.shape[0] if hasattr(dat.pheno, 'count') else 0

logging.info(f"Data intersection complete. Retained Genes: {n_genes_intersect}, Retained Samples: {n_samples_intersect}")

if n_genes_intersect == 0 or n_samples_intersect == 0:
logging.warning("Data intersection resulted in empty dataset. Check sample ID consistency between phenotype, genotype, and covariates.")

# before filter gene list, calculate library size and set offset, or read in pre-computed log(offset)
if args.offset is None:
total_libsize = jnp.array(dat.pheno.count.sum(axis=1))[:, jnp.newaxis]
Expand All @@ -347,8 +364,9 @@ def main(args):
offset_eta = offset_eta.loc[offset_eta.index.isin(dat.pheno.count.index)].sort_index()
offset_eta = jnp.array(offset_eta)

# filter out genes with no expressions at all
# 3. Filter out non-expressed genes
dat.filter_gene(geneexpr_percent_cutoff=0.0)
logging.info(f"Filtered non-expressed genes (all zeros). Remaining genes: {dat.pheno.count.shape[1]}")

# add expression PCs to covar, genotype PC should appended to covar outside jaxqtl
if args.addpc > 0:
Expand All @@ -359,8 +377,14 @@ def main(args):
# note: use pre-processed file as in tensorqtl
offset_eta = jnp.zeros_like(offset_eta)

# filter gene list
# 4. Apply target gene list and expression percentage filter
dat.filter_gene(gene_list=genelist, geneexpr_percent_cutoff=args.express_percent)
logging.info(f"Applied target gene list and expression threshold. Final gene count for mapping: {dat.pheno.count.shape[1]}")

if dat.pheno.count.shape[1] == 0:
logging.error("No genes remaining after filtering. Please check if gene IDs in the provided list match the phenotype file.")
# Optionally raise an error here to stop execution cleanly
# raise ValueError("No genes available for analysis.")

# permute gene expression for type I error calibration
if args.perm_pheno:
Expand Down