-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.3.x
Description
Bug description
If the project
name is not passed (left as None
), when creating a WandbLogger
, the logger logs the global step differently.
What version are you seeing the problem on?
v2.3
How to reproduce the bug
In the below minimal example, project
is not passed to WandbLogger
and the incorrect step is logged:
import os
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.wandb import WandbLogger
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(28 * 28, 10)
self.loss_metric = CrossEntropyLoss()
self.train_loss_total = torch.tensor(0, dtype=torch.float32)
self.validation_loss_total = torch.tensor(0, dtype=torch.float32)
self.train_steps = torch.tensor(0, dtype=torch.int64)
self.validation_steps = torch.tensor(0, dtype=torch.int64)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_metric(y_hat, y)
self.train_loss_total += loss
self.train_steps += 1
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_metric(y_hat, y)
self.validation_loss_total += loss
self.validation_steps += 1
return loss
def on_train_epoch_end(self):
self.log('loss', self.train_loss_total / self.train_steps, on_step=False, on_epoch=True)
self.train_loss_total.zero_()
self.train_steps.zero_()
def on_validation_epoch_end(self):
self.log('val_loss', self.validation_loss_total / self.validation_steps, on_step=False, on_epoch=True)
self.validation_loss_total.zero_()
self.validation_steps.zero_()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
train_loader = DataLoader(
MNIST(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
)
val_loader = DataLoader(
MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
)
logger = WandbLogger(project='a')
trainer = pl.Trainer(
max_epochs=10,
limit_train_batches=20,
limit_val_batches=10,
log_every_n_steps=0,
logger=logger,
accelerator='cpu',
)
model = LitModel()
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
This results in the step being incremented only at the end of the epoch, and once for both training and validation, meaning they have different step numbers (each of which less than the global training step numbers).
By only changing WandbLogger()
to WandbLogger(project='a')
, it changes how the step value is logged. With this change, the step value is logged as the global training step value, and is consistent between the train and validation logging.
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None - Lightning:
- lightning: 2.3.3
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.3.3
- torch: 2.3.1
- torcheval: 0.0.7
- torchmetrics: 1.4.0.post0
- torchvision: 0.18.1 - Packages:
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- alabaster: 0.7.16
- anyio: 4.4.0
- appnope: 0.1.4
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- arrow: 1.3.0
- astropy: 6.1.0
- astropy-iers-data: 0.2024.6.3.0.31.14
- astroquery: 0.4.7
- asttokens: 2.4.1
- async-lru: 2.0.4
- atpublic: 4.1.0
- attrs: 23.2.0
- autograd: 1.6.2
- babel: 2.14.0
- backcall: 0.2.0
- backports-strenum: 1.2.8
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.12.3
- bleach: 6.1.0
- bokeh: 3.5.1
- brotli: 1.1.0
- cached-property: 1.5.2
- certifi: 2024.6.2
- cffi: 1.17.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- debugpy: 1.8.5
- decorator: 5.1.1
- defusedxml: 0.7.1
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- docopt: 0.6.2
- docutils: 0.21.2
- entrypoints: 0.4
- exceptiongroup: 1.2.2
- executing: 2.0.1
- fastjsonschema: 2.19.1
- fbpca: 1.0
- filelock: 3.14.0
- fonttools: 4.53.0
- fqdn: 1.5.1
- frozenlist: 1.4.1
- fsspec: 2024.6.0
- furo: 2024.5.6
- future: 1.0.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- h11: 0.14.0
- h2: 4.1.0
- hatch: 1.12.0
- hatchling: 1.24.2
- hpack: 4.0.0
- html5lib: 1.1
- httpcore: 1.0.5
- httpx: 0.27.0
- humanize: 4.11.0
- hyperframe: 6.0.1
- hyperlink: 21.0.0
- idna: 3.7
- imagesize: 1.4.1
- importlib-metadata: 7.1.0
- importlib-resources: 6.4.0
- iniconfig: 2.0.0
- ipykernel: 6.29.5
- ipython: 8.12.3
- ipywidgets: 8.1.3
- isoduration: 20.11.0
- jaraco.classes: 3.4.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jedi: 0.19.1
- jinja2: 3.1.4
- joblib: 1.4.2
- json5: 0.9.25
- jsonpointer: 3.0.0
- jsonschema: 4.22.0
- jsonschema-specifications: 2023.12.1
- jupyter: 1.0.0
- jupyter-client: 8.6.2
- jupyter-console: 6.6.3
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-lsp: 2.2.5
- jupyter-server: 2.14.2
- jupyter-server-terminals: 0.5.3
- jupyterlab: 4.2.4
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.3
- jupyterlab-widgets: 3.0.11
- keyring: 25.2.1
- kiwisolver: 1.4.5
- lightkurve: 2.4.2
- lightning: 2.3.3
- lightning-utilities: 0.11.2
- lxml: 5.2.2
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.9.0
- matplotlib-inline: 0.1.7
- mdit-py-plugins: 0.4.1
- mdurl: 0.1.2
- memoization: 0.4.0
- mistune: 3.0.2
- more-itertools: 10.2.0
- mpmath: 1.3.0
- multidict: 6.0.5
- myst-parser: 3.0.1
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.3
- notebook: 7.2.1
- notebook-shim: 0.2.4
- numpy: 1.26.4
- oktopus: 0.1.2
- overrides: 7.7.0
- packaging: 24.0
- pandas: 2.2.2
- pandocfilters: 1.5.0
- parso: 0.8.4
- pathspec: 0.12.1
- patsy: 0.5.6
- peewee: 3.17.5
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 10.3.0
- pip: 24.0
- pipreqs: 0.5.0
- pkgutil-resolve-name: 1.3.10
- platformdirs: 4.2.2
- plotly: 5.22.0
- pluggy: 1.5.0
- polars: 0.20.31
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.46
- protobuf: 5.27.1
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- pyarrow: 16.1.0
- pycparser: 2.22
- pyerfa: 2.0.1.4
- pygments: 2.18.0
- pyobjc-core: 10.3.1
- pyobjc-framework-cocoa: 10.3.1
- pyparsing: 3.1.2
- pysocks: 1.7.1
- pytest: 7.4.4
- pytest-asyncio: 0.23.7
- pytest-pycharm: 0.7.0
- python-dateutil: 2.9.0
- python-json-logger: 2.0.7
- pytorch-lightning: 2.3.3
- pytz: 2024.1
- pyvo: 1.5.2
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- qtconsole: 5.5.2
- qtpy: 2.4.1
- qusi: 1.0.3
- qusi-evaluation: 0.0.1
- referencing: 0.35.1
- requests: 2.32.3
- retrying: 1.3.4
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rich: 13.7.1
- rpds-py: 0.18.1
- scikit-learn: 1.5.0
- scipy: 1.13.1
- send2trash: 1.8.3
- sentry-sdk: 2.5.1
- setproctitle: 1.3.3
- setuptools: 70.0.0
- shellingham: 1.5.4
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- snowballstemmer: 2.2.0
- soupsieve: 2.5
- sphinx: 7.3.7
- sphinx-basic-ng: 1.0.0b2
- sphinxcontrib-applehelp: 1.0.8
- sphinxcontrib-devhelp: 1.0.6
- sphinxcontrib-htmlhelp: 2.0.5
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-qthelp: 1.0.7
- sphinxcontrib-serializinghtml: 1.1.10
- stack-data: 0.6.2
- stringcase: 1.2.0
- sympy: 1.12.1
- tenacity: 8.3.0
- terminado: 0.18.1
- threadpoolctl: 3.5.0
- tinycss2: 1.3.0
- tomli: 2.0.1
- tomli-w: 1.0.0
- tomlkit: 0.12.5
- torch: 2.3.1
- torcheval: 0.0.7
- torchmetrics: 1.4.0.post0
- torchvision: 0.18.1
- tornado: 6.4.1
- tqdm: 4.66.4
- traitlets: 5.14.3
- trove-classifiers: 2024.5.22
- types-python-dateutil: 2.9.0.20240316
- typing-extensions: 4.12.2
- typing-utils: 0.1.0
- tzdata: 2024.1
- uncertainties: 3.2.1
- uri-template: 1.3.0
- urllib3: 2.2.1
- userpath: 1.9.2
- uv: 0.2.11
- uvloop: 0.19.0
- virtualenv: 20.26.2
- wandb: 0.17.1
- wcwidth: 0.2.13
- webcolors: 24.6.0
- webencodings: 0.5.1
- websocket-client: 1.8.0
- wget: 3.2
- wheel: 0.43.0
- widgetsnbextension: 4.0.11
- xyzservices: 2024.6.0
- yarg: 0.1.9
- yarl: 1.9.4
- zenodo-get: 1.6.1
- zipp: 3.19.2
- zstandard: 0.22.0 - System:
- OS: Darwin
- architecture:
- 64bit
-
- processor: arm
- python: 3.11.9
- release: 24.0.0
- version: Darwin Kernel Version 24.0.0: Tue Sep 24 23:39:07 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6000
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.3.x