Skip to content

Commit 615cf3d

Browse files
authored
[Doc] update loss.md after refactoring loss ctx (#1474)
* [Doc] update loss.md after refactoring loss ctx * update chunk mode time by run twice
1 parent e368d87 commit 615cf3d

File tree

1 file changed

+23
-26
lines changed
  • docs/zh_cn/pretrain_sft/advanced_tutorial

1 file changed

+23
-26
lines changed

docs/zh_cn/pretrain_sft/advanced_tutorial/loss.md

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds")
6363
Eager mode Loss: 12.125
6464
Eager mode hidden_states grad norm: 0.0031890869140625
6565
Eager mode lm_head weight grad norm: 0.353515625
66-
Eager mode Max memory allocated: 38.53 GB
67-
Eager mode Max memory reserved: 38.54 GB
68-
Eager mode Time taken: 0.57 seconds
69-
Chunk mode Loss: 12.096513748168945
66+
Eager mode Max memory allocated: 38.57 GB
67+
Eager mode Max memory reserved: 47.81 GB
68+
Eager mode Time taken: 0.42 seconds
69+
Chunk mode Loss: 12.094674110412598
7070
Chunk mode hidden_states grad norm: 0.0031890869140625
7171
Chunk mode lm_head weight grad norm: 0.353515625
72-
Chunk mode Max memory allocated: 8.32 GB
73-
Chunk mode Max memory reserved: 8.40 GB
74-
Chunk mode Time taken: 0.40 seconds
72+
Chunk mode Max memory allocated: 6.87 GB
73+
Chunk mode Max memory reserved: 9.56 GB
74+
Chunk mode Time taken: 0.26 seconds
7575
```
7676

7777
(global-average)=
@@ -91,7 +91,7 @@ loss 全局校准是指,无论使用多少张显卡,无论使用什么并行
9191
import torch
9292
import torch.nn as nn
9393
import torch.nn.functional as F
94-
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
94+
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext
9595
from mmengine.dist import infer_launcher, init_dist
9696
import torch.distributed as dist
9797

@@ -131,10 +131,10 @@ hidden_states = torch.chunk(hidden_states, world_size, dim=0)[rank]
131131
shifted_labels = torch.chunk(shifted_labels_gt, world_size, dim=0)[rank]
132132
hidden_states = hidden_states.unsqueeze(0)
133133
shifted_labels = shifted_labels.unsqueeze(0)
134-
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)]
135134
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
136-
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
137-
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0])
135+
loss_ctx = loss_cfg.build(shifted_labels)
136+
loss_ctx_list = CELossContext.build_batches([loss_ctx])
137+
loss_ctx = loss_ctx_list[0]
138138
loss, _ = loss_ctx.forward(hidden_states, lm_head.weight)
139139
loss.backward()
140140

@@ -198,7 +198,7 @@ XTuner 中所有的 loss 计算均涉及两个核心组件 `LossConfig` 和 `Los
198198
```python
199199
import torch
200200
import torch.nn as nn
201-
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
201+
from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext
202202

203203
emb = nn.Embedding(4, 2)
204204
head = nn.Linear(2, 4, bias=False)
@@ -210,8 +210,9 @@ hidden_states = emb(input_ids)
210210

211211
loss_ctx_input_list = [CELossContextInputItem(shifted_labels=shifted_labels)]
212212
loss_cfg = CELossConfig(mode='chunk', chunk_size=1024, loss_reduction="token")
213-
batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
214-
loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0])
213+
loss_ctx = loss_cfg.build(shifted_labels=data["shifted_labels"])
214+
loss_ctx_list = CELossContext.build_batches([loss_ctx])
215+
loss_ctx = loss_ctx_list[0]
215216
loss, _ = loss_ctx.forward(hidden_states, head.weight)
216217
loss.backward()
217218
```
@@ -235,14 +236,13 @@ class CELossConfig:
235236

236237
### CELossContext
237238

238-
`CELossContext` 中我们引入了额外的两个数据结构[`CELossKwargs`](xtuner.v1.loss.ce_loss.CELossKwargs)[`CELossContextInputItem`](xtuner.v1.loss.ce_loss.CELossContextInputItem)
239+
`CELossContext` 中我们引入了额外的一个数据结构[`CELossKwargs`](xtuner.v1.loss.ce_loss.CELossKwargs)
239240

240241
- `CELossKwargs` 表示 CE Loss 实际计算的时候需要用到哪些参数,即:`shifted_labels``loss_weight` 两项,注意此时的 `loss_weight` 已经经历过全局校准的处理了,详细实现请参考 `xtuner/v1/loss/ce_loss.py`
241-
- `CELossContextInputItem` 则表示计算出 `CELossKwargs` 需要哪些信息,即:`shifted_labels`
242242

243243
我们在 `CELossContext` 中只需要实现两个接口:
244244

245-
1. 为了做 loss 全局校准,classmethod `build_batches_loss_kwargs` 输入梯度累积范围内每一条数据对应的 `CELossContextInputItem` ,并计算出每一个 iter 的 `CELossKwargs`
245+
1. 为了做 loss 全局校准,staticmethod `build_batches` 计算全局校准对应的loss weight
246246
2. `loss_fn` 根据 `CELossKwargs` 计算出当前 iter 的 loss。
247247

248248
对于其他功能(如:chunk loss),不同 loss 都是通用的,我们统一放到 `BaseLossContext` 里实现。
@@ -281,25 +281,22 @@ class CustomLossKwargs(BaseLossKwargs):
281281
...
282282
```
283283

284-
第二步,继承 `BaseLossContext` 并实现 `CustomLossContext` 中的 classmethod `build_batches_loss_kwargs``loss_fn`
284+
第二步,继承 `BaseLossContext` 并实现 `CustomLossContext` 中的 classmethod `build_batches``loss_fn`
285285

286286
```python
287287
from xtuner.v1.loss import BaseLossContext, BaseLossKwargs
288-
from xtuner.v1.loss.ce_loss import CELossContextInputItem
289288

290-
class CustomLossContext(BaseLossContext[CELossContextInputItem]):
289+
class CustomLossContext(BaseLossContext):
291290
loss_cfg: CustomLossConfig
292291
loss_kwargs: CustomLossKwargs
293292

294-
@classmethod
295-
def build_batches_loss_kwargs(
296-
cls,
297-
data_batches: list[RLLossContextInputItem],
298-
loss_cfg: CustomLossConfig,
293+
@staticmethod
294+
def build_batches(
295+
loss_ctx_list: list["CELossContext"],
299296
# 为了提高计算效率,XTuner 会将多条短数据 pack 成一条长数据进行训练
300297
# 若在计算 CustomLossKwargs 的过程中需要解 pack 成若干短数据,则需要传入 cu_seq_lens_list
301298
# 默认为 None 即可。
302-
cu_seq_lens_list: list[torch.Tensor] | None = None,
299+
cu_seq_lens_list: Sequence[torch.IntTensor] | None = None,
303300
# 若开启了序列并行 (sp) 且计算 CustomLossKwargs 的过程中需要 sp 切分前的数据,则需要传入 cu_seq_lens_list
304301
# 默认为 None 即可。
305302
sp_mesh: DeviceMesh | None = None,

0 commit comments

Comments
 (0)