Skip to content

Commit b4bb54b

Browse files
authored
bugfix (#3322)
1 parent eeec4bd commit b4bb54b

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

fastdeploy/rl/rollout_model.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
super(BaseRLModel, self).__init__()
8787
self.infer_to_train_mapping = {}
8888
self.fd_config = None
89+
self._mappings_built = False
8990

9091
@classmethod
9192
def name(cls) -> str:
@@ -142,6 +143,12 @@ def name(self) -> str:
142143

143144
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
144145
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
146+
if self._mappings_built:
147+
return self.infer_to_train_mapping
148+
149+
self.infer_to_train_mapping = {}
150+
self._mappings_built = True
151+
145152
# Prepare placeholders
146153
place_holders = ["weight"]
147154

@@ -215,6 +222,11 @@ def name(self) -> str:
215222

216223
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
217224
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
225+
if self._mappings_built:
226+
return self.infer_to_train_mapping
227+
228+
self.infer_to_train_mapping = {}
229+
self._mappings_built = True
218230
# Prepare placeholders
219231
place_holders = ["weight"]
220232

@@ -316,6 +328,11 @@ def name(self) -> str:
316328

317329
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
318330
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
331+
if self._mappings_built:
332+
return self.infer_to_train_mapping
333+
334+
self.infer_to_train_mapping = {}
335+
self._mappings_built = True
319336
# Prepare placeholders
320337
place_holders = ["weight"]
321338

@@ -360,6 +377,11 @@ def name(self) -> str:
360377

361378
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
362379
"""Generate mapping between inference and training parameter for RL(donot delete!)."""
380+
if self._mappings_built:
381+
return self.infer_to_train_mapping
382+
383+
self.infer_to_train_mapping = {}
384+
self._mappings_built = True
363385
# Prepare placeholders
364386
place_holders = ["weight"]
365387

@@ -429,4 +451,30 @@ def name(self) -> str:
429451
return "Qwen3ForCausalLMRL"
430452

431453
def get_name_mappings_to_training(self, trainer_degree=None) -> Dict[str, str]:
432-
pass
454+
455+
if self._mappings_built:
456+
return self.infer_to_train_mapping
457+
458+
self.infer_to_train_mapping = {}
459+
self._mappings_built = True
460+
# Prepare placeholders
461+
place_holders = ["weight"]
462+
463+
# Initialize mapping dictionary
464+
self._update_base_mappings("model")
465+
base_name = "model.layers"
466+
467+
# Helper function to add layer mappings
468+
def _add_layer_mappings(layer_idx):
469+
# FFN mappings
470+
for ph in place_holders:
471+
self.infer_to_train_mapping[f"{base_name}.{layer_idx}.mlp.up_gate_proj.{ph}"] = (
472+
f"{base_name}.{layer_idx}.mlp.gate_up_fused_proj.{ph}"
473+
)
474+
475+
for layer_idx in range(self.fd_config.model_config.num_hidden_layers):
476+
_add_layer_mappings(layer_idx)
477+
478+
self._complete_missing_mappings()
479+
480+
return self.infer_to_train_mapping

0 commit comments

Comments
 (0)