|
1 | 1 | """Finetune the Decima model.""" |
| 2 | + |
2 | 3 | import logging |
3 | 4 | import click |
4 | 5 | import anndata |
|
9 | 10 |
|
10 | 11 | @click.command() |
11 | 12 | @click.option("--name", required=True, help="Name of the run.") |
12 | | -@click.option("--model", default="0", type=str, help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.") |
| 13 | +@click.option( |
| 14 | + "--model", |
| 15 | + default="0", |
| 16 | + type=str, |
| 17 | + help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.", |
| 18 | +) |
13 | 19 | @click.option("--matrix-file", required=True, help="Matrix file path.") |
14 | 20 | @click.option("--h5-file", required=True, help="H5 file path.") |
15 | 21 | @click.option("--outdir", required=True, help="Output directory path to save model checkpoints.") |
|
24 | 30 | @click.option("--logger", default="wandb", type=str, help="Logger.") |
25 | 31 | @click.option("--num-workers", default=16, type=int, help="Number of workers.") |
26 | 32 | @click.option("--seed", default=0, type=int, help="Random seed.") |
27 | | -def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss_total_weight, gradient_accumulation, batch_size, max_seq_shift, gradient_clipping, save_top_k, epochs, logger, num_workers, seed): |
| 33 | +def cli_finetune( |
| 34 | + name, |
| 35 | + model, |
| 36 | + matrix_file, |
| 37 | + h5_file, |
| 38 | + outdir, |
| 39 | + learning_rate, |
| 40 | + loss_total_weight, |
| 41 | + gradient_accumulation, |
| 42 | + batch_size, |
| 43 | + max_seq_shift, |
| 44 | + gradient_clipping, |
| 45 | + save_top_k, |
| 46 | + epochs, |
| 47 | + logger, |
| 48 | + num_workers, |
| 49 | + seed, |
| 50 | +): |
28 | 51 | """Finetune the Decima model.""" |
29 | 52 | train_logger = logger |
30 | 53 | logger = logging.getLogger("decima") |
@@ -62,6 +85,7 @@ def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss |
62 | 85 | } |
63 | 86 | model_params = { |
64 | 87 | "n_tasks": ad.shape[0], |
| 88 | + "init_borzoi": True, |
65 | 89 | "replicate": model, |
66 | 90 | } |
67 | 91 | logger.info(f"train_params: {train_params}") |
|
0 commit comments