@@ -86,6 +86,7 @@ def __init__(
86
86
super (BaseRLModel , self ).__init__ ()
87
87
self .infer_to_train_mapping = {}
88
88
self .fd_config = None
89
+ self ._mappings_built = False
89
90
90
91
@classmethod
91
92
def name (cls ) -> str :
@@ -142,6 +143,12 @@ def name(self) -> str:
142
143
143
144
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
144
145
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
146
+ if self ._mappings_built :
147
+ return self .infer_to_train_mapping
148
+
149
+ self .infer_to_train_mapping = {}
150
+ self ._mappings_built = True
151
+
145
152
# Prepare placeholders
146
153
place_holders = ["weight" ]
147
154
@@ -215,6 +222,11 @@ def name(self) -> str:
215
222
216
223
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
217
224
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
225
+ if self ._mappings_built :
226
+ return self .infer_to_train_mapping
227
+
228
+ self .infer_to_train_mapping = {}
229
+ self ._mappings_built = True
218
230
# Prepare placeholders
219
231
place_holders = ["weight" ]
220
232
@@ -316,6 +328,11 @@ def name(self) -> str:
316
328
317
329
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
318
330
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
331
+ if self ._mappings_built :
332
+ return self .infer_to_train_mapping
333
+
334
+ self .infer_to_train_mapping = {}
335
+ self ._mappings_built = True
319
336
# Prepare placeholders
320
337
place_holders = ["weight" ]
321
338
@@ -360,6 +377,11 @@ def name(self) -> str:
360
377
361
378
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
362
379
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
380
+ if self ._mappings_built :
381
+ return self .infer_to_train_mapping
382
+
383
+ self .infer_to_train_mapping = {}
384
+ self ._mappings_built = True
363
385
# Prepare placeholders
364
386
place_holders = ["weight" ]
365
387
@@ -429,4 +451,30 @@ def name(self) -> str:
429
451
return "Qwen3ForCausalLMRL"
430
452
431
453
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
432
- pass
454
+
455
+ if self ._mappings_built :
456
+ return self .infer_to_train_mapping
457
+
458
+ self .infer_to_train_mapping = {}
459
+ self ._mappings_built = True
460
+ # Prepare placeholders
461
+ place_holders = ["weight" ]
462
+
463
+ # Initialize mapping dictionary
464
+ self ._update_base_mappings ("model" )
465
+ base_name = "model.layers"
466
+
467
+ # Helper function to add layer mappings
468
+ def _add_layer_mappings (layer_idx ):
469
+ # FFN mappings
470
+ for ph in place_holders :
471
+ self .infer_to_train_mapping [f"{ base_name } .{ layer_idx } .mlp.up_gate_proj.{ ph } " ] = (
472
+ f"{ base_name } .{ layer_idx } .mlp.gate_up_fused_proj.{ ph } "
473
+ )
474
+
475
+ for layer_idx in range (self .fd_config .model_config .num_hidden_layers ):
476
+ _add_layer_mappings (layer_idx )
477
+
478
+ self ._complete_missing_mappings ()
479
+
480
+ return self .infer_to_train_mapping
0 commit comments