-
Notifications
You must be signed in to change notification settings - Fork 224
Open
Description
Context
Mixed-precision training (FP16/BF16) significantly accelerates training while reducing memory usage. The repository currently has a placeholder use_amp flag, but the implementation is incomplete.
Detailed Analysis
- The
--use_ampflag exists inrun.pybut appears unused in experiment files - Modern NVIDIA GPUs (Volta, Turing, Ampere architectures) provide substantial speedups with mixed precision
- Implementation should leverage PyTorch's native
torch.cuda.ampmodule
Implementation Recommendation
from torch.cuda.amp import autocast, GradScaler
# Initialize scaler once at beginning of training
scaler = GradScaler() if args.use_amp else None
# In training loop
with autocast(enabled=args.use_amp):
outputs = model(batch_x, batch_y, batch_x_mark, batch_y_mark)
loss = criterion(outputs, batch_y)
if args.use_amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()Expected Outcomes
- 1.5-3x training speedup depending on model size and GPU
- ~50% memory reduction enabling larger models/batches
- Minimal to no impact on final model accuracy
Respectfully submitted,
Quality Assurance Team
Metadata
Metadata
Assignees
Labels
No labels