diff --git a/.github/workflows/cli_vamb.yml b/.github/workflows/cli_vamb.yml index 4b988af4..bbbebccd 100644 --- a/.github/workflows/cli_vamb.yml +++ b/.github/workflows/cli_vamb.yml @@ -61,3 +61,18 @@ jobs: vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --markers markers_mock.npz --algorithm kmeans --minfasta 200000 ls -la outdir_recluster cat outdir_recluster/log.txt + - name: Run Partial Composition + run: | + vamb partial composition --outdir outdir_composition --fasta catalogue_mock.fna.gz + ls -la outdir_composition + cat outdir_composition/log.txt + - name: Run Partial Train + run: | + vamb partial train --outdir latent_file --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz -e 3 -q + ls -la latent_file + cat latent_file/log.txt + - name: Run Partial Cluster + run: | + vamb partial cluster --outdir outdir_cluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent latent_file/latent.npz + ls -la outdir_cluster + cat outdir_cluster/log.txt diff --git a/vamb/__main__.py b/vamb/__main__.py index 3790039b..9feb7753 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -360,7 +360,7 @@ class GeneralOptions: @classmethod def from_args(cls, args: argparse.Namespace): return cls( - typeasserted(args.outdir, Path), + typeasserted(Path(args.outdir), Path), typeasserted(args.nthreads, int), typeasserted(args.seed, int), typeasserted(args.cuda, bool), @@ -1088,6 +1088,7 @@ def load_composition_and_abundance( vamb_options.out_dir.path, binsplitter, ) + abundance = calc_abundance( abundance_options, vamb_options.out_dir.path, @@ -1405,13 +1406,11 @@ def run_partial_abundance(opt: PartialAbundanceOptions): ) -def run_bin_default(opt: BinDefaultOptions): - composition, abundance = load_composition_and_abundance( - vamb_options=opt.common.general, - comp_options=opt.common.comp, - abundance_options=opt.common.abundance, - binsplitter=opt.common.output.binsplitter, - ) +def run_train_vae( + opt: BinDefaultOptions, + composition: vamb.parsecontigs.Composition, + abundance: vamb.parsebam.Abundance, +) -> np.ndarray: data_loader = vamb.encode.make_dataloader( abundance.matrix, composition.matrix, @@ -1430,6 +1429,14 @@ def run_bin_default(opt: BinDefaultOptions): del composition, abundance assert comp_metadata.nseqs == len(latent) + return latent + + +def run_cluster_and_write_files( + latent, opt: BinDefaultOptions, composition: vamb.parsecontigs.Composition +): + comp_metadata = composition.metadata + assert comp_metadata.nseqs == len(latent) cluster_and_write_files( opt.common.clustering, opt.common.output.binsplitter, @@ -1442,7 +1449,39 @@ def run_bin_default(opt: BinDefaultOptions): FastaOutput.try_from_common(opt.common), None, ) - del latent + + +class RunDefault: + pass + + +class RunTrain: + pass + + +class RunCluster: + def __init__(self, latent: Path): + self.latent = latent + + +def load_train_bin( + opt: BinDefaultOptions, partial_mode: Union[RunDefault, RunTrain, RunCluster] +): + composition, abundance = load_composition_and_abundance( + vamb_options=opt.common.general, + comp_options=opt.common.comp, + abundance_options=opt.common.abundance, + binsplitter=opt.common.output.binsplitter, + ) + + if isinstance(partial_mode, (RunDefault, RunTrain)): + latent = run_train_vae(opt, composition, abundance) + + if isinstance(partial_mode, RunCluster): + latent = vamb.vambtools.read_npz(partial_mode.latent) + + if isinstance(partial_mode, (RunDefault, RunCluster)): + run_cluster_and_write_files(latent, opt, composition) def run_bin_aae(opt: BinAvambOptions): @@ -2125,7 +2164,12 @@ def add_vae_arguments(subparser: argparse.ArgumentParser): trainos = subparser.add_argument_group(title="Training options", description=None) trainos.add_argument( - "-e", dest="nepochs", metavar="", type=int, default=300, help=argparse.SUPPRESS + "-e", + dest="nepochs", + metavar="", + type=int, + default=300, + help=argparse.SUPPRESS, ) trainos.add_argument( "-t", @@ -2206,6 +2250,20 @@ def add_predictor_arguments(subparser: argparse.ArgumentParser): return subparser +def add_cluster_only_args(subparser: argparse.ArgumentParser): + c_only_arg = subparser.add_argument_group( + title="Clustering options", description=None + ) + c_only_arg.add_argument( + "--latent", + dest="latent_file", + required=True, + metavar="", + type=Path, + help="Path to latent.npz file", + ) + + def add_clustering_arguments(subparser: argparse.ArgumentParser): # Clustering arguments clusto = subparser.add_argument_group(title="Clustering options", description=None) @@ -2383,8 +2441,9 @@ def main(): """, add_help=False, ) - add_help_arguments(vaevae_parserbin_parser) - subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand") + subparsers_model = vaevae_parserbin_parser.add_subparsers( + dest="model_subcommand", required=True + ) vae_parser = subparsers_model.add_parser( VAMB, @@ -2510,6 +2569,35 @@ def main(): add_minlength(general_group) add_composition_npz_argument(abundance_parser) add_abundance_args_nonpz(abundance_parser) + train_parser = partial_part.add_parser( + "train", help="Do training without clustering", add_help=False + ) + + train_parser.set_defaults(model_subcommand=VAMB) + + general_group = add_general_arguments(train_parser) + add_minlength(general_group) + add_composition_arguments(train_parser) + add_abundance_arguments(train_parser) + add_taxonomy_arguments(train_parser) + add_bin_output_arguments(train_parser) + add_vae_arguments(train_parser) + add_clustering_arguments(train_parser) + + cluster_parser = partial_part.add_parser( + "cluster", help="Cluster after training", add_help=False + ) + cluster_parser.set_defaults(model_subcommand=VAMB) + + general_group = add_general_arguments(cluster_parser) + add_minlength(general_group) + add_composition_arguments(cluster_parser) + add_abundance_arguments(cluster_parser) + add_taxonomy_arguments(cluster_parser) + add_bin_output_arguments(cluster_parser) + add_vae_arguments(cluster_parser) + add_clustering_arguments(cluster_parser) + add_cluster_only_args(cluster_parser) args = parser.parse_args() @@ -2524,7 +2612,7 @@ def main(): sys.exit(1) if model == VAMB: opt = BinDefaultOptions.from_args(args) - runner = partial(run_bin_default, opt) + runner = partial(load_train_bin, opt, RunDefault()) run(runner, opt.common.general) elif model == TAXVAMB: opt = BinTaxVambOptions.from_args(args) @@ -2541,6 +2629,8 @@ def main(): runner = partial(run_reclustering, opt) run(runner, opt.general) elif args.subcommand == PARTIAL: + # TODO: args.partial_part is not a string, so why is it being + # compared to a string here?? if args.partial_part == "composition": opt = PartialCompositionOptions.from_args(args) runner = partial(run_partial_composition, opt) @@ -2549,6 +2639,19 @@ def main(): opt = PartialAbundanceOptions.from_args(args) runner = partial(run_partial_abundance, opt) run(runner, opt.general) + elif args.partial_part == "train": + opt = BinDefaultOptions.from_args(args) + runner = partial(load_train_bin, opt, partial_mode=RunTrain()) + run(runner, opt.common.general) + elif args.partial_part == "cluster": + opt = BinDefaultOptions.from_args(args) + runner = partial( + load_train_bin, + opt, + partial_mode=RunCluster(args.latent_file), + ) + run(runner, opt.common.general) + else: # TODO: Add abundance # TODO: Add encoding w. VAE