@@ -63,15 +63,15 @@ print(f"Chunk mode Time taken: {time.time() - t1:.2f} seconds")
6363Eager mode Loss: 12.125
6464Eager mode hidden_states grad norm: 0.0031890869140625
6565Eager mode lm_head weight grad norm: 0.353515625
66- Eager mode Max memory allocated: 38.53 GB
67- Eager mode Max memory reserved: 38.54 GB
68- Eager mode Time taken: 0.57 seconds
69- Chunk mode Loss: 12.096513748168945
66+ Eager mode Max memory allocated: 38.57 GB
67+ Eager mode Max memory reserved: 47.81 GB
68+ Eager mode Time taken: 0.42 seconds
69+ Chunk mode Loss: 12.094674110412598
7070Chunk mode hidden_states grad norm: 0.0031890869140625
7171Chunk mode lm_head weight grad norm: 0.353515625
72- Chunk mode Max memory allocated: 8.32 GB
73- Chunk mode Max memory reserved: 8.40 GB
74- Chunk mode Time taken: 0.40 seconds
72+ Chunk mode Max memory allocated: 6.87 GB
73+ Chunk mode Max memory reserved: 9.56 GB
74+ Chunk mode Time taken: 0.26 seconds
7575```
7676
7777(global-average)=
@@ -91,7 +91,7 @@ loss 全局校准是指,无论使用多少张显卡,无论使用什么并行
9191import torch
9292import torch.nn as nn
9393import torch.nn.functional as F
94- from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
94+ from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext
9595from mmengine.dist import infer_launcher, init_dist
9696import torch.distributed as dist
9797
@@ -131,10 +131,10 @@ hidden_states = torch.chunk(hidden_states, world_size, dim=0)[rank]
131131shifted_labels = torch.chunk(shifted_labels_gt, world_size, dim = 0 )[rank]
132132hidden_states = hidden_states.unsqueeze(0 )
133133shifted_labels = shifted_labels.unsqueeze(0 )
134- loss_ctx_input_list = [CELossContextInputItem(shifted_labels = shifted_labels)]
135134loss_cfg = CELossConfig(mode = ' chunk' , chunk_size = 1024 , loss_reduction = " token" )
136- batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
137- loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0 ])
135+ loss_ctx = loss_cfg.build(shifted_labels)
136+ loss_ctx_list = CELossContext.build_batches([loss_ctx])
137+ loss_ctx = loss_ctx_list[0 ]
138138loss, _ = loss_ctx.forward(hidden_states, lm_head.weight)
139139loss.backward()
140140
@@ -198,7 +198,7 @@ XTuner 中所有的 loss 计算均涉及两个核心组件 `LossConfig` 和 `Los
198198``` python
199199import torch
200200import torch.nn as nn
201- from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem, CELossContext
201+ from xtuner.v1.loss.ce_loss import CELossConfig, CELossContext
202202
203203emb = nn.Embedding(4 , 2 )
204204head = nn.Linear(2 , 4 , bias = False )
@@ -210,8 +210,9 @@ hidden_states = emb(input_ids)
210210
211211loss_ctx_input_list = [CELossContextInputItem(shifted_labels = shifted_labels)]
212212loss_cfg = CELossConfig(mode = ' chunk' , chunk_size = 1024 , loss_reduction = " token" )
213- batches_loss_kwargs = CELossContext.build_batches_loss_kwargs(loss_ctx_input_list, loss_cfg)
214- loss_ctx = CELossContext(loss_cfg, batches_loss_kwargs[0 ])
213+ loss_ctx = loss_cfg.build(shifted_labels = data[" shifted_labels" ])
214+ loss_ctx_list = CELossContext.build_batches([loss_ctx])
215+ loss_ctx = loss_ctx_list[0 ]
215216loss, _ = loss_ctx.forward(hidden_states, head.weight)
216217loss.backward()
217218```
@@ -235,14 +236,13 @@ class CELossConfig:
235236
236237### CELossContext
237238
238- 在 ` CELossContext ` 中我们引入了额外的两个数据结构 :[ ` CELossKwargs ` ] ( xtuner.v1.loss.ce_loss.CELossKwargs ) 和 [ ` CELossContextInputItem ` ] ( xtuner.v1.loss.ce_loss.CELossContextInputItem ) 。
239+ 在 ` CELossContext ` 中我们引入了额外的一个数据结构 :[ ` CELossKwargs ` ] ( xtuner.v1.loss.ce_loss.CELossKwargs ) 。
239240
240241- ` CELossKwargs ` 表示 CE Loss 实际计算的时候需要用到哪些参数,即:` shifted_labels ` 和 ` loss_weight ` 两项,注意此时的 ` loss_weight ` 已经经历过全局校准的处理了,详细实现请参考 ` xtuner/v1/loss/ce_loss.py ` 。
241- - ` CELossContextInputItem ` 则表示计算出 ` CELossKwargs ` 需要哪些信息,即:` shifted_labels `
242242
243243我们在 ` CELossContext ` 中只需要实现两个接口:
244244
245- 1 . 为了做 loss 全局校准,classmethod ` build_batches_loss_kwargs ` 输入梯度累积范围内每一条数据对应的 ` CELossContextInputItem ` ,并计算出每一个 iter 的 ` CELossKwargs ` 。
245+ 1 . 为了做 loss 全局校准,staticmethod ` build_batches ` 计算全局校准对应的loss weight 。
2462462 . ` loss_fn ` 根据 ` CELossKwargs ` 计算出当前 iter 的 loss。
247247
248248对于其他功能(如:chunk loss),不同 loss 都是通用的,我们统一放到 ` BaseLossContext ` 里实现。
@@ -281,25 +281,22 @@ class CustomLossKwargs(BaseLossKwargs):
281281 ...
282282```
283283
284- 第二步,继承 ` BaseLossContext ` 并实现 ` CustomLossContext ` 中的 classmethod ` build_batches_loss_kwargs ` 和 ` loss_fn ` :
284+ 第二步,继承 ` BaseLossContext ` 并实现 ` CustomLossContext ` 中的 classmethod ` build_batches ` 和 ` loss_fn ` :
285285
286286``` python
287287from xtuner.v1.loss import BaseLossContext, BaseLossKwargs
288- from xtuner.v1.loss.ce_loss import CELossContextInputItem
289288
290- class CustomLossContext (BaseLossContext[CELossContextInputItem] ):
289+ class CustomLossContext (BaseLossContext ):
291290 loss_cfg: CustomLossConfig
292291 loss_kwargs: CustomLossKwargs
293292
294- @ classmethod
295- def build_batches_loss_kwargs (
296- cls ,
297- data_batches : list[RLLossContextInputItem],
298- loss_cfg : CustomLossConfig,
293+ @ staticmethod
294+ def build_batches (
295+ loss_ctx_list : list[" CELossContext" ],
299296 # 为了提高计算效率,XTuner 会将多条短数据 pack 成一条长数据进行训练
300297 # 若在计算 CustomLossKwargs 的过程中需要解 pack 成若干短数据,则需要传入 cu_seq_lens_list
301298 # 默认为 None 即可。
302- cu_seq_lens_list : list [torch.Tensor ] | None = None ,
299+ cu_seq_lens_list : Sequence [torch.IntTensor ] | None = None ,
303300 # 若开启了序列并行 (sp) 且计算 CustomLossKwargs 的过程中需要 sp 切分前的数据,则需要传入 cu_seq_lens_list
304301 # 默认为 None 即可。
305302 sp_mesh : DeviceMesh | None = None ,
0 commit comments