Skip to content

Latest commit

 

History

History
113 lines (83 loc) · 3.11 KB

File metadata and controls

113 lines (83 loc) · 3.11 KB

CS336 Scaling Law Training API

快速启动

uv sync
uv run -m uvicorn server:app --host 127.0.0.1 --port 8000

后台运行:

nohup uv run -m server:app --host 127.0.0.1 --port 8000 > server.log 2>&1 &

API 端点

GET /loss

提交训练请求,返回最终训练损失。首次请求会通过 SLURM 提交 GPU 训练任务并阻塞等待结果。

参数:

参数 类型 范围
d_model int [64, 1024],且必须被 num_heads 整除
num_layers int [2, 24]
num_heads int [2, 16]
batch_size int 128 或 256
learning_rate float [1e-4, 1e-3]
train_flops int 1e13, 3e13, 6e13, 1e14, 3e14, 6e14, 1e15, 3e15, 6e15, 1e16, 3e16, 6e16, 1e17, 3e17, 6e17, 1e18
api_key str 你的 SSH 公钥

示例:

import requests

config = {
    "d_model": 1024,
    "num_layers": 24,
    "num_heads": 16,
    "batch_size": 128,
    "learning_rate": 0.001,
    "train_flops": int(1e16),
    "api_key": "<YOUR_API_KEY>"
}
r = requests.get("http://localhost:8000/loss", params=config)
print(r.json())
# {'loss': 9.07, 'total_flops_used': 10000000000000000}

响应:

状态码 说明
200 返回 {"loss": float, "total_flops_used": float}
404 参数超出范围
403 API key 无效 / FLOPs 配额已用尽
422 该配置正在训练中(含进度信息)
500 训练任务失败

GET /total_flops_used

查询当前 API key 已使用的总 FLOPs。

requests.get("http://localhost:8000/total_flops_used", params={"api_key": "<YOUR_API_KEY>"}).json()
# 10000000000000000

GET /previous_runs

查询当前 API key 的所有历史训练记录。

requests.get("http://localhost:8000/previous_runs", params={"api_key": "<YOUR_API_KEY>"}).json()
# {'previous_runs': [{'d_model': 1024, 'num_layers': 24, ...}]}

API Key

API key 为你的 SSH 公钥(完整一行,不含换行符),存储在 valid_api_keys.txt 中。

FLOPs 配额

每个用户配额 1e19 FLOPs,硬上限 1.2e19 FLOPs。超出后 /loss 返回 403,但仍可查询 /total_flops_used/previous_runs

缓存机制

  • 相同超参数组合只训练一次,所有用户共享结果
  • 同一用户重复查询相同配置不会重复扣减 FLOPs
  • 训练结果持久化在 /data/share/hw3-data/scaling_api.db

训练细节

  • 模型:BasicsTransformerLM(Pre-Norm Transformer),d_ff = 4 × d_model
  • 数据集:SlimPajama(byte-level BPE,32K vocab,context_length=512)
  • 优化器:AdamW(weight_decay=0.01, grad_clip=1.0)
  • 学习率:Cosine schedule 衰减至 initial_lr / 10,无 warmup
  • Dropout:attention 和 residual 均为 0.1
  • FLOPs 估算:6 × N_params × batch_size × context_length per step

项目结构

server.py          # FastAPI 服务端
train.py           # GPU 训练脚本(通过 SLURM 提交)
model.py           # Transformer 模型定义
data_loader.py     # 数据加载与 tokenization
database.py        # SQLite 缓存与 FLOPs 追踪
valid_api_keys.txt # 合法 API key 列表