Skip to content

[BUG]: KeyError: 'sum' When Using Adagrad Optimizer with LowLevelZeroPlugin (initial_scale=8.33) #6396

@sdjasj

Description

@sdjasj

Is there an existing issue for this bug?

  • I have searched the existing issues

The bug has not been fixed in the latest main branch

  • I have checked the latest main branch

Do you feel comfortable sharing a concise (minimal) script that reproduces the error? :)

Yes, I will share a minimal reproducible script.

🐛 Describe the bug

When using the Adagrad optimizer together with the LowLevelZeroPlugin configured with initial_scale=8.332635365271916, a KeyError: 'sum' error occurs.
The specific reproduction script is main.py:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin

class RandomDataset(Dataset):
    def __init__(self, num_samples=32 * 10, input_dim=1024, num_classes=10):
        self.x = torch.randn(num_samples, input_dim)
        self.y = torch.randint(0, num_classes, (num_samples,))

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

class MLP(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=512, num_layers=10, num_classes=10):
        super().__init__()
        layers = []
        for i in range(num_layers):
            in_dim = input_dim if i == 0 else hidden_dim
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

def main():
    seed = 3057219192
    colossalai.launch_from_torch(seed=seed)
    plugin = LowLevelZeroPlugin(
        initial_scale=8.332635365271916
    )

    booster = Booster(plugin=plugin)

    model = MLP()
    optimizer = optim.Adagrad(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    dataset = RandomDataset()
    train_dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)

    precision = getattr(plugin, "precision", "fp16")
    dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
    dtype = dtype_map.get(precision, torch.float16)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.train()
    for epoch in range(1):
        total_loss = 0
        for step, (x, y) in enumerate(train_dataloader):
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device)

            optimizer.zero_grad()
            output = model(x)
            loss = criterion(output, y)

            booster.backward(loss, optimizer)
            optimizer.step()

            total_loss += loss.item()
            print(f"[Epoch {epoch}] step {step}, loss = {loss.item():.4f}")

        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch} finished, average loss = {avg_loss:.4f}")


if __name__ == "__main__":
    main()

Running the following training command:

$ colossalai run --nproc_per_node 4 --master_port 29505 ./bug6.py

Will produce the following error log:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/yanzhen/distributed_test/colossalAI/test/./bug6.py", line 80, in <module>
[rank0]:     main()
[rank0]:   File "/home/yanzhen/distributed_test/colossalAI/test/./bug6.py", line 70, in main
[rank0]:     optimizer.step()
[rank0]:   File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/zero/low_level/low_level_optim.py", line 588, in step
[rank0]:     self.optim.step()
[rank0]:   File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank0]:     out = func(*args, **kwargs)
[rank0]:   File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank0]:     ret = func(self, *args, **kwargs)
[rank0]:   File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/adagrad.py", line 164, in step
[rank0]:     has_sparse_grad, has_complex = self._init_group(
[rank0]:   File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/adagrad.py", line 139, in _init_group
[rank0]:     state_sums.append(state["sum"])
[rank0]: KeyError: 'sum'

Environment

PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.28.3
Libc version: glibc-2.39

Python version: 3.9.23 (main, Jun 5 2025, 13:40:20) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090
GPU 2: NVIDIA GeForce RTX 4090
GPU 3: NVIDIA GeForce RTX 4090

Nvidia driver version: 580.65.06
cuDNN version: Probably one of the following:
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7773X 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 2
Frequency boost: enabled
CPU max MHz: 3527.7339
CPU min MHz: 1500.0000
BogoMIPS: 4400.15
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 1.5 GiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] galore-torch==1.0
[pip3] numpy==2.0.2
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] galore-torch 1.0 pypi_0 pypi
[conda] numpy 2.0.2 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions