Skip to content

[QUESTION] MoE score precision #2741

@Daniel5103

Description

@Daniel5103

In megatron/core/transformer/moe/moe_utils.py

compute_routing_scores_for_aux_loss
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")

Why softmax in FP32, but sigmoid in original dtype?
Should it be changed to the same high precision?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions