Skip to content

Conversation

@qianhao0713
Copy link

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

support olmoe and training with megatron.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qianhao0713, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the model support by integrating OLMoE models into the framework. It provides the necessary definitions for the OLMoE model type and its associated templates, alongside a robust integration with Megatron for distributed training. This includes specialized attention mechanisms and state bridging, as well as enabling DeepSpeed ZeRO-3 optimizations for efficient training of these Mixture-of-Experts models.

Highlights

  • OLMoE Model Support: Introduced the olmoe model type across the system, enabling recognition and usage of OLMoE models.
  • Megatron Integration: Implemented specific Megatron configurations for OLMoE, including a custom attention mechanism (OLMoESelfAttention) and a bridge (OLMoEBridge) for seamless state transfer between Hugging Face and Megatron formats.
  • DeepSpeed ZeRO-3 Optimization: Configured DeepSpeed ZeRO-3 to recognize OlmoeSparseMoeBlock as a leaf module, allowing for efficient memory management during training.
  • Dedicated Template: Added a new template specifically for OLMoE models, ensuring correct input/output formatting and interaction.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the OLMoE model. The changes are mostly correct, but I've found a critical issue in the Megatron bridge for weight conversion and an incorrect chat template definition. The weight conversion logic would cause a runtime error due to incorrect tensor reshaping, and the wrong template would lead to poor model performance. I've provided suggestions to fix both issues. Please review them carefully.

Comment on lines 135 to 248
def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
hf_attn = self.hf_layers[layer_idx].self_attn
args = self.args
num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
hidden_size_block = args.hidden_size // self.fp8_block_size
if to_mcore:
if isinstance(mg_attn.linear_qkv, LoraParallelLinear):
lora_A = hf_state_dict['q_proj.lora_A.weight'].load()
assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and (
lora_A == hf_state_dict['v_proj.lora_A.weight'].load()
).all(), 'Need to ensure QKV\'s lora_A are consistent'
q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load()
lora_B = torch.cat([
q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])),
hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])),
hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])),
],
dim=0).reshape((-1, q_lora_B.shape[-1]))
self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A,
'linear_qkv.lora_A.weight')
self._set_weight(mg_attn.linear_qkv.lora_B[self._adapter_name].weight, lora_B,
'linear_qkv.lora_B.weight')
else:
linear_qkv_weight = torch.cat([
hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)),
hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)),
hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, args.hidden_size)),
],
dim=0).reshape((-1, args.hidden_size))
qkv_scale_inv = None
if 'q_proj.weight_scale_inv' in hf_state_dict:
qkv_scale_inv = torch.cat([
hf_state_dict['q_proj.weight_scale_inv'].load().reshape(
(num_query_groups, -1, hidden_size_block)),
hf_state_dict['k_proj.weight_scale_inv'].load().reshape(
(num_query_groups, -1, hidden_size_block)),
hf_state_dict['v_proj.weight_scale_inv'].load().reshape(
(num_query_groups, -1, hidden_size_block)),
],
dim=0).reshape((-1, hidden_size_block))
self._set_weight(
mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv)
else:
q_dim, kv_dim = hf_attn.q_proj.weight.shape[0] // num_query_groups, hf_attn.k_proj.weight.shape[
0] // num_query_groups
q_block = q_dim // self.fp8_block_size
kv_block = kv_dim // self.fp8_block_size
is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv,
LoraParallelLinear) and self._is_peft_format
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
if self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
if is_lora:
lora_A, _ = self._get_weight(
None if mg_attn is None else mg_attn.linear_qkv.lora_A[self._adapter_name].weight.data,
f'linear_qkv.lora_A.{self._adapter_name}.weight')
lora_B, _ = self._get_weight(
None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data,
f'linear_qkv.lora_B.{self._adapter_name}.weight')
if lora_A is not None:
self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'})
for key in ['q_proj', 'k_proj', 'v_proj']:
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
hf_state_dict['q_proj.lora_B.weight'] = lora_B[:num_query_groups * q_dim, :].clone()
hf_state_dict['k_proj.lora_B.weight'] = lora_B[num_query_groups * q_dim:-num_query_groups
* kv_dim, :].clone()
hf_state_dict['v_proj.lora_B.weight'] = lora_B[-num_query_groups * kv_dim, :].clone()
elif not self._is_peft_format:
mg_attn_weight, scale_inv = self._get_weight(
None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight')
if mg_attn_weight is not None:
hf_state_dict['q_proj.weight'] = mg_attn_weight[:num_query_groups * q_dim, :].clone()
hf_state_dict['k_proj.weight'] = mg_attn_weight[num_query_groups * q_dim:-num_query_groups
* kv_dim, :].clone()
hf_state_dict['v_proj.weight'] = mg_attn_weight[-num_query_groups * kv_dim:, :].clone()
if scale_inv is not None:
hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:num_query_groups * q_block, :].clone()
hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[num_query_groups * q_block:-num_query_groups
* kv_block, :].clone()
hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[-num_query_groups * kv_block:, :].clone()
del mg_attn_weight
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
if args.add_qkv_bias and not self._is_peft_format:
if to_mcore:
linear_qkv_bias = torch.cat([
hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)),
hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)),
hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)),
],
dim=0).reshape(-1)
self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias')
else:
mg_attn_bias, _ = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data,
'linear_qkv.bias')
if mg_attn_bias is not None:
mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1))
hf_state_dict['q_proj.bias'] = mg_attn_bias[:num_query_groups * q_dim].clone()
hf_state_dict['k_proj.bias'] = mg_attn_bias[num_query_groups * q_dim:-num_query_groups
* kv_dim].clone()
hf_state_dict['v_proj.bias'] = mg_attn_bias[-num_query_groups * kv_dim:].clone()
if args.qk_layernorm:
hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight'
hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight'
self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore)
self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore)
if to_mcore:
hf_state_dict = {}
else:
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for concatenating Q, K, and V weights in _set_attn_state is incorrect and will cause a runtime error. The reshape calls before torch.cat are problematic because Q, K, and V projections have different output feature dimensions in Grouped Query Attention (GQA), leading to incompatible tensor shapes for concatenation on dim=0.

The correct approach is to directly concatenate the weights on dim=0 to form a stacked QKV weight matrix. This applies to lora_B, linear_qkv_weight, qkv_scale_inv, and linear_qkv_bias when converting from HuggingFace to Megatron format (to_mcore=True).

I've provided a corrected implementation of the function below.

    def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
        if to_mcore:
            hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
        else:
            hf_state_dict = {}
        hf_attn = self.hf_layers[layer_idx].self_attn
        args = self.args
        num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
        hidden_size_block = args.hidden_size // self.fp8_block_size
        if to_mcore:
            if isinstance(mg_attn.linear_qkv, LoraParallelLinear):
                lora_A = hf_state_dict['q_proj.lora_A.weight'].load()
                assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and (
                    lora_A == hf_state_dict['v_proj.lora_A.weight'].load()
                ).all(), 'Need to ensure QKV\'s lora_A are consistent'
                lora_B = torch.cat([
                    hf_state_dict['q_proj.lora_B.weight'].load(),
                    hf_state_dict['k_proj.lora_B.weight'].load(),
                    hf_state_dict['v_proj.lora_B.weight'].load(),
                ], dim=0)
                self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A,
                                 'linear_qkv.lora_A.weight')
                self._set_weight(mg_attn.linear_qkv.lora_B[self._adapter_name].weight, lora_B,
                                 'linear_qkv.lora_B.weight')
            else:
                linear_qkv_weight = torch.cat([
                    hf_state_dict['q_proj.weight'].load(),
                    hf_state_dict['k_proj.weight'].load(),
                    hf_state_dict['v_proj.weight'].load(),
                ], dim=0)
                qkv_scale_inv = None
                if 'q_proj.weight_scale_inv' in hf_state_dict:
                    qkv_scale_inv = torch.cat([
                        hf_state_dict['q_proj.weight_scale_inv'].load(),
                        hf_state_dict['k_proj.weight_scale_inv'].load(),
                        hf_state_dict['v_proj.weight_scale_inv'].load(),
                    ], dim=0)
                self._set_weight(
                    mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv)
        else:
            q_dim, kv_dim = hf_attn.q_proj.weight.shape[0], hf_attn.k_proj.weight.shape[0]
            q_block = q_dim // self.fp8_block_size
            kv_block = kv_dim // self.fp8_block_size
            is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv,
                                                               LoraParallelLinear) and self._is_peft_format
            is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
            if self.pp_size > 1:
                dist.all_reduce(is_lora, group=self.pp_group)
            if is_lora:
                lora_A, _ = self._get_weight(
                    None if mg_attn is None else mg_attn.linear_qkv.lora_A[self._adapter_name].weight.data,
                    f'linear_qkv.lora_A.{self._adapter_name}.weight')
                lora_B, _ = self._get_weight(
                    None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data,
                    f'linear_qkv.lora_B.{self._adapter_name}.weight')
                if lora_A is not None:
                    self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'})
                    for key in ['q_proj', 'k_proj', 'v_proj']:
                        hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
                    hf_state_dict['q_proj.lora_B.weight'] = lora_B[:q_dim, :].clone()
                    hf_state_dict['k_proj.lora_B.weight'] = lora_B[q_dim:-kv_dim, :].clone()
                    hf_state_dict['v_proj.lora_B.weight'] = lora_B[-kv_dim:, :].clone()
            elif not self._is_peft_format:
                mg_attn_weight, scale_inv = self._get_weight(
                    None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight')
                if mg_attn_weight is not None:
                    hf_state_dict['q_proj.weight'] = mg_attn_weight[:q_dim, :].clone()
                    hf_state_dict['k_proj.weight'] = mg_attn_weight[q_dim:-kv_dim, :].clone()
                    hf_state_dict['v_proj.weight'] = mg_attn_weight[-kv_dim:, :].clone()
                if scale_inv is not None:
                    hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:q_block, :].clone()
                    hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[q_block:-kv_block, :].clone()
                    hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[-kv_block:, :].clone()
                del mg_attn_weight
        self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, 'o_proj.weight', to_mcore)
        if args.add_qkv_bias and not self._is_peft_format:
            if to_mcore:
                linear_qkv_bias = torch.cat([
                    hf_state_dict['q_proj.bias'].load(),
                    hf_state_dict['k_proj.bias'].load(),
                    hf_state_dict['v_proj.bias'].load(),
                ], dim=0)
                self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias')
            else:
                mg_attn_bias, _ = self._get_weight(None if mg_attn is None else mg_attn.linear_qkv.bias.data,
                                                   'linear_qkv.bias')
                if mg_attn_bias is not None:
                    hf_state_dict['q_proj.bias'] = mg_attn_bias[:q_dim].clone()
                    hf_state_dict['k_proj.bias'] = mg_attn_bias[q_dim:-kv_dim].clone()
                    hf_state_dict['v_proj.bias'] = mg_attn_bias[-kv_dim:].clone()
        if args.qk_layernorm:
            hf_q_norm_key = 'q_norm.weight' if hasattr(hf_attn, 'q_norm') else 'query_layernorm.weight'
            hf_k_norm_key = 'k_norm.weight' if hasattr(hf_attn, 'k_norm') else 'key_layernorm.weight'
            self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, hf_q_norm_key, to_mcore)
            self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, hf_k_norm_key, to_mcore)
        if to_mcore:
            hf_state_dict = {}
        else:
            hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
        return hf_state_dict

Comment on lines 415 to 425
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=['|||IP_ADDRESS|||'],
system_prefix=['|||IP_ADDRESS|||<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['|||IP_ADDRESS|||\n'],
suffix=['|||IP_ADDRESS|||'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The chat template for OLMoE appears to be incorrect. According to the official model documentation, the template should use <|endoftext|> to separate messages, and does not use |||IP_ADDRESS||| as a structural token. The |||IP_ADDRESS||| token is intended for PII masking. Using an incorrect template can lead to degraded model performance. Here is a corrected version of the template registration.

Suggested change
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=['|||IP_ADDRESS|||'],
system_prefix=['|||IP_ADDRESS|||<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['|||IP_ADDRESS|||\n'],
suffix=['|||IP_ADDRESS|||'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=[],
system_prefix=['<|system|>\n{{SYSTEM}}<|endoftext|>'],
prompt=['<|user|>\n{{QUERY}}<|endoftext|><|assistant|>\n'],
chat_sep=['<|endoftext|>'],
suffix=['<|endoftext|>'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))


class OLMoEBridge(GPTBridge):

def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, Could you tell me Why override this function?

Copy link
Author

@qianhao0713 qianhao0713 Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One difference in model architecture between Olmoe and other GPT-like models is that Olmoe applies LayerNorm on the QKV tensor across the n_head * head_dim dimension, rather than on the head_dim dimension.
In the Megatron GPT implementation, the QKV is first projected via the linear_qkv module into a matrix of shape (batch_size * seq_len * n_head, head_dim), and then LayerNorm is applied. In contrast, in Olmoe, the tensor must be reshaped to (batch_size * seq_len, n_head * head_dim) before applying LayerNorm.
This difference means that when converting weights from Hugging Face to Megatron format, the weights of the linear_qkv layer must be concatenated along the n_head*head_dim dimension rather than the head_dim dimension.
The points mentioned above are the reasons why the GPTBridge needs to be overrided

@qianhao0713
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the OLMoE model, including model registration, a new prompt template, and integration with Megatron for training. The changes are generally well-implemented and follow existing patterns in the codebase. However, I've identified a significant issue in the prompt template definition for olmoe where an invalid placeholder |||IP_ADDRESS||| is used. This will likely cause incorrect prompt formatting and negatively impact model performance. I've provided a corrected version of the template that aligns with the OLMo model's official chat template.

Comment on lines 428 to 438
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=['|||IP_ADDRESS|||'],
system_prefix=['|||IP_ADDRESS|||<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['|||IP_ADDRESS|||\n'],
suffix=['|||IP_ADDRESS|||'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The provided template for olmoe uses |||IP_ADDRESS||| as a placeholder in prefix, system_prefix, chat_sep, and suffix. This appears to be a copy-paste error, as it's not a valid token for this model and will lead to incorrect prompt formatting.

Based on the official chat template for OLMo models, I suggest replacing the entire register_template block with a corrected version that uses the appropriate special tokens and structure.

Suggested change
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=['|||IP_ADDRESS|||'],
system_prefix=['|||IP_ADDRESS|||<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['|||IP_ADDRESS|||\n'],
suffix=['|||IP_ADDRESS|||'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=[],
system_prefix=['<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['<|endoftext|>\n'],
suffix=['<|endoftext|>'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))

@qianhao0713
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the OLMoE model series, including its integration with Megatron for training. The changes span across model registration, template definitions, and Megatron-specific implementations.

My review focuses on the correctness of the model and template registrations. While most of the implementation looks good, I've found significant issues with the chat templates for both olmoe and olmoe_0924 models, which do not align with the official Hugging Face implementations. I've provided suggestions to correct them. I also found a minor code cleanup opportunity in the new Megatron bridge file.

Please review the suggested changes to ensure the models behave as expected.

Comment on lines 428 to 450
register_template(
TemplateMeta(
LLMTemplateType.olmoe,
prefix=['|||IP_ADDRESS|||'],
system_prefix=['|||IP_ADDRESS|||<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['|||IP_ADDRESS|||\n'],
suffix=['|||IP_ADDRESS|||'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))

register_template(
TemplateMeta(
LLMTemplateType.olmoe_0924,
prefix=[],
system_prefix=['<|system|>\n{{SYSTEM}}\n'],
prompt=['<|user|>\n{{QUERY}}\n<|assistant|>\n'],
chat_sep=['<|endoftext|>\n'],
suffix=['<|endoftext|>'],
default_system='You are a helpful assistant.',
stop_words=['<|endoftext|>'],
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The templates for olmoe and olmoe_0924 are not correctly implemented according to their respective Hugging Face model cards and tokenizer_config.json files.

  • The olmoe template incorrectly uses |||IP_ADDRESS||| as a separator. This token is used for PII redaction in the training data, not as a chat separator.
  • Both templates do not correctly place the <|endoftext|> token after each message part (system, user).
  • The olmoe template should have a newline \n after <|endoftext|>, while olmoe_0924 should not, based on their respective chat_template definitions.

Here is a corrected implementation for both templates that aligns with the official chat templates:

register_template(
    TemplateMeta(
        LLMTemplateType.olmoe,
        prefix=[],
        system_prefix=['<|system|>\n{{SYSTEM}}<|endoftext|>\n'],
        prompt=['<|user|>\n{{QUERY}}<|endoftext|>\n<|assistant|>\n'],
        chat_sep=['<|endoftext|>\n'],
        suffix=['<|endoftext|>\n'],
        default_system='You are a helpful assistant.',
        stop_words=['<|endoftext|>'],
    ))

register_template(
    TemplateMeta(
        LLMTemplateType.olmoe_0924,
        prefix=[],
        system_prefix=['<|system|>\n{{SYSTEM}}<|endoftext|>'],
        prompt=['<|user|>\n{{QUERY}}<|endoftext|><|assistant|>\n'],
        chat_sep=['<|endoftext|>'],
        suffix=['<|endoftext|>'],
        default_system='You are a helpful assistant.',
        stop_words=['<|endoftext|>'],
    ))

hf_state_dict = {}
hf_attn = self.hf_layers[layer_idx].self_attn
args = self.args
num_query_groups = (args.num_query_groups if args.group_query_attention else args.num_attention_heads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable num_query_groups is defined but never used within this method. It can be safely removed to improve code clarity.

@Jintao-Huang
Copy link
Collaborator

hello. please run:

pip install pre-commit
pre-commit run --all-files

@Jintao-Huang
Copy link
Collaborator

There are incompatibility issues with more 0.15

截屏2025-12-25 16 41 24

@Jintao-Huang
Copy link
Collaborator

截屏2025-12-25 16 47 17

Hello, it seems the forward pass precision is not aligned.

@Jintao-Huang
Copy link
Collaborator

def test_olmoe():
    _test_model('allenai/OLMoE-1B-7B-0924-Instruct')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants