diff --git a/src/axolotl/integrations/diffusion/generation.py b/src/axolotl/integrations/diffusion/generation.py index 49e3cdfae8..ec517fd238 100644 --- a/src/axolotl/integrations/diffusion/generation.py +++ b/src/axolotl/integrations/diffusion/generation.py @@ -7,7 +7,7 @@ from axolotl.utils.logging import get_logger -from .utils import create_bidirectional_attention_mask +from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions LOG = get_logger(__name__) @@ -360,7 +360,7 @@ def _diffusion_step( # Forward pass outputs = model(input_ids=sequence, attention_mask=attention_mask) - logits = outputs.logits + logits = shift_logits_to_input_positions(outputs.logits) # Only sample at currently masked positions if current_mask.any(): diff --git a/src/axolotl/integrations/diffusion/trainer.py b/src/axolotl/integrations/diffusion/trainer.py index 42b2468f41..dfaef2a48c 100644 --- a/src/axolotl/integrations/diffusion/trainer.py +++ b/src/axolotl/integrations/diffusion/trainer.py @@ -11,7 +11,7 @@ from axolotl.utils.logging import get_logger from .callbacks import DiffusionGenerationCallback -from .utils import create_bidirectional_attention_mask +from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions LOG = get_logger(__name__) @@ -207,7 +207,7 @@ def _compute_diffusion_loss( input_ids=noisy_batch.long(), attention_mask=bidirectional_mask, ) - logits = outputs.logits + logits = shift_logits_to_input_positions(outputs.logits) if masked_indices.sum() > 0: valid_indices = torch.where(masked_indices) diff --git a/src/axolotl/integrations/diffusion/utils.py b/src/axolotl/integrations/diffusion/utils.py index 47abf6fecb..b6f71c07b9 100644 --- a/src/axolotl/integrations/diffusion/utils.py +++ b/src/axolotl/integrations/diffusion/utils.py @@ -157,3 +157,10 @@ def create_bidirectional_attention_mask( # Add head dimension: [batch_size, 1, seq_len, seq_len] return bidirectional_mask.unsqueeze(1) + + +def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor: + """Align next-token logits with their input token positions for diffusion.""" + if logits.size(1) <= 1: + return logits + return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)