-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
Is there an existing issue for this bug?
- I have searched the existing issues
The bug has not been fixed in the latest main branch
- I have checked the latest main branch
Do you feel comfortable sharing a concise (minimal) script that reproduces the error? :)
Yes, I will share a minimal reproducible script.
🐛 Describe the bug
When using the Adagrad optimizer together with the LowLevelZeroPlugin configured with initial_scale=8.332635365271916, a KeyError: 'sum' error occurs.
The specific reproduction script is main.py:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
class RandomDataset(Dataset):
def __init__(self, num_samples=32 * 10, input_dim=1024, num_classes=10):
self.x = torch.randn(num_samples, input_dim)
self.y = torch.randint(0, num_classes, (num_samples,))
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
class MLP(nn.Module):
def __init__(self, input_dim=1024, hidden_dim=512, num_layers=10, num_classes=10):
super().__init__()
layers = []
for i in range(num_layers):
in_dim = input_dim if i == 0 else hidden_dim
layers.append(nn.Linear(in_dim, hidden_dim))
layers.append(nn.ReLU())
layers.append(nn.Linear(hidden_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def main():
seed = 3057219192
colossalai.launch_from_torch(seed=seed)
plugin = LowLevelZeroPlugin(
initial_scale=8.332635365271916
)
booster = Booster(plugin=plugin)
model = MLP()
optimizer = optim.Adagrad(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
dataset = RandomDataset()
train_dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
precision = getattr(plugin, "precision", "fp16")
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map.get(precision, torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for epoch in range(1):
total_loss = 0
for step, (x, y) in enumerate(train_dataloader):
x = x.to(device=device, dtype=dtype)
y = y.to(device=device)
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
booster.backward(loss, optimizer)
optimizer.step()
total_loss += loss.item()
print(f"[Epoch {epoch}] step {step}, loss = {loss.item():.4f}")
avg_loss = total_loss / len(train_dataloader)
print(f"Epoch {epoch} finished, average loss = {avg_loss:.4f}")
if __name__ == "__main__":
main()
Running the following training command:
$ colossalai run --nproc_per_node 4 --master_port 29505 ./bug6.py
Will produce the following error log:
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/./bug6.py", line 80, in <module>
[rank0]: main()
[rank0]: File "/home/yanzhen/distributed_test/colossalAI/test/./bug6.py", line 70, in main
[rank0]: optimizer.step()
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/colossalai/zero/low_level/low_level_optim.py", line 588, in step
[rank0]: self.optim.step()
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 487, in wrapper
[rank0]: out = func(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
[rank0]: ret = func(self, *args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/adagrad.py", line 164, in step
[rank0]: has_sparse_grad, has_complex = self._init_group(
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: File "/home/yanzhen/miniconda3/envs/colossal/lib/python3.9/site-packages/torch/optim/adagrad.py", line 139, in _init_group
[rank0]: state_sums.append(state["sum"])
[rank0]: KeyError: 'sum'
Environment
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.9.23 (main, Jun 5 2025, 13:40:20) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 4090
GPU 1: NVIDIA GeForce RTX 4090
GPU 2: NVIDIA GeForce RTX 4090
GPU 3: NVIDIA GeForce RTX 4090
Nvidia driver version: 580.65.06
cuDNN version: Probably one of the following:
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.6/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7773X 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 2
Frequency boost: enabled
CPU max MHz: 3527.7339
CPU min MHz: 1500.0000
BogoMIPS: 4400.15
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin brs arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 64 MiB (128 instances)
L3 cache: 1.5 GiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] galore-torch==1.0
[pip3] numpy==2.0.2
[pip3] torch==2.5.1
[pip3] triton==3.1.0
[conda] galore-torch 1.0 pypi_0 pypi
[conda] numpy 2.0.2 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi