Skip to content

Commit 328e33d

Browse files
MuhammedHasanMuhammed Hasan Celik
andauthored
bug fix for training (#25)
Co-authored-by: Muhammed Hasan Celik <celik.muhammed_hasan@gene.com>
1 parent f2278a2 commit 328e33d

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/decima/cli/finetune.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Finetune the Decima model."""
2+
23
import logging
34
import click
45
import anndata
@@ -9,7 +10,12 @@
910

1011
@click.command()
1112
@click.option("--name", required=True, help="Name of the run.")
12-
@click.option("--model", default="0", type=str, help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.")
13+
@click.option(
14+
"--model",
15+
default="0",
16+
type=str,
17+
help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.",
18+
)
1319
@click.option("--matrix-file", required=True, help="Matrix file path.")
1420
@click.option("--h5-file", required=True, help="H5 file path.")
1521
@click.option("--outdir", required=True, help="Output directory path to save model checkpoints.")
@@ -24,7 +30,24 @@
2430
@click.option("--logger", default="wandb", type=str, help="Logger.")
2531
@click.option("--num-workers", default=16, type=int, help="Number of workers.")
2632
@click.option("--seed", default=0, type=int, help="Random seed.")
27-
def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss_total_weight, gradient_accumulation, batch_size, max_seq_shift, gradient_clipping, save_top_k, epochs, logger, num_workers, seed):
33+
def cli_finetune(
34+
name,
35+
model,
36+
matrix_file,
37+
h5_file,
38+
outdir,
39+
learning_rate,
40+
loss_total_weight,
41+
gradient_accumulation,
42+
batch_size,
43+
max_seq_shift,
44+
gradient_clipping,
45+
save_top_k,
46+
epochs,
47+
logger,
48+
num_workers,
49+
seed,
50+
):
2851
"""Finetune the Decima model."""
2952
train_logger = logger
3053
logger = logging.getLogger("decima")
@@ -62,6 +85,7 @@ def cli_finetune(name, model, matrix_file, h5_file , outdir, learning_rate, loss
6285
}
6386
model_params = {
6487
"n_tasks": ad.shape[0],
88+
"init_borzoi": True,
6589
"replicate": model,
6690
}
6791
logger.info(f"train_params: {train_params}")

src/decima/model/decima_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from pathlib import Path
23
from tempfile import TemporaryDirectory
34

@@ -47,15 +48,18 @@ def __init__(self, n_tasks: int, mask=True, borzoi_kwargs: dict = None, init_bor
4748
model = int(model)
4849

4950
if init_borzoi:
51+
logger = logging.getLogger("decima")
5052
# Load state dict
5153
if Path(str(replicate)).exists():
54+
logger.info(f"Initializing weights from Borzoi model using file: {replicate}")
5255
if replicate.endswith(".h5") or replicate.endswith(".pth") or replicate.endswith(".pt"):
5356
state_dict = torch.load(replicate)
5457
elif replicate.endswith(".ckpt"):
5558
state_dict = torch.load(replicate)["state_dict"]
5659
else:
5760
raise ValueError(f"Invalid replicate path: {replicate}")
5861
else:
62+
logger.info(f"Initializing weights from Borzoi model using wandb for replicate: {replicate}")
5963
wandb.login(host="https://api.wandb.ai/", anonymous="must")
6064
api = wandb.Api(overrides={"base_url": "https://api.wandb.ai/"})
6165
art = api.artifact(f"grelu/borzoi/human_state_dict_fold{replicate}:latest")

0 commit comments

Comments
 (0)