diff --git a/src/jaxqtl/cli.py b/src/jaxqtl/cli.py index f31ebbe..6a388e3 100644 --- a/src/jaxqtl/cli.py +++ b/src/jaxqtl/cli.py @@ -321,12 +321,19 @@ def main(args): geno, bim, sample_info = geno_reader(args.geno) pheno_reader = PheBedReader() + # 1. Load raw phenotype data pheno = pheno_reader(args.pheno) + n_genes_raw = pheno.shape[1] if hasattr(pheno, 'shape') else len(pheno) + logging.info(f"Raw phenotype data loaded. Total genes detected: {n_genes_raw}") covar = covar_reader(args.covar, args.add_covar, args.covar_test, args.rm_covar) if args.genelist is not None: + # Load provided gene list genelist = pd.read_csv(args.genelist, header=None, sep="\t").iloc[:, 0].to_list() + logging.info(f"User-provided gene list loaded. Total target genes: {len(genelist)}") + # Use DEBUG level for verbose data dumping (requires --verbose True) + logging.debug(f"Gene list preview (first 5): {genelist[:5]}") else: genelist = None @@ -336,8 +343,18 @@ def main(args): else: indList = None + # 2. Create ReadyData (Intersection of Genotype, Phenotype, and Covariates) dat = create_readydata(geno, bim, pheno, covar, autosomal_only=args.autosomal_only, ind_list=indList) + # Check dimensions after sample intersection + n_genes_intersect = dat.pheno.count.shape[1] if hasattr(dat.pheno, 'count') else 0 + n_samples_intersect = dat.pheno.count.shape[0] if hasattr(dat.pheno, 'count') else 0 + + logging.info(f"Data intersection complete. Retained Genes: {n_genes_intersect}, Retained Samples: {n_samples_intersect}") + + if n_genes_intersect == 0 or n_samples_intersect == 0: + logging.warning("Data intersection resulted in empty dataset. Check sample ID consistency between phenotype, genotype, and covariates.") + # before filter gene list, calculate library size and set offset, or read in pre-computed log(offset) if args.offset is None: total_libsize = jnp.array(dat.pheno.count.sum(axis=1))[:, jnp.newaxis] @@ -347,8 +364,9 @@ def main(args): offset_eta = offset_eta.loc[offset_eta.index.isin(dat.pheno.count.index)].sort_index() offset_eta = jnp.array(offset_eta) - # filter out genes with no expressions at all + # 3. Filter out non-expressed genes dat.filter_gene(geneexpr_percent_cutoff=0.0) + logging.info(f"Filtered non-expressed genes (all zeros). Remaining genes: {dat.pheno.count.shape[1]}") # add expression PCs to covar, genotype PC should appended to covar outside jaxqtl if args.addpc > 0: @@ -359,8 +377,14 @@ def main(args): # note: use pre-processed file as in tensorqtl offset_eta = jnp.zeros_like(offset_eta) - # filter gene list + # 4. Apply target gene list and expression percentage filter dat.filter_gene(gene_list=genelist, geneexpr_percent_cutoff=args.express_percent) + logging.info(f"Applied target gene list and expression threshold. Final gene count for mapping: {dat.pheno.count.shape[1]}") + + if dat.pheno.count.shape[1] == 0: + logging.error("No genes remaining after filtering. Please check if gene IDs in the provided list match the phenotype file.") + # Optionally raise an error here to stop execution cleanly + # raise ValueError("No genes available for analysis.") # permute gene expression for type I error calibration if args.perm_pheno: