Skip to content

WandbLogger logging step incorrectly if project name not passed. #20383

@golmschenk

Description

@golmschenk

Bug description

If the project name is not passed (left as None), when creating a WandbLogger, the logger logs the global step differently.

What version are you seeing the problem on?

v2.3

How to reproduce the bug

In the below minimal example, project is not passed to WandbLogger and the incorrect step is logged:

import os

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers.wandb import WandbLogger
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28 * 28, 10)
        self.loss_metric = CrossEntropyLoss()
        self.train_loss_total = torch.tensor(0, dtype=torch.float32)
        self.validation_loss_total = torch.tensor(0, dtype=torch.float32)
        self.train_steps = torch.tensor(0, dtype=torch.int64)
        self.validation_steps = torch.tensor(0, dtype=torch.int64)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_metric(y_hat, y)
        self.train_loss_total += loss
        self.train_steps += 1
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_metric(y_hat, y)
        self.validation_loss_total += loss
        self.validation_steps += 1
        return loss

    def on_train_epoch_end(self):
        self.log('loss', self.train_loss_total / self.train_steps, on_step=False, on_epoch=True)
        self.train_loss_total.zero_()
        self.train_steps.zero_()

    def on_validation_epoch_end(self):
        self.log('val_loss', self.validation_loss_total / self.validation_steps, on_step=False, on_epoch=True)
        self.validation_loss_total.zero_()
        self.validation_steps.zero_()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


train_loader = DataLoader(
    MNIST(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
)
val_loader = DataLoader(
    MNIST(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
)
logger = WandbLogger(project='a')
trainer = pl.Trainer(
    max_epochs=10,
    limit_train_batches=20,
    limit_val_batches=10,
    log_every_n_steps=0,
    logger=logger,
    accelerator='cpu',
)
model = LitModel()

trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

This results in the step being incremented only at the end of the epoch, and once for both training and validation, meaning they have different step numbers (each of which less than the global training step numbers).

screenshot_2024_11_01_13_45_30

By only changing WandbLogger() to WandbLogger(project='a'), it changes how the step value is logged. With this change, the step value is logged as the global training step value, and is consistent between the train and validation logging.

screenshot_2024_11_01_13_45_50

Environment

Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: None
  • Lightning:
    - lightning: 2.3.3
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.3.3
    - torch: 2.3.1
    - torcheval: 0.0.7
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1
  • Packages:
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - alabaster: 0.7.16
    - anyio: 4.4.0
    - appnope: 0.1.4
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - astropy: 6.1.0
    - astropy-iers-data: 0.2024.6.3.0.31.14
    - astroquery: 0.4.7
    - asttokens: 2.4.1
    - async-lru: 2.0.4
    - atpublic: 4.1.0
    - attrs: 23.2.0
    - autograd: 1.6.2
    - babel: 2.14.0
    - backcall: 0.2.0
    - backports-strenum: 1.2.8
    - backports.tarfile: 1.2.0
    - beautifulsoup4: 4.12.3
    - bleach: 6.1.0
    - bokeh: 3.5.1
    - brotli: 1.1.0
    - cached-property: 1.5.2
    - certifi: 2024.6.2
    - cffi: 1.17.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.5
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - docopt: 0.6.2
    - docutils: 0.21.2
    - entrypoints: 0.4
    - exceptiongroup: 1.2.2
    - executing: 2.0.1
    - fastjsonschema: 2.19.1
    - fbpca: 1.0
    - filelock: 3.14.0
    - fonttools: 4.53.0
    - fqdn: 1.5.1
    - frozenlist: 1.4.1
    - fsspec: 2024.6.0
    - furo: 2024.5.6
    - future: 1.0.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - h11: 0.14.0
    - h2: 4.1.0
    - hatch: 1.12.0
    - hatchling: 1.24.2
    - hpack: 4.0.0
    - html5lib: 1.1
    - httpcore: 1.0.5
    - httpx: 0.27.0
    - humanize: 4.11.0
    - hyperframe: 6.0.1
    - hyperlink: 21.0.0
    - idna: 3.7
    - imagesize: 1.4.1
    - importlib-metadata: 7.1.0
    - importlib-resources: 6.4.0
    - iniconfig: 2.0.0
    - ipykernel: 6.29.5
    - ipython: 8.12.3
    - ipywidgets: 8.1.3
    - isoduration: 20.11.0
    - jaraco.classes: 3.4.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jedi: 0.19.1
    - jinja2: 3.1.4
    - joblib: 1.4.2
    - json5: 0.9.25
    - jsonpointer: 3.0.0
    - jsonschema: 4.22.0
    - jsonschema-specifications: 2023.12.1
    - jupyter: 1.0.0
    - jupyter-client: 8.6.2
    - jupyter-console: 6.6.3
    - jupyter-core: 5.7.2
    - jupyter-events: 0.10.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.14.2
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.2.4
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.3
    - jupyterlab-widgets: 3.0.11
    - keyring: 25.2.1
    - kiwisolver: 1.4.5
    - lightkurve: 2.4.2
    - lightning: 2.3.3
    - lightning-utilities: 0.11.2
    - lxml: 5.2.2
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.0
    - matplotlib-inline: 0.1.7
    - mdit-py-plugins: 0.4.1
    - mdurl: 0.1.2
    - memoization: 0.4.0
    - mistune: 3.0.2
    - more-itertools: 10.2.0
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - myst-parser: 3.0.1
    - nbclient: 0.10.0
    - nbconvert: 7.16.4
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.3
    - notebook: 7.2.1
    - notebook-shim: 0.2.4
    - numpy: 1.26.4
    - oktopus: 0.1.2
    - overrides: 7.7.0
    - packaging: 24.0
    - pandas: 2.2.2
    - pandocfilters: 1.5.0
    - parso: 0.8.4
    - pathspec: 0.12.1
    - patsy: 0.5.6
    - peewee: 3.17.5
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 10.3.0
    - pip: 24.0
    - pipreqs: 0.5.0
    - pkgutil-resolve-name: 1.3.10
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 0.20.31
    - prometheus-client: 0.20.0
    - prompt-toolkit: 3.0.46
    - protobuf: 5.27.1
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - pyarrow: 16.1.0
    - pycparser: 2.22
    - pyerfa: 2.0.1.4
    - pygments: 2.18.0
    - pyobjc-core: 10.3.1
    - pyobjc-framework-cocoa: 10.3.1
    - pyparsing: 3.1.2
    - pysocks: 1.7.1
    - pytest: 7.4.4
    - pytest-asyncio: 0.23.7
    - pytest-pycharm: 0.7.0
    - python-dateutil: 2.9.0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.3.3
    - pytz: 2024.1
    - pyvo: 1.5.2
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - qtconsole: 5.5.2
    - qtpy: 2.4.1
    - qusi: 1.0.3
    - qusi-evaluation: 0.0.1
    - referencing: 0.35.1
    - requests: 2.32.3
    - retrying: 1.3.4
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.7.1
    - rpds-py: 0.18.1
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - send2trash: 1.8.3
    - sentry-sdk: 2.5.1
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - shellingham: 1.5.4
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - snowballstemmer: 2.2.0
    - soupsieve: 2.5
    - sphinx: 7.3.7
    - sphinx-basic-ng: 1.0.0b2
    - sphinxcontrib-applehelp: 1.0.8
    - sphinxcontrib-devhelp: 1.0.6
    - sphinxcontrib-htmlhelp: 2.0.5
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.7
    - sphinxcontrib-serializinghtml: 1.1.10
    - stack-data: 0.6.2
    - stringcase: 1.2.0
    - sympy: 1.12.1
    - tenacity: 8.3.0
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - tinycss2: 1.3.0
    - tomli: 2.0.1
    - tomli-w: 1.0.0
    - tomlkit: 0.12.5
    - torch: 2.3.1
    - torcheval: 0.0.7
    - torchmetrics: 1.4.0.post0
    - torchvision: 0.18.1
    - tornado: 6.4.1
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - trove-classifiers: 2024.5.22
    - types-python-dateutil: 2.9.0.20240316
    - typing-extensions: 4.12.2
    - typing-utils: 0.1.0
    - tzdata: 2024.1
    - uncertainties: 3.2.1
    - uri-template: 1.3.0
    - urllib3: 2.2.1
    - userpath: 1.9.2
    - uv: 0.2.11
    - uvloop: 0.19.0
    - virtualenv: 20.26.2
    - wandb: 0.17.1
    - wcwidth: 0.2.13
    - webcolors: 24.6.0
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - wget: 3.2
    - wheel: 0.43.0
    - widgetsnbextension: 4.0.11
    - xyzservices: 2024.6.0
    - yarg: 0.1.9
    - yarl: 1.9.4
    - zenodo-get: 1.6.1
    - zipp: 3.19.2
    - zstandard: 0.22.0
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    -
    - processor: arm
    - python: 3.11.9
    - release: 24.0.0
    - version: Darwin Kernel Version 24.0.0: Tue Sep 24 23:39:07 PDT 2024; root:xnu-11215.1.12~1/RELEASE_ARM64_T6000

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.3.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions