Skip to content

Commit 974a698

Browse files
committed
default to tensorboard logging, add wandb later
1 parent c15ce86 commit 974a698

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
.env
2+
logs/
23

34
# Byte-compiled / optimized / DLL files
45
__pycache__/

alphafold3_pytorch/configs.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
DirectoryPath
3232
)
3333

34+
from lightning.fabric.loggers import TensorBoardLogger
35+
3436
# functions
3537

3638
def exists(v):
@@ -166,6 +168,9 @@ class TrainerConfig(BaseModelWithExtra):
166168
checkpoint_folder: str
167169
overwrite_checkpoints: bool
168170
dataset_config: DatasetConfig | None = None
171+
use_tensorboard: bool = True
172+
tensorboard_log_dir: str = './logs'
173+
logger_kwargs: dict = dict()
169174

170175
@classmethod
171176
@typecheck
@@ -193,7 +198,12 @@ def create_instance(
193198
) -> Trainer:
194199

195200
trainer_kwargs = self.model_dump(
196-
exclude = {'dataset_config'}
201+
exclude = {
202+
'dataset_config',
203+
'use_tensorboard',
204+
'tensorboard_log_dir',
205+
'logger_kwargs'
206+
}
197207
)
198208

199209
assert exists(self.model) ^ exists(model), 'either model is available on the trainer config, or passed in when creating the instance, but not both or neither'
@@ -261,6 +271,13 @@ def create_instance(
261271

262272
assert 'dataset' in trainer_kwargs, 'dataset is absent - dataset_type must be specified along with train folders (pdb for now), or the Dataset instance must be passed in'
263273

274+
# handle loggers
275+
276+
loggers = []
277+
278+
if self.use_tensorboard:
279+
loggers.append(TensorBoardLogger(self.tensorboard_log_dir, **self.logger_kwargs))
280+
264281
# handle rest
265282

266283
trainer_kwargs.update(dict(
@@ -270,7 +287,8 @@ def create_instance(
270287
optimizer = optimizer,
271288
scheduler = scheduler,
272289
valid_dataset = valid_dataset,
273-
map_dataset_input_fn = map_dataset_input_fn
290+
map_dataset_input_fn = map_dataset_input_fn,
291+
loggers = loggers
274292
))
275293

276294
trainer = Trainer(**trainer_kwargs)

alphafold3_pytorch/trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from ema_pytorch import EMA
4848

4949
from lightning import Fabric
50+
from lightning.fabric.loggers import Logger
5051
from lightning.fabric.wrappers import _unwrap_objects
5152

5253
from shortuuid import uuid
@@ -290,6 +291,7 @@ def __init__(
290291
default_lambda_lr = default_lambda_lr_fn,
291292
train_sampler: Sampler | None = None,
292293
fabric: Fabric | None = None,
294+
loggers: List[Logger] = [],
293295
accelerator = 'auto',
294296
checkpoint_prefix = 'af3.ckpt.',
295297
checkpoint_every: int = 1000,
@@ -317,7 +319,11 @@ def __init__(
317319
# instantiate fabric
318320

319321
if not exists(fabric):
320-
fabric = Fabric(accelerator = accelerator, **fabric_kwargs)
322+
fabric = Fabric(
323+
accelerator = accelerator,
324+
loggers = loggers,
325+
**fabric_kwargs
326+
)
321327

322328
self.fabric = fabric
323329
fabric.launch()

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.111"
3+
version = "0.2.114"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },
@@ -48,6 +48,7 @@ dependencies = [
4848
"scikit-learn>=1.5.0",
4949
"sh>=2.0.7",
5050
"shortuuid",
51+
"tensorboard",
5152
"taylor-series-linear-attention>=0.1.9",
5253
"timeout_decorator>=0.5.0",
5354
'torch_geometric',

0 commit comments

Comments
 (0)