-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Open
Labels
community-requestmodule: moequestionFurther information is requestedFurther information is requested
Description
In megatron/core/transformer/moe/moe_utils.py
compute_routing_scores_for_aux_loss
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")
Why softmax in FP32, but sigmoid in original dtype?
Should it be changed to the same high precision?
Skylion007
Metadata
Metadata
Assignees
Labels
community-requestmodule: moequestionFurther information is requestedFurther information is requested