Skip to content

Commit 9d4d39e

Browse files
authored
Diffusion trainer fix: shift logits to align with input tokens (#3191)
* shift logits for diffusion generate * delete unused * diffusion trainer: token shift
1 parent bb33fda commit 9d4d39e

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/axolotl/integrations/diffusion/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from axolotl.utils.logging import get_logger
99

10-
from .utils import create_bidirectional_attention_mask
10+
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
1111

1212
LOG = get_logger(__name__)
1313

@@ -360,7 +360,7 @@ def _diffusion_step(
360360

361361
# Forward pass
362362
outputs = model(input_ids=sequence, attention_mask=attention_mask)
363-
logits = outputs.logits
363+
logits = shift_logits_to_input_positions(outputs.logits)
364364

365365
# Only sample at currently masked positions
366366
if current_mask.any():

src/axolotl/integrations/diffusion/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from axolotl.utils.logging import get_logger
1212

1313
from .callbacks import DiffusionGenerationCallback
14-
from .utils import create_bidirectional_attention_mask
14+
from .utils import create_bidirectional_attention_mask, shift_logits_to_input_positions
1515

1616
LOG = get_logger(__name__)
1717

@@ -207,7 +207,7 @@ def _compute_diffusion_loss(
207207
input_ids=noisy_batch.long(),
208208
attention_mask=bidirectional_mask,
209209
)
210-
logits = outputs.logits
210+
logits = shift_logits_to_input_positions(outputs.logits)
211211

212212
if masked_indices.sum() > 0:
213213
valid_indices = torch.where(masked_indices)

src/axolotl/integrations/diffusion/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,10 @@ def create_bidirectional_attention_mask(
157157

158158
# Add head dimension: [batch_size, 1, seq_len, seq_len]
159159
return bidirectional_mask.unsqueeze(1)
160+
161+
162+
def shift_logits_to_input_positions(logits: torch.Tensor) -> torch.Tensor:
163+
"""Align next-token logits with their input token positions for diffusion."""
164+
if logits.size(1) <= 1:
165+
return logits
166+
return torch.cat([logits[:, :1], logits[:, :-1]], dim=1)

0 commit comments

Comments
 (0)