Skip to content

Commit ccc7f1b

Browse files
authored
fix mapping (#3320)
1 parent 283da92 commit ccc7f1b

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
super(BaseRLModel, self).__init__()
9090
self.infer_to_train_mapping = {}
9191
self.fd_config = None
92+
self._mappings_built = False
9293

9394
@classmethod
9495
def name(cls) -> str:
@@ -145,6 +146,12 @@ def name(self) -> str:
145146

146147
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
147148
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
149+
if self._mappings_built:
150+
return self.infer_to_train_mapping
151+
152+
self.infer_to_train_mapping = {}
153+
self._mappings_built = True
154+
148155
# Prepare placeholders
149156
place_holders = ["weight"]
150157

@@ -218,6 +225,11 @@ def name(self) -> str:
218225

219226
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
220227
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
228+
if self._mappings_built:
229+
return self.infer_to_train_mapping
230+
231+
self.infer_to_train_mapping = {}
232+
self._mappings_built = True
221233
# Prepare placeholders
222234
place_holders = ["weight"]
223235

@@ -319,6 +331,11 @@ def name(self) -> str:
319331

320332
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
321333
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
334+
if self._mappings_built:
335+
return self.infer_to_train_mapping
336+
337+
self.infer_to_train_mapping = {}
338+
self._mappings_built = True
322339
# Prepare placeholders
323340
place_holders = ["weight"]
324341

@@ -363,6 +380,11 @@ def name(self) -> str:
363380

364381
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
365382
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
383+
if self._mappings_built:
384+
return self.infer_to_train_mapping
385+
386+
self.infer_to_train_mapping = {}
387+
self._mappings_built = True
366388
# Prepare placeholders
367389
place_holders = ["weight"]
368390

@@ -432,6 +454,11 @@ def name(self) -> str:
432454
return "Qwen3ForCausalLMRL"
433455

434456
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
457+
if self._mappings_built:
458+
return self.infer_to_train_mapping
459+
460+
self.infer_to_train_mapping = {}
461+
self._mappings_built = True
435462
# Prepare placeholders
436463
place_holders = ["weight"]
437464

0 commit comments

Comments
 (0)