@@ -89,6 +89,7 @@ def __init__(
89
89
super (BaseRLModel , self ).__init__ ()
90
90
self .infer_to_train_mapping = {}
91
91
self .fd_config = None
92
+ self ._mappings_built = False
92
93
93
94
@classmethod
94
95
def name (cls ) -> str :
@@ -145,6 +146,12 @@ def name(self) -> str:
145
146
146
147
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
147
148
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
149
+ if self ._mappings_built :
150
+ return self .infer_to_train_mapping
151
+
152
+ self .infer_to_train_mapping = {}
153
+ self ._mappings_built = True
154
+
148
155
# Prepare placeholders
149
156
place_holders = ["weight" ]
150
157
@@ -218,6 +225,11 @@ def name(self) -> str:
218
225
219
226
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
220
227
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
228
+ if self ._mappings_built :
229
+ return self .infer_to_train_mapping
230
+
231
+ self .infer_to_train_mapping = {}
232
+ self ._mappings_built = True
221
233
# Prepare placeholders
222
234
place_holders = ["weight" ]
223
235
@@ -319,6 +331,11 @@ def name(self) -> str:
319
331
320
332
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
321
333
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
334
+ if self ._mappings_built :
335
+ return self .infer_to_train_mapping
336
+
337
+ self .infer_to_train_mapping = {}
338
+ self ._mappings_built = True
322
339
# Prepare placeholders
323
340
place_holders = ["weight" ]
324
341
@@ -363,6 +380,11 @@ def name(self) -> str:
363
380
364
381
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
365
382
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
383
+ if self ._mappings_built :
384
+ return self .infer_to_train_mapping
385
+
386
+ self .infer_to_train_mapping = {}
387
+ self ._mappings_built = True
366
388
# Prepare placeholders
367
389
place_holders = ["weight" ]
368
390
@@ -432,6 +454,11 @@ def name(self) -> str:
432
454
return "Qwen3ForCausalLMRL"
433
455
434
456
def get_name_mappings_to_training (self , trainer_degree = None ) -> Dict [str , str ]:
457
+ if self ._mappings_built :
458
+ return self .infer_to_train_mapping
459
+
460
+ self .infer_to_train_mapping = {}
461
+ self ._mappings_built = True
435
462
# Prepare placeholders
436
463
place_holders = ["weight" ]
437
464
0 commit comments