Skip to content

Commit f9a5203

Browse files
authored
Finetune tutorial (#30)
* added finetuning tutorial * finetuning tutorial * fixed wandb connection * ran test * added dataset
1 parent 2f190bb commit f9a5203

File tree

5 files changed

+1552
-9
lines changed

5 files changed

+1552
-9
lines changed

src/decima/cli/finetune.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
type=str,
1717
help="Model path or replication number. If a path is provided, the model will be loaded from the path. If a replication number is provided, the model will be loaded from the replication number.",
1818
)
19+
@click.option(
20+
"--device",
21+
type=str,
22+
default="0",
23+
help="Device to use. Default: 0",
24+
)
1925
@click.option("--matrix-file", required=True, help="Matrix file path.")
2026
@click.option("--h5-file", required=True, help="H5 file path.")
2127
@click.option("--outdir", required=True, help="Output directory path to save model checkpoints.")
@@ -33,6 +39,7 @@
3339
def cli_finetune(
3440
name,
3541
model,
42+
device,
3643
matrix_file,
3744
h5_file,
3845
outdir,
@@ -66,11 +73,14 @@ def cli_finetune(
6673
)
6774
val_dataset = HDF5Dataset(h5_file=h5_file, ad=ad, key="val", max_seq_shift=0)
6875

76+
if isinstance(device, str) and device.isdigit():
77+
device = int(device)
78+
6979
train_params = {
7080
"name": name,
7181
"batch_size": batch_size,
7282
"num_workers": num_workers,
73-
"devices": 0,
83+
"devices": device,
7484
"logger": train_logger,
7585
"save_dir": outdir,
7686
"max_epochs": epochs,
@@ -94,10 +104,12 @@ def cli_finetune(
94104
logger.info("Initializing model")
95105
model = LightningModel(model_params=model_params, train_params=train_params)
96106

97-
logger.info("Training")
98-
if logger == "wandb":
99-
wandb.login(host="https://genentech.wandb.io")
107+
if train_logger == "wandb":
108+
logger.info("Connecting to wandb.")
109+
wandb.login(host="https://genentech.wandb.io", anonymous="never")
100110
run = wandb.init(project="decima", dir=name, name=name)
111+
112+
logger.info("Training")
101113
model.train_on_dataset(train_dataset, val_dataset)
102114
train_dataset.close()
103115
val_dataset.close()

src/decima/data/write_hdf5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from grelu.sequence.utils import get_unique_length
55

66

7-
def write_hdf5(file, ad, pad=0):
7+
def write_hdf5(file, ad, pad=0, genome="hg38"):
88
# Calculate seq_len
99
seq_len = get_unique_length(ad.var)
1010

@@ -45,7 +45,7 @@ def write_hdf5(file, ad, pad=0):
4545
arr = ad.var[["chrom", "start", "end", "strand"]].copy()
4646
arr.start = arr.start - pad
4747
arr.end = arr.end + pad
48-
arr = convert_input_type(arr, "indices", genome="hg38")
48+
arr = convert_input_type(arr, "indices", genome=genome)
4949
print(f"Writing sequence array of shape: {arr.shape}")
5050
f.create_dataset("sequences", shape=arr.shape, dtype=np.int8, data=arr)
5151

tutorials/2-variant-effect-prediction.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4859,7 +4859,7 @@
48594859
],
48604860
"metadata": {
48614861
"kernelspec": {
4862-
"display_name": "decima",
4862+
"display_name": "Python 3 (ipykernel)",
48634863
"language": "python",
48644864
"name": "python3"
48654865
},
@@ -4873,9 +4873,9 @@
48734873
"name": "python",
48744874
"nbconvert_exporter": "python",
48754875
"pygments_lexer": "ipython3",
4876-
"version": "3.11.12"
4876+
"version": "3.11.10"
48774877
}
48784878
},
48794879
"nbformat": 4,
4880-
"nbformat_minor": 2
4880+
"nbformat_minor": 4
48814881
}

0 commit comments

Comments
 (0)