-
Notifications
You must be signed in to change notification settings - Fork 49
Added partial train and cluster commands #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: partial
Are you sure you want to change the base?
Changes from all commits
db0f3a1
f97e0b0
1e6ddd3
492f703
4ac4683
44dba9a
89199e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -360,7 +360,7 @@ class GeneralOptions: | |
| @classmethod | ||
| def from_args(cls, args: argparse.Namespace): | ||
| return cls( | ||
| typeasserted(args.outdir, Path), | ||
| typeasserted(Path(args.outdir), Path), | ||
| typeasserted(args.nthreads, int), | ||
| typeasserted(args.seed, int), | ||
| typeasserted(args.cuda, bool), | ||
|
|
@@ -1088,6 +1088,7 @@ def load_composition_and_abundance( | |
| vamb_options.out_dir.path, | ||
| binsplitter, | ||
| ) | ||
|
|
||
| abundance = calc_abundance( | ||
| abundance_options, | ||
| vamb_options.out_dir.path, | ||
|
|
@@ -1405,13 +1406,11 @@ def run_partial_abundance(opt: PartialAbundanceOptions): | |
| ) | ||
|
|
||
|
|
||
| def run_bin_default(opt: BinDefaultOptions): | ||
| composition, abundance = load_composition_and_abundance( | ||
| vamb_options=opt.common.general, | ||
| comp_options=opt.common.comp, | ||
| abundance_options=opt.common.abundance, | ||
| binsplitter=opt.common.output.binsplitter, | ||
| ) | ||
| def run_train_vae( | ||
| opt: BinDefaultOptions, | ||
| composition: vamb.parsecontigs.Composition, | ||
| abundance: vamb.parsebam.Abundance, | ||
| ) -> np.ndarray: | ||
| data_loader = vamb.encode.make_dataloader( | ||
| abundance.matrix, | ||
| composition.matrix, | ||
|
|
@@ -1430,6 +1429,14 @@ def run_bin_default(opt: BinDefaultOptions): | |
| del composition, abundance | ||
| assert comp_metadata.nseqs == len(latent) | ||
|
|
||
| return latent | ||
|
|
||
|
|
||
| def run_cluster_and_write_files( | ||
| latent, opt: BinDefaultOptions, composition: vamb.parsecontigs.Composition | ||
| ): | ||
| comp_metadata = composition.metadata | ||
| assert comp_metadata.nseqs == len(latent) | ||
| cluster_and_write_files( | ||
| opt.common.clustering, | ||
| opt.common.output.binsplitter, | ||
|
|
@@ -1442,7 +1449,39 @@ def run_bin_default(opt: BinDefaultOptions): | |
| FastaOutput.try_from_common(opt.common), | ||
| None, | ||
| ) | ||
| del latent | ||
|
|
||
|
|
||
| class RunDefault: | ||
| pass | ||
|
|
||
|
|
||
| class RunTrain: | ||
| pass | ||
|
|
||
|
|
||
| class RunCluster: | ||
| def __init__(self, latent: Path): | ||
| self.latent = latent | ||
|
Comment on lines
+1454
to
+1464
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sgalkina I suggested this pattern, to mimick an ADT (or |
||
|
|
||
|
|
||
| def load_train_bin( | ||
| opt: BinDefaultOptions, partial_mode: Union[RunDefault, RunTrain, RunCluster] | ||
| ): | ||
| composition, abundance = load_composition_and_abundance( | ||
| vamb_options=opt.common.general, | ||
| comp_options=opt.common.comp, | ||
| abundance_options=opt.common.abundance, | ||
| binsplitter=opt.common.output.binsplitter, | ||
| ) | ||
|
|
||
| if isinstance(partial_mode, (RunDefault, RunTrain)): | ||
| latent = run_train_vae(opt, composition, abundance) | ||
|
|
||
| if isinstance(partial_mode, RunCluster): | ||
| latent = vamb.vambtools.read_npz(partial_mode.latent) | ||
|
|
||
| if isinstance(partial_mode, (RunDefault, RunCluster)): | ||
| run_cluster_and_write_files(latent, opt, composition) | ||
|
|
||
jakobnissen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def run_bin_aae(opt: BinAvambOptions): | ||
|
|
@@ -2125,7 +2164,12 @@ def add_vae_arguments(subparser: argparse.ArgumentParser): | |
| trainos = subparser.add_argument_group(title="Training options", description=None) | ||
|
|
||
| trainos.add_argument( | ||
| "-e", dest="nepochs", metavar="", type=int, default=300, help=argparse.SUPPRESS | ||
| "-e", | ||
| dest="nepochs", | ||
| metavar="", | ||
| type=int, | ||
| default=300, | ||
| help=argparse.SUPPRESS, | ||
| ) | ||
| trainos.add_argument( | ||
| "-t", | ||
|
|
@@ -2206,6 +2250,20 @@ def add_predictor_arguments(subparser: argparse.ArgumentParser): | |
| return subparser | ||
|
|
||
|
|
||
| def add_cluster_only_args(subparser: argparse.ArgumentParser): | ||
| c_only_arg = subparser.add_argument_group( | ||
| title="Clustering options", description=None | ||
| ) | ||
| c_only_arg.add_argument( | ||
| "--latent", | ||
| dest="latent_file", | ||
| required=True, | ||
| metavar="", | ||
| type=Path, | ||
| help="Path to latent.npz file", | ||
| ) | ||
|
|
||
|
|
||
| def add_clustering_arguments(subparser: argparse.ArgumentParser): | ||
| # Clustering arguments | ||
| clusto = subparser.add_argument_group(title="Clustering options", description=None) | ||
|
|
@@ -2383,8 +2441,9 @@ def main(): | |
| """, | ||
| add_help=False, | ||
| ) | ||
| add_help_arguments(vaevae_parserbin_parser) | ||
| subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand") | ||
| subparsers_model = vaevae_parserbin_parser.add_subparsers( | ||
| dest="model_subcommand", required=True | ||
| ) | ||
|
|
||
| vae_parser = subparsers_model.add_parser( | ||
| VAMB, | ||
|
|
@@ -2510,6 +2569,35 @@ def main(): | |
| add_minlength(general_group) | ||
| add_composition_npz_argument(abundance_parser) | ||
| add_abundance_args_nonpz(abundance_parser) | ||
| train_parser = partial_part.add_parser( | ||
| "train", help="Do training without clustering", add_help=False | ||
| ) | ||
|
|
||
| train_parser.set_defaults(model_subcommand=VAMB) | ||
|
|
||
| general_group = add_general_arguments(train_parser) | ||
| add_minlength(general_group) | ||
| add_composition_arguments(train_parser) | ||
| add_abundance_arguments(train_parser) | ||
| add_taxonomy_arguments(train_parser) | ||
| add_bin_output_arguments(train_parser) | ||
| add_vae_arguments(train_parser) | ||
| add_clustering_arguments(train_parser) | ||
|
|
||
| cluster_parser = partial_part.add_parser( | ||
| "cluster", help="Cluster after training", add_help=False | ||
| ) | ||
| cluster_parser.set_defaults(model_subcommand=VAMB) | ||
|
|
||
| general_group = add_general_arguments(cluster_parser) | ||
| add_minlength(general_group) | ||
| add_composition_arguments(cluster_parser) | ||
| add_abundance_arguments(cluster_parser) | ||
| add_taxonomy_arguments(cluster_parser) | ||
| add_bin_output_arguments(cluster_parser) | ||
| add_vae_arguments(cluster_parser) | ||
| add_clustering_arguments(cluster_parser) | ||
| add_cluster_only_args(cluster_parser) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
|
|
@@ -2524,7 +2612,7 @@ def main(): | |
| sys.exit(1) | ||
| if model == VAMB: | ||
| opt = BinDefaultOptions.from_args(args) | ||
| runner = partial(run_bin_default, opt) | ||
| runner = partial(load_train_bin, opt, RunDefault()) | ||
| run(runner, opt.common.general) | ||
| elif model == TAXVAMB: | ||
| opt = BinTaxVambOptions.from_args(args) | ||
|
|
@@ -2541,6 +2629,8 @@ def main(): | |
| runner = partial(run_reclustering, opt) | ||
| run(runner, opt.general) | ||
| elif args.subcommand == PARTIAL: | ||
| # TODO: args.partial_part is not a string, so why is it being | ||
| # compared to a string here?? | ||
| if args.partial_part == "composition": | ||
| opt = PartialCompositionOptions.from_args(args) | ||
| runner = partial(run_partial_composition, opt) | ||
|
|
@@ -2549,6 +2639,19 @@ def main(): | |
| opt = PartialAbundanceOptions.from_args(args) | ||
| runner = partial(run_partial_abundance, opt) | ||
| run(runner, opt.general) | ||
sgalkina marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| elif args.partial_part == "train": | ||
| opt = BinDefaultOptions.from_args(args) | ||
| runner = partial(load_train_bin, opt, partial_mode=RunTrain()) | ||
| run(runner, opt.common.general) | ||
| elif args.partial_part == "cluster": | ||
| opt = BinDefaultOptions.from_args(args) | ||
| runner = partial( | ||
| load_train_bin, | ||
| opt, | ||
| partial_mode=RunCluster(args.latent_file), | ||
| ) | ||
| run(runner, opt.common.general) | ||
|
|
||
| else: | ||
| # TODO: Add abundance | ||
| # TODO: Add encoding w. VAE | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.