Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen2/qwen2_liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
except:
print("Liger Kernel is not installed, pip install liger-kernel to use this patch")
import torch
import torch.distributed as dist

from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -143,7 +145,14 @@ def qwen2_lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= loss_kwargs["num_items_in_batch"]
Expand Down
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen2_5_omni/qwen2_5_omni_liger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
Qwen2_5OmniThinkerCausalLMOutputWithPast,
Qwen2_5OmniThinkerForConditionalGeneration,
Expand All @@ -10,6 +11,7 @@
from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -253,7 +255,14 @@ def lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= kwargs["num_items_in_batch"]
Expand Down
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen2_5_vl/qwen2_5_vl_liger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from transformers import Qwen2_5_VLForConditionalGeneration
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLCausalLMOutputWithPast,
Expand All @@ -9,6 +10,7 @@
from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -125,7 +127,14 @@ def lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= kwargs["num_items_in_batch"]
Expand Down
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen3/qwen3_liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
except:
print("Liger Kernel is not installed, pip install liger-kernel to use this patch")
import torch
import torch.distributed as dist

from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -143,7 +145,14 @@ def qwen3_lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= loss_kwargs["num_items_in_batch"]
Expand Down
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen3_omni_moe/qwen3_omni_moe_liger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
Qwen3OmniMoeThinkerCausalLMOutputWithPast,
Qwen3OmniMoeThinkerForConditionalGeneration,
Expand All @@ -11,6 +12,7 @@
from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -266,7 +268,14 @@ def lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= kwargs["num_items_in_batch"]
Expand Down
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen3_vl/qwen3_vl_liger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from loguru import logger
from transformers import Qwen3VLForConditionalGeneration
from transformers.cache_utils import Cache
Expand All @@ -9,6 +10,7 @@
from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -121,7 +123,14 @@ def qwen3_vl_lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= kwargs["num_items_in_batch"]
Expand Down
11 changes: 10 additions & 1 deletion src/lmms_engine/models/qwen3_vl_moe/qwen3_vl_moe_liger.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from transformers.cache_utils import Cache
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeCausalLMOutputWithPast,
Expand All @@ -11,6 +12,7 @@
from lmms_engine.parallel.sequence_parallel.ulysses import (
calculate_seq_len_per_rank,
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
pad_to_max_across_ranks,
slice_input_tensor,
Expand Down Expand Up @@ -112,7 +114,14 @@ def lce_forward(
# Pad to max size across ranks, then gather and unpad
loss, total_padding = pad_to_max_across_ranks(loss, dim=0)
loss = gather_outputs_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=total_padding)
loss = torch.sum(loss) / (torch.sum(attention_mask) + 1e-8)
# Calculate the actual number of valid tokens (non-ignored labels) across all ranks
# shift_labels shape is (num_tokens,) after flatten, -100 means ignore
num_valid_tokens = (shift_labels != -100).sum().float()
# Gather num_valid_tokens across all SP ranks to get the total count
sp_group = get_ulysses_sequence_parallel_group()
if sp_group is not None:
dist.all_reduce(num_valid_tokens, op=dist.ReduceOp.SUM, group=sp_group)
loss = torch.sum(loss) / (num_valid_tokens + 1e-8)

if reduction == "sum":
loss /= kwargs["num_items_in_batch"]
Expand Down