Skip to content

Error when learning on tpu #20891

@intexcor

Description

@intexcor

Bug description


BrokenProcessPool Traceback (most recent call last)
in <cell line: 0>()
159 )
160
--> 161 trainer.fit(model, train_loader, val_loader)
162
163

11 frames
/usr/lib/python3.11/concurrent/futures/_base.py in __get_result(self)
399 if self._exception:
400 try:
--> 401 raise self._exception
402 finally:
403 # Break a reference cycle with the exception in self._exception

BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import torch
from torch.utils.data import DataLoader, Dataset
import lightning as pl
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os


os.environ["WANDB_API_KEY"] = "652be9a335ccff9372ec8e5b16946c34163f0ff5"
os.environ["HF_TOKEN"] = "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"


torch.set_float32_matmul_precision('high')


class ChatDataset(Dataset):
    def __init__(self, dataset, tokenizer, max_length=1024):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        data = self.dataset[idx]

        chat = ("<|im_start|>user\n" + data["input"] + "<|im_end|>\n" +
                "<|im_start|>assistant\n<think>\n \n</think>\n" + data["output"] + "<|im_end|>\n")

        encoding = self.tokenizer(
            chat,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()

        labels = input_ids.clone()

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }


class LanguageModelLightning(pl.LightningModule):
    def __init__(self, model_name, learning_rate=2e-5, weight_decay=0.01):
        super().__init__()
        self.save_hyperparameters()

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto"
        )
        self.model.train()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )

        loss = outputs.loss
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        return optimizer


dataset = load_dataset('intexcp/russian-llm-training-dataset')

model = LanguageModelLightning("Qwen/Qwen3-0.6B")
#model = torch.compile(model)

train_dataset = ChatDataset(dataset["train"], model.tokenizer)
val_dataset = ChatDataset(dataset["test"], model.tokenizer)


train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=16,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    num_workers=16,
    pin_memory=True
)

wandb_logger = WandbLogger(
    project="IGen",
    name="IGen"
)

checkpoint_callback = ModelCheckpoint(
    dirpath="IGen/checkpoints",
    filename='{epoch}-{val_loss:.2f}',
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_last=True
)

trainer = Trainer(
    max_epochs=2,
    precision="bf16-true",
    accelerator="auto",
    strategy="auto",
    devices="auto",
    callbacks=[checkpoint_callback],
    check_val_every_n_epoch=1,
    log_every_n_steps=50,
    enable_model_summary=True,
    enable_progress_bar=True,
)

trainer.fit(model, train_loader, val_loader)


model.model.save_pretrained("IGen/final_model")
model.tokenizer.save_pretrained("IGen/final_model")

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions