Skip to content

Error using wandb when learning on tpu #20880

@intexcor

Description

@intexcor

Bug description

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/wandb/sdk/wandb_init.py", line 1114, in _attach
    attach_settings = service.inform_attach(attach_id=attach_id)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/wandb/sdk/lib/service_connection.py", line 177, in inform_attach
    raise WandbAttachFailedError(
wandb.sdk.lib.service_connection.WandbAttachFailedError: Failed to attach because the run does not belong to the current service process, or because the service process is busy (unlikely).

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/lib/python3.11/concurrent/futures/process.py", line 261, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/concurrent/futures/process.py", line 210, in _process_chunk
    return [fn(*args) for args in chunk]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/concurrent/futures/process.py", line 210, in <listcomp>
    return [fn(*args) for args in chunk]
            ^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch_xla/_internal/pjrt.py", line 77, in _run_thread_per_device
    replica_results = list(
                      ^^^^^
  File "/usr/lib/python3.11/concurrent/futures/_base.py", line 619, in result_iterator
    yield _result_or_cancel(fs.pop())
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/concurrent/futures/_base.py", line 317, in _result_or_cancel
    return fut.result(timeout)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/concurrent/futures/_base.py", line 456, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/usr/lib/python3.11/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch_xla/_internal/pjrt.py", line 70, in _thread_fn
    return fn()
           ^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch_xla/_internal/pjrt.py", line 185, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/strategies/launchers/xla.py", line 139, in _wrapping_function
    trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 211, in _deepcopy_tuple
    y = [deepcopy(a, memo) for a in x]
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 211, in <listcomp>
    y = [deepcopy(a, memo) for a in x]
         ^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 206, in _deepcopy_list
    append(deepcopy(a, memo))
           ^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 161, in deepcopy
    rv = reductor(4)
         ^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loggers/wandb.py", line 360, in __getstate__
    _ = self.experiment
        ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/lightning/fabric/loggers/logger.py", line 118, in experiment
    return fn(self)
           ^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/lightning/pytorch/loggers/wandb.py", line 404, in experiment
    self._experiment = wandb._attach(attach_id)
                       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/wandb/sdk/wandb_init.py", line 1116, in _attach
    raise UsageError(f"Unable to attach to run {attach_id}") from e
wandb.errors.errors.UsageError: Unable to attach to run aoaxb9zk
"""

The above exception was the direct cause of the following exception:

UsageError                                Traceback (most recent call last)
[<ipython-input-2-1b9dc128dae1>](https://localhost:8080/#) in <cell line: 0>()
    160 )
    161 
--> 162 trainer.fit(model, train_loader, val_loader)
    163 
    164 

11 frames
[/usr/lib/python3.11/concurrent/futures/_base.py](https://localhost:8080/#) in __get_result(self)
    399         if self._exception:
    400             try:
--> 401                 raise self._exception
    402             finally:
    403                 # Break a reference cycle with the exception in self._exception

UsageError: Unable to attach to run aoaxb9zk

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import torch
from torch.utils.data import DataLoader, Dataset
import lightning as pl
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os


torch.set_float32_matmul_precision('high')


class ChatDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        conversation = self.dataset[idx]["conversation"]
        chat = ""
        for message in conversation:
            role = message["role"]
            content = message["content"]
            chat += f"<|im_start|>{role}\n{content}<|im_end|>\n"

        encoding = self.tokenizer(
            chat,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()

        labels = input_ids.clone()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


class LanguageModelLightning(pl.LightningModule):
    def __init__(self, model_name, learning_rate=2e-5, weight_decay=0.01):
        super().__init__()
        self.save_hyperparameters()

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto"
        )
        self.model.train()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        return optimizer


dataset = load_dataset('Vikhrmodels/GrandMaster-PRO-MAX')

model = LanguageModelLightning("Qwen/Qwen3-0.6B")

train_dataset = ChatDataset(dataset["train"], model.tokenizer)
val_dataset = ChatDataset(dataset["test"], model.tokenizer)


train_loader = DataLoader(
    train_dataset,
    batch_size=90,
    shuffle=True,
    num_workers=16,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=90,
    shuffle=False,
    num_workers=16,
    pin_memory=True
)

wandb_logger = WandbLogger(
    project="IGen",
    name="IGen"
)

checkpoint_callback = ModelCheckpoint(
    dirpath="IGen/checkpoints",
    filename='{epoch}-{val_loss:.2f}',
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_last=True
)

trainer = Trainer(
    max_epochs=5,
    precision="bf16-true",
    accelerator="auto",
    strategy="auto",
    devices="auto",
    logger=wandb_logger,
    callbacks=[checkpoint_callback],
    check_val_every_n_epoch=1,
    log_every_n_steps=50,
    enable_model_summary=True,
    enable_progress_bar=True,
)

trainer.fit(model, train_loader, val_loader)


model.model.save_pretrained("IGen/final_model")
model.tokenizer.save_pretrained("IGen/final_model")

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux): Ubuntu 25.04
#- CUDA/cuDNN version: 12.8
#- GPU models and configuration: TPU v2-8
#- How you installed Lightning(`conda`, `pip`, source): pip

More info

No response

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingloggerRelated to the Loggersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions