Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/workflows/cli_vamb.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,18 @@ jobs:
vamb recluster --outdir outdir_recluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent_path outdir_taxvamb/vaevae_latent.npz --clusters_path outdir_taxvamb/vaevae_clusters_split.tsv --markers markers_mock.npz --algorithm kmeans --minfasta 200000
ls -la outdir_recluster
cat outdir_recluster/log.txt
- name: Run Partial Composition
run: |
vamb partial composition --outdir outdir_composition --fasta catalogue_mock.fna.gz
ls -la outdir_composition
cat outdir_composition/log.txt
- name: Run Partial Train
run: |
vamb partial train --outdir latent_file --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz -e 3 -q
ls -la latent_file
cat latent_file/log.txt
- name: Run Partial Cluster
run: |
vamb partial cluster --outdir outdir_cluster --fasta catalogue_mock.fna.gz --abundance abundance_mock.npz --latent latent_file/latent.npz
ls -la outdir_cluster
cat outdir_cluster/log.txt
129 changes: 116 additions & 13 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class GeneralOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
return cls(
typeasserted(args.outdir, Path),
typeasserted(Path(args.outdir), Path),
typeasserted(args.nthreads, int),
typeasserted(args.seed, int),
typeasserted(args.cuda, bool),
Expand Down Expand Up @@ -1088,6 +1088,7 @@ def load_composition_and_abundance(
vamb_options.out_dir.path,
binsplitter,
)

abundance = calc_abundance(
abundance_options,
vamb_options.out_dir.path,
Expand Down Expand Up @@ -1405,13 +1406,11 @@ def run_partial_abundance(opt: PartialAbundanceOptions):
)


def run_bin_default(opt: BinDefaultOptions):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
def run_train_vae(
opt: BinDefaultOptions,
composition: vamb.parsecontigs.Composition,
abundance: vamb.parsebam.Abundance,
) -> np.ndarray:
data_loader = vamb.encode.make_dataloader(
abundance.matrix,
composition.matrix,
Expand All @@ -1430,6 +1429,14 @@ def run_bin_default(opt: BinDefaultOptions):
del composition, abundance
assert comp_metadata.nseqs == len(latent)

return latent


def run_cluster_and_write_files(
latent, opt: BinDefaultOptions, composition: vamb.parsecontigs.Composition
):
comp_metadata = composition.metadata
assert comp_metadata.nseqs == len(latent)
cluster_and_write_files(
opt.common.clustering,
opt.common.output.binsplitter,
Expand All @@ -1442,7 +1449,39 @@ def run_bin_default(opt: BinDefaultOptions):
FastaOutput.try_from_common(opt.common),
None,
)
del latent


class RunDefault:
pass


class RunTrain:
pass


class RunCluster:
def __init__(self, latent: Path):
self.latent = latent
Comment on lines +1454 to +1464
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sgalkina I suggested this pattern, to mimick an ADT (or enum, in Rust parlance). My idea is to 1) statically check that the argument is the correct type, and 2) make it easier to catch errors if a new partial mode is added. Is there a better way in Python to make ADTs that can be caught by the type checker?



def load_train_bin(
opt: BinDefaultOptions, partial_mode: Union[RunDefault, RunTrain, RunCluster]
):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)

if isinstance(partial_mode, (RunDefault, RunTrain)):
latent = run_train_vae(opt, composition, abundance)

if isinstance(partial_mode, RunCluster):
latent = vamb.vambtools.read_npz(partial_mode.latent)

if isinstance(partial_mode, (RunDefault, RunCluster)):
run_cluster_and_write_files(latent, opt, composition)


def run_bin_aae(opt: BinAvambOptions):
Expand Down Expand Up @@ -2125,7 +2164,12 @@ def add_vae_arguments(subparser: argparse.ArgumentParser):
trainos = subparser.add_argument_group(title="Training options", description=None)

trainos.add_argument(
"-e", dest="nepochs", metavar="", type=int, default=300, help=argparse.SUPPRESS
"-e",
dest="nepochs",
metavar="",
type=int,
default=300,
help=argparse.SUPPRESS,
)
trainos.add_argument(
"-t",
Expand Down Expand Up @@ -2206,6 +2250,20 @@ def add_predictor_arguments(subparser: argparse.ArgumentParser):
return subparser


def add_cluster_only_args(subparser: argparse.ArgumentParser):
c_only_arg = subparser.add_argument_group(
title="Clustering options", description=None
)
c_only_arg.add_argument(
"--latent",
dest="latent_file",
required=True,
metavar="",
type=Path,
help="Path to latent.npz file",
)


def add_clustering_arguments(subparser: argparse.ArgumentParser):
# Clustering arguments
clusto = subparser.add_argument_group(title="Clustering options", description=None)
Expand Down Expand Up @@ -2383,8 +2441,9 @@ def main():
""",
add_help=False,
)
add_help_arguments(vaevae_parserbin_parser)
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand")
subparsers_model = vaevae_parserbin_parser.add_subparsers(
dest="model_subcommand", required=True
)

vae_parser = subparsers_model.add_parser(
VAMB,
Expand Down Expand Up @@ -2510,6 +2569,35 @@ def main():
add_minlength(general_group)
add_composition_npz_argument(abundance_parser)
add_abundance_args_nonpz(abundance_parser)
train_parser = partial_part.add_parser(
"train", help="Do training without clustering", add_help=False
)

train_parser.set_defaults(model_subcommand=VAMB)

general_group = add_general_arguments(train_parser)
add_minlength(general_group)
add_composition_arguments(train_parser)
add_abundance_arguments(train_parser)
add_taxonomy_arguments(train_parser)
add_bin_output_arguments(train_parser)
add_vae_arguments(train_parser)
add_clustering_arguments(train_parser)

cluster_parser = partial_part.add_parser(
"cluster", help="Cluster after training", add_help=False
)
cluster_parser.set_defaults(model_subcommand=VAMB)

general_group = add_general_arguments(cluster_parser)
add_minlength(general_group)
add_composition_arguments(cluster_parser)
add_abundance_arguments(cluster_parser)
add_taxonomy_arguments(cluster_parser)
add_bin_output_arguments(cluster_parser)
add_vae_arguments(cluster_parser)
add_clustering_arguments(cluster_parser)
add_cluster_only_args(cluster_parser)

args = parser.parse_args()

Expand All @@ -2524,7 +2612,7 @@ def main():
sys.exit(1)
if model == VAMB:
opt = BinDefaultOptions.from_args(args)
runner = partial(run_bin_default, opt)
runner = partial(load_train_bin, opt, RunDefault())
run(runner, opt.common.general)
elif model == TAXVAMB:
opt = BinTaxVambOptions.from_args(args)
Expand All @@ -2541,6 +2629,8 @@ def main():
runner = partial(run_reclustering, opt)
run(runner, opt.general)
elif args.subcommand == PARTIAL:
# TODO: args.partial_part is not a string, so why is it being
# compared to a string here??
if args.partial_part == "composition":
opt = PartialCompositionOptions.from_args(args)
runner = partial(run_partial_composition, opt)
Expand All @@ -2549,6 +2639,19 @@ def main():
opt = PartialAbundanceOptions.from_args(args)
runner = partial(run_partial_abundance, opt)
run(runner, opt.general)
elif args.partial_part == "train":
opt = BinDefaultOptions.from_args(args)
runner = partial(load_train_bin, opt, partial_mode=RunTrain())
run(runner, opt.common.general)
elif args.partial_part == "cluster":
opt = BinDefaultOptions.from_args(args)
runner = partial(
load_train_bin,
opt,
partial_mode=RunCluster(args.latent_file),
)
run(runner, opt.common.general)

else:
# TODO: Add abundance
# TODO: Add encoding w. VAE
Expand Down