@@ -294,6 +294,10 @@ class LlmMetaConfig:
294
294
),
295
295
]
296
296
297
+ moe_attributes = [
298
+ ("moe_subbatch_token_num" , int , 0 , "The number of tokens in each subbatch for MoE model processing." ),
299
+ ]
300
+
297
301
@classmethod
298
302
def _get_defaults (cls ):
299
303
ret = {}
@@ -302,6 +306,7 @@ def _get_defaults(cls):
302
306
cls .hybrid_parallel_attributes ,
303
307
cls .recompute_attributes ,
304
308
cls .loss_attributes ,
309
+ cls .moe_attributes ,
305
310
]:
306
311
for attr in attrs :
307
312
# return dict of key and default values
@@ -316,6 +321,7 @@ def _get_all_meta(cls):
316
321
cls .hybrid_parallel_attributes ,
317
322
cls .recompute_attributes ,
318
323
cls .loss_attributes ,
324
+ cls .moe_attributes ,
319
325
]:
320
326
for attr in attrs :
321
327
# return dict of key and default values
@@ -330,6 +336,7 @@ def _get_unsavable_keys(cls):
330
336
cls .hybrid_parallel_attributes ,
331
337
cls .recompute_attributes ,
332
338
cls .loss_attributes ,
339
+ cls .moe_attributes ,
333
340
]:
334
341
for attr in attrs :
335
342
ret .add (attr [0 ])
@@ -488,6 +495,8 @@ class PretrainedConfig:
488
495
problem_type (`str`, *optional*):
489
496
Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
490
497
`"single_label_classification"` or `"multi_label_classification"`.
498
+ moe_subbatch_token_num (`int`, *optional*, defaults to 0):
499
+ The number of tokens in a subbatch for MoE.
491
500
492
501
> Parameters for general components
493
502
@@ -632,6 +641,8 @@ def __init__(self, **kwargs):
632
641
self .dpo_config = kwargs .pop ("dpo_config" , None )
633
642
self .kto_config = kwargs .pop ("kto_config" , None )
634
643
644
+ self .num_subbatch_token_num = kwargs .pop ("num_subbatch_token_num" , 0 )
645
+
635
646
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
636
647
self .tokenizer_class = kwargs .pop ("tokenizer_class" , None )
637
648
self .prefix = kwargs .pop ("prefix" , None )
0 commit comments