Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/axolotl/integrations/diffusion/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from axolotl.utils.logging import get_logger

from .utils import create_bidirectional_attention_mask
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions

LOG = get_logger(__name__)

Expand Down Expand Up @@ -360,7 +360,7 @@ def _diffusion_step(

# Forward pass
outputs = model(input_ids=sequence, attention_mask=attention_mask)
logits = outputs.logits
logits = shift_logits_to_input_positions(outputs.logits)

# Only sample at currently masked positions
if current_mask.any():
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/integrations/diffusion/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from axolotl.utils.logging import get_logger

from .callbacks import DiffusionGenerationCallback
from .utils import create_bidirectional_attention_mask
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions

LOG = get_logger(__name__)

Expand Down Expand Up @@ -207,7 +207,7 @@ def _compute_diffusion_loss(
input_ids=noisy_batch.long(),
attention_mask=bidirectional_mask,
)
logits = outputs.logits
logits = shift_logits_to_input_positions(outputs.logits)

if masked_indices.sum() > 0:
valid_indices = torch.where(masked_indices)
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/integrations/diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(

# Add head dimension: [batch_size, 1, seq_len, seq_len]
return bidirectional_mask.unsqueeze(1)


def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
"""Align next-token logits with their input token positions for diffusion."""
if logits.size(1) <= 1:
return logits
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this trying to do? Concat logit's first column and 1..N column together?

Copy link
Member Author

@djsaunde djsaunde Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of a hack to use pretrained causal LMs for diffusion fine-tuning. we're shifting logits to the right by one position so we align the input logits with the output logits

Unfortunately we're duplicating the first token, but I couldn't think of a better way to do it. open to ideas here