Skip to content

Commit c00811b

Browse files
committed
Add rope_scaling and rope_theta to megatron workers
1 parent 7ef71a2 commit c00811b

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

trinity/trainer/verl/megatron_workers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,82 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
151151
)
152152
self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)
153153

154+
def _init_hf_config_and_tf_config(
155+
self,
156+
model_path,
157+
tokenizer_or_path,
158+
dtype,
159+
override_model_config,
160+
override_transformer_config,
161+
trust_remote_code=False,
162+
use_mbridge=False,
163+
):
164+
from transformers import AutoConfig
165+
from verl.models.mcore import hf_to_mcore_config
166+
from verl.utils import hf_processor, hf_tokenizer
167+
from verl.utils.fs import copy_to_local
168+
from verl.utils.model import update_model_config
169+
170+
# Step 1: initialize the tokenizer
171+
self.local_path = copy_to_local(model_path)
172+
if tokenizer_or_path is None:
173+
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)
174+
self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)
175+
elif isinstance(tokenizer_or_path, str):
176+
self.tokenizer = hf_tokenizer(
177+
copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code
178+
)
179+
self.processor = hf_processor(
180+
copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code
181+
)
182+
else:
183+
self.tokenizer = tokenizer_or_path
184+
self.processor = tokenizer_or_path
185+
186+
if self.config.model.get("custom_chat_template", None) is not None:
187+
if self.processor is not None:
188+
self.processor.chat_template = self.config.model.custom_chat_template
189+
else:
190+
self.tokenizer.chat_template = self.config.model.custom_chat_template
191+
192+
# Step 2: get the hf
193+
hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)
194+
195+
# Step 3: override the hf config
196+
override_config_kwargs = {
197+
"bos_token_id": self.tokenizer.bos_token_id,
198+
"eos_token_id": self.tokenizer.eos_token_id,
199+
"pad_token_id": self.tokenizer.pad_token_id,
200+
}
201+
override_config_kwargs.update(override_model_config.get("model_config", {}))
202+
203+
# patch for rope
204+
if self.config.model.rope_scaling is not None:
205+
hf_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
206+
if self.config.model.rope_theta is not None:
207+
hf_config.rope_theta = self.config.model.rope_theta
208+
209+
self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
210+
update_model_config(hf_config, override_config_kwargs=override_config_kwargs)
211+
self.architectures = getattr(hf_config, "architectures", None)
212+
if self.rank == 0:
213+
print(f"Model config after override: {hf_config}")
214+
tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)
215+
216+
if use_mbridge:
217+
from verl.models.mcore.mbridge import AutoBridge
218+
219+
bridge = AutoBridge.from_config(hf_config)
220+
bridge.set_extra_args(**override_transformer_config)
221+
tf_config = bridge.config
222+
self.bridge = bridge
223+
else:
224+
self.bridge = None
225+
226+
print(f"TF config: {tf_config}")
227+
self.hf_config = hf_config
228+
self.tf_config = tf_config
229+
154230
def _build_model_optimizer(
155231
self, model_path, optim_config, override_model_config, override_transformer_config
156232
):

0 commit comments

Comments
 (0)