uv sync
uv run -m uvicorn server:app --host 127.0.0.1 --port 8000后台运行:
nohup uv run -m server:app --host 127.0.0.1 --port 8000 > server.log 2>&1 &提交训练请求,返回最终训练损失。首次请求会通过 SLURM 提交 GPU 训练任务并阻塞等待结果。
参数:
| 参数 | 类型 | 范围 |
|---|---|---|
d_model |
int | [64, 1024],且必须被 num_heads 整除 |
num_layers |
int | [2, 24] |
num_heads |
int | [2, 16] |
batch_size |
int | 128 或 256 |
learning_rate |
float | [1e-4, 1e-3] |
train_flops |
int | 1e13, 3e13, 6e13, 1e14, 3e14, 6e14, 1e15, 3e15, 6e15, 1e16, 3e16, 6e16, 1e17, 3e17, 6e17, 1e18 |
api_key |
str | 你的 SSH 公钥 |
示例:
import requests
config = {
"d_model": 1024,
"num_layers": 24,
"num_heads": 16,
"batch_size": 128,
"learning_rate": 0.001,
"train_flops": int(1e16),
"api_key": "<YOUR_API_KEY>"
}
r = requests.get("http://localhost:8000/loss", params=config)
print(r.json())
# {'loss': 9.07, 'total_flops_used': 10000000000000000}响应:
| 状态码 | 说明 |
|---|---|
| 200 | 返回 {"loss": float, "total_flops_used": float} |
| 404 | 参数超出范围 |
| 403 | API key 无效 / FLOPs 配额已用尽 |
| 422 | 该配置正在训练中(含进度信息) |
| 500 | 训练任务失败 |
查询当前 API key 已使用的总 FLOPs。
requests.get("http://localhost:8000/total_flops_used", params={"api_key": "<YOUR_API_KEY>"}).json()
# 10000000000000000查询当前 API key 的所有历史训练记录。
requests.get("http://localhost:8000/previous_runs", params={"api_key": "<YOUR_API_KEY>"}).json()
# {'previous_runs': [{'d_model': 1024, 'num_layers': 24, ...}]}API key 为你的 SSH 公钥(完整一行,不含换行符),存储在 valid_api_keys.txt 中。
每个用户配额 1e19 FLOPs,硬上限 1.2e19 FLOPs。超出后 /loss 返回 403,但仍可查询 /total_flops_used 和 /previous_runs。
- 相同超参数组合只训练一次,所有用户共享结果
- 同一用户重复查询相同配置不会重复扣减 FLOPs
- 训练结果持久化在
/data/share/hw3-data/scaling_api.db
- 模型:
BasicsTransformerLM(Pre-Norm Transformer),d_ff = 4 × d_model - 数据集:SlimPajama(byte-level BPE,32K vocab,context_length=512)
- 优化器:AdamW(weight_decay=0.01, grad_clip=1.0)
- 学习率:Cosine schedule 衰减至 initial_lr / 10,无 warmup
- Dropout:attention 和 residual 均为 0.1
- FLOPs 估算:6 × N_params × batch_size × context_length per step
server.py # FastAPI 服务端
train.py # GPU 训练脚本(通过 SLURM 提交)
model.py # Transformer 模型定义
data_loader.py # 数据加载与 tokenization
database.py # SQLite 缓存与 FLOPs 追踪
valid_api_keys.txt # 合法 API key 列表