Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
180 commits
Select commit Hold shift + click to select a range
31c40e2
cleanup configs
mayank31398 Sep 28, 2025
6191d01
cleanup configs
mayank31398 Sep 28, 2025
1b153ea
cleanup TP test
mayank31398 Sep 28, 2025
94f4020
add SWA
mayank31398 Sep 28, 2025
bd46de7
add SWA
mayank31398 Sep 28, 2025
615e105
dim in FA
mayank31398 Sep 28, 2025
8d9f7d4
dim in FA
mayank31398 Sep 28, 2025
6b4d1dc
drop SBA
mayank31398 Sep 28, 2025
3e1aef2
drop SBA
mayank31398 Sep 28, 2025
bd123a8
drop SBA
mayank31398 Sep 28, 2025
aec861b
drop SBA
mayank31398 Sep 28, 2025
c45e89f
drop SBA
mayank31398 Sep 28, 2025
97284d2
drop SBA
mayank31398 Sep 28, 2025
3a94928
drop SBA
mayank31398 Sep 28, 2025
0b0af13
drop SBA
mayank31398 Sep 28, 2025
fb01683
drop SBA
mayank31398 Sep 28, 2025
a061be4
add packed tensor
mayank31398 Sep 28, 2025
a1c8d55
add packed tensor
mayank31398 Sep 28, 2025
137618c
add packed tensor
mayank31398 Sep 28, 2025
18c9a93
add packed tensor
mayank31398 Sep 28, 2025
e77686e
add packed tensor
mayank31398 Sep 28, 2025
5232d26
add packed tensor
mayank31398 Sep 28, 2025
0ba23bf
add packed tensor
mayank31398 Sep 28, 2025
21f7521
add packed tensor
mayank31398 Sep 29, 2025
0405210
add packed tensor
mayank31398 Sep 29, 2025
156db5d
add packed tensor
mayank31398 Sep 29, 2025
bb6e1b9
add packed tensor
mayank31398 Sep 29, 2025
6e688e5
add packed tensor
mayank31398 Sep 29, 2025
86e8bed
Merge branch 'main' into sl
mayank31398 Sep 29, 2025
84cc4e1
drop SBA temporarily
mayank31398 Sep 29, 2025
796b1da
drop SBA temporarily
mayank31398 Sep 29, 2025
766bcde
cleanup model_wrappers
mayank31398 Sep 29, 2025
3dd0bd4
cleanup model_wrappers
mayank31398 Sep 29, 2025
1e19a5c
cleanup model_wrappers
mayank31398 Sep 29, 2025
fa63452
cleanup model_wrappers
mayank31398 Sep 29, 2025
c80e873
cleanup model_wrappers
mayank31398 Sep 29, 2025
3a40c9b
cleanup model_wrappers
mayank31398 Sep 29, 2025
82a9c88
cleanup model_wrappers
mayank31398 Sep 29, 2025
3df92f5
cleanup model_wrappers
mayank31398 Sep 29, 2025
8085a17
cleanup model_wrappers
mayank31398 Sep 29, 2025
06b35fe
cleanup model_wrappers
mayank31398 Sep 29, 2025
b04c4ed
cleanup model_wrappers
mayank31398 Sep 29, 2025
2ba48cd
cleanup model_wrappers
mayank31398 Sep 29, 2025
f6beead
cleanup model_wrappers
mayank31398 Sep 29, 2025
695c85a
cleanup model_wrappers
mayank31398 Sep 29, 2025
5f3d790
cleanup model_wrappers
mayank31398 Sep 29, 2025
f526bb1
cleanup model_wrappers
mayank31398 Sep 29, 2025
07aa99e
cleanup model_wrappers
mayank31398 Sep 29, 2025
bfc8eb1
cleanup model_wrappers
mayank31398 Sep 29, 2025
f98a168
cleanup model_wrappers
mayank31398 Sep 29, 2025
14e54b6
cleanup model_wrappers
mayank31398 Sep 29, 2025
88b0595
cleanup model_wrappers
mayank31398 Sep 30, 2025
b4a2b5e
cleanup model_wrappers
mayank31398 Sep 30, 2025
0f8b0aa
cleanup model_wrappers
mayank31398 Sep 30, 2025
7239f5b
cleanup model_wrappers
mayank31398 Sep 30, 2025
bf191e2
cleanup model_wrappers
mayank31398 Sep 30, 2025
42c3efb
cleanup model_wrappers
mayank31398 Sep 30, 2025
f04ae6b
cleanup model_wrappers
mayank31398 Sep 30, 2025
7f7f0a2
cleanup model_wrappers
mayank31398 Sep 30, 2025
4214755
cleanup model_wrappers
mayank31398 Sep 30, 2025
3e97087
cleanup model_wrappers
mayank31398 Sep 30, 2025
83e4063
cleanup model_wrappers
mayank31398 Sep 30, 2025
b1dff71
cleanup model_wrappers
mayank31398 Sep 30, 2025
d6dffd3
cleanup model_wrappers
mayank31398 Sep 30, 2025
f838cea
cleanup model_wrappers
mayank31398 Sep 30, 2025
15c0824
cleanup model_wrappers
mayank31398 Sep 30, 2025
3aed859
cleanup model_wrappers
mayank31398 Sep 30, 2025
012e42e
cleanup model_wrappers
mayank31398 Sep 30, 2025
14b9ff7
cleanup model_wrappers
mayank31398 Sep 30, 2025
8183244
cleanup model_wrappers
mayank31398 Sep 30, 2025
4dfb0c7
cleanup model_wrappers
mayank31398 Sep 30, 2025
8438c1b
cleanup model_wrappers
mayank31398 Sep 30, 2025
60ac022
cleanup model_wrappers
mayank31398 Sep 30, 2025
47ac638
cleanup model_wrappers
mayank31398 Sep 30, 2025
71c8585
cleanup model_wrappers
mayank31398 Sep 30, 2025
37a8142
cleanup model_wrappers
mayank31398 Sep 30, 2025
8041f6c
cleanup model_wrappers
mayank31398 Sep 30, 2025
636987e
cleanup model_wrappers
mayank31398 Sep 30, 2025
424dab1
cleanup model_wrappers
mayank31398 Sep 30, 2025
ab95037
cleanup model_wrappers
mayank31398 Sep 30, 2025
21e6416
cleanup model_wrappers
mayank31398 Sep 30, 2025
b0edda7
cleanup model_wrappers
mayank31398 Sep 30, 2025
6bedf74
cleanup model_wrappers
mayank31398 Sep 30, 2025
2b87962
cleanup model_wrappers
mayank31398 Sep 30, 2025
e75d48f
cleanup model_wrappers
mayank31398 Sep 30, 2025
9b349c6
cleanup model_wrappers
mayank31398 Sep 30, 2025
df27f32
cleanup model_wrappers
mayank31398 Sep 30, 2025
21e0ba9
cleanup model_wrappers
mayank31398 Sep 30, 2025
976a9ba
cleanup model_wrappers
mayank31398 Sep 30, 2025
1d949d5
cleanup model_wrappers
mayank31398 Sep 30, 2025
24ee79a
cleanup model_wrappers
mayank31398 Sep 30, 2025
be7104e
cleanup model_wrappers
mayank31398 Sep 30, 2025
09ea1ef
cleanup model_wrappers
mayank31398 Sep 30, 2025
836900f
cleanup model_wrappers
mayank31398 Sep 30, 2025
6ca483d
cleanup model_wrappers
mayank31398 Sep 30, 2025
4ebee0d
cleanup model_wrappers
mayank31398 Sep 30, 2025
753ade1
cleanup model_wrappers
mayank31398 Sep 30, 2025
e682a14
cleanup model_wrappers
mayank31398 Sep 30, 2025
c5825ba
cleanup model_wrappers
mayank31398 Sep 30, 2025
9af4136
cleanup model_wrappers
mayank31398 Sep 30, 2025
08d3bfe
cleanup model_wrappers
mayank31398 Sep 30, 2025
ee3862d
cleanup model_wrappers
mayank31398 Oct 1, 2025
21a3958
cleanup model_wrappers
mayank31398 Oct 1, 2025
55bb8f3
cleanup model_wrappers
mayank31398 Oct 1, 2025
0980568
cleanup model_wrappers
mayank31398 Oct 1, 2025
a90184f
cleanup model_wrappers
mayank31398 Oct 1, 2025
06fb51a
rsa torch
mayank31398 Oct 1, 2025
91a670d
merge
mayank31398 Oct 8, 2025
ee5bb33
merge
mayank31398 Oct 8, 2025
eb339e6
merge
mayank31398 Oct 8, 2025
1297653
merge
mayank31398 Oct 8, 2025
b18fd58
merge
mayank31398 Oct 8, 2025
6240658
merge
mayank31398 Oct 8, 2025
8c82591
merge
mayank31398 Oct 8, 2025
47aa2c4
merge
mayank31398 Oct 8, 2025
efc4862
merge
mayank31398 Oct 8, 2025
0bd6eb3
merge
mayank31398 Oct 8, 2025
02eb59e
merge
mayank31398 Oct 8, 2025
07fb5fd
merge
mayank31398 Oct 9, 2025
958f081
merge
mayank31398 Oct 9, 2025
38383ac
merge
mayank31398 Oct 9, 2025
c5feb37
merge
mayank31398 Oct 9, 2025
c11141b
merge
mayank31398 Oct 9, 2025
fcf6761
merge
mayank31398 Oct 9, 2025
d4ac5cb
merge
mayank31398 Oct 9, 2025
4a8fcff
merge
mayank31398 Oct 9, 2025
cb63b79
merge
mayank31398 Oct 9, 2025
224ab27
merge
mayank31398 Oct 9, 2025
ff5dc55
merge
mayank31398 Oct 9, 2025
f2c860d
merge
mayank31398 Oct 9, 2025
764c233
merge
mayank31398 Oct 9, 2025
d173775
merge
mayank31398 Oct 9, 2025
369a3b3
better
mayank31398 Oct 9, 2025
935af46
better
mayank31398 Oct 9, 2025
30f890b
better
mayank31398 Oct 9, 2025
9d6b095
better
mayank31398 Oct 9, 2025
f5d7175
better
mayank31398 Oct 9, 2025
2407caf
better
mayank31398 Oct 9, 2025
205d6ab
better
mayank31398 Oct 9, 2025
656767c
better
mayank31398 Oct 9, 2025
6a36238
better
mayank31398 Oct 9, 2025
4630666
better
mayank31398 Oct 9, 2025
96a9f86
better
mayank31398 Oct 9, 2025
5e60c0d
better
mayank31398 Oct 9, 2025
c93232a
better
mayank31398 Oct 9, 2025
12d5c83
better
mayank31398 Oct 9, 2025
ed73a44
cleanup
mayank31398 Oct 13, 2025
44c3a9a
cleanup
mayank31398 Oct 19, 2025
8e0ea46
cleanup
mayank31398 Oct 19, 2025
3855ee7
cleanup
mayank31398 Oct 19, 2025
bfa093c
cleanup
mayank31398 Oct 20, 2025
5b55c84
cleanup
mayank31398 Oct 20, 2025
52efb76
cleanup
mayank31398 Oct 20, 2025
f3df060
cleanup
mayank31398 Oct 20, 2025
5c1e169
cleanup
mayank31398 Oct 20, 2025
4622db9
cleanup
mayank31398 Oct 20, 2025
c99211b
cleanup
mayank31398 Oct 20, 2025
45ed3c5
cleanup
mayank31398 Oct 20, 2025
a65f9de
cleanup
mayank31398 Oct 20, 2025
37de101
cleanup
mayank31398 Oct 20, 2025
d762281
cleanup
mayank31398 Oct 20, 2025
dcc2245
cleanup
mayank31398 Oct 20, 2025
c93267a
cleanup
mayank31398 Oct 20, 2025
a40b459
cleanup
mayank31398 Oct 20, 2025
6a322f8
cleanup
mayank31398 Oct 20, 2025
edebff4
cleanup
mayank31398 Oct 20, 2025
aa4bcdd
cleanup
mayank31398 Oct 20, 2025
d40273a
cleanup
mayank31398 Oct 20, 2025
8cb4529
cleanup
mayank31398 Oct 20, 2025
aced695
cleanup
mayank31398 Oct 20, 2025
6f439e3
cleanup
mayank31398 Oct 20, 2025
0a3db83
cleanup
mayank31398 Oct 20, 2025
8e31633
cleanup
mayank31398 Oct 20, 2025
d9750fb
cleanup
mayank31398 Oct 20, 2025
2a2ee67
cleanup
mayank31398 Oct 20, 2025
6fcd525
cleanup
mayank31398 Oct 20, 2025
a8b9072
cleanup
mayank31398 Oct 20, 2025
5c84652
cleanup
mayank31398 Oct 20, 2025
1e3b055
cleanup
mayank31398 Oct 20, 2025
66f2883
merge
mayank31398 Nov 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ __pycache__
/appwrapper.yaml
*.egg-info/
build/
*.log
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ labels = [[-100, -100, -100, 4, 5, 0], [-100, -100, 8, 0]]

# this will throw a warning saying that the model is of gpt_bigcode class
# ignore the warning
model = GPTBaseForCausalLM.from_pretrained(<model_path>, use_padding_free_transformer=True).cuda()
model = GPTBaseForCausalLM.from_pretrained(<model_path>).cuda()

with enable_kernels([Kernel.flash_attention_2]):
loss = model(input_ids=input_ids, labels=labels).loss
Expand Down
1 change: 0 additions & 1 deletion configs/distillation-example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ model_args:
model_class: AutoModelForCausalLM
model_name: ibm/PowerLM-3b
efficient_initialization: false
use_padding_free_transformer: false

teacher_args:
model_class: AutoModelForCausalLM
Expand Down
1 change: 0 additions & 1 deletion configs/finetuning-example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ model_args:
# padding free transformer needs a gpt_base model.
# To convert granite models to this class and convert back after training,
# take a look at the readme of this repo
use_padding_free_transformer: false

random_args:
# for replication of experiment (however, flash attention is non-deterministic so replication generally won't work)
Expand Down
1 change: 0 additions & 1 deletion configs/pretraining-examples/dense/pretrain-1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ model_args:
intermediate_size: 3072
add_bias: true
position_embedding_type: learned_absolute
use_padding_free_transformer: true

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/pretraining-examples/dense/pretrain-2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ model_args:
intermediate_size: 3072
add_bias: true
position_embedding_type: learned_absolute
use_padding_free_transformer: true

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/pretraining-examples/dense/pretrain-3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ model_args:
intermediate_size: 3072
add_bias: true
position_embedding_type: learned_absolute
use_padding_free_transformer: true

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/pretraining-examples/dense/pretrain-tpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ model_args:
intermediate_size: 3072
add_bias: true
position_embedding_type: learned_absolute
# use_padding_free_transformer: true

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/cross-layer-attention/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ model_args:
activation_function: swiglu
intermediate_size: 8192
efficient_initialization: false
use_padding_free_transformer: true

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/cross-layer-attention/cla.yml
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ model_args:
activation_function: swiglu
intermediate_size: 8192
efficient_initialization: false
use_padding_free_transformer: true

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/ladder-residual/1b-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ model_args:
activation_function: swiglu
intermediate_size: 4096
efficient_initialization: false
use_padding_free_transformer: false

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/ladder-residual/1b-ladder.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ model_args:
activation_function: swiglu
intermediate_size: 4096
efficient_initialization: false
use_padding_free_transformer: false

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/ladder-residual/1b-parallel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ model_args:
activation_function: swiglu
intermediate_size: 4096
efficient_initialization: false
use_padding_free_transformer: false

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/ladder-residual/3b-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ model_args:
- mlp_type: MLP
activation_function: swiglu
efficient_initialization: false
use_padding_free_transformer: false

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/ladder-residual/3b-ladder.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ model_args:
- mlp_type: MLP
activation_function: swiglu
efficient_initialization: false
use_padding_free_transformer: false

tuning_args:
tuning_method: pretraining
Expand Down
1 change: 0 additions & 1 deletion configs/research/ladder-residual/3b-parallel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ model_args:
- mlp_type: MLP
activation_function: swiglu
efficient_initialization: false
use_padding_free_transformer: false

tuning_args:
tuning_method: pretraining
Expand Down
2 changes: 0 additions & 2 deletions lm_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class ModelArgs(BaseArgs):
model_class: str = None
# trust remote code for models that are not directly supported by HuggingFace yet
trust_remote_code: bool = False
# whether to use padding free transformer: https://huggingface.co/blog/mayank-mishra/padding-free-transformer
use_padding_free_transformer: bool = False
# use lower memory to initialize model
efficient_initialization: bool = False
# whether to reset attention masks for pretraining
Expand Down
1 change: 0 additions & 1 deletion lm_engine/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def get_finetuning_dataloader(
use_output=use_output,
loss_mask=args.training_parameters.loss_mask,
eos_token_id=tokenizer.eos_token_id,
use_padding_free_transformer=args.model_args.use_padding_free_transformer,
pad_to_multiple_of=ProcessGroupManager.get_tensor_parallel_world_size(),
),
)
Expand Down
138 changes: 79 additions & 59 deletions lm_engine/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,59 @@
import torch

from ..enums import LossMask
from ..hf_models import convert_padding_free_lists_to_tensors


def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None:
if list_of_list is None:
return

assert isinstance(list_of_list, list), error_message
assert isinstance(list_of_list[0], list), error_message


def _flatten_and_convert_to_tensors(x: list[int], device: torch.device) -> torch.Tensor:
y = []
for sequence in x:
y.extend(sequence)

return torch.tensor(y, device=device)


def _convert_padding_free_lists_to_tensors(
input_ids: list[list[int]] | None = None,
position_ids: list[list[int]] | None = None,
labels: list[list[int]] | None = None,
device: torch.device = None,
) -> tuple[torch.Tensor | int]:

# check input types are correct
error_message = "{variable} should be of type List[List[{dtype}]]"
_check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int"))
_check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int"))
_check_list_type(labels, error_message.format(variable="labels", dtype="int"))

# prepare inputs for the model
seqlens = torch.tensor([0] + [len(x) for x in input_ids], device=device)
cu_seqlens = seqlens.cumsum(dim=-1).to(torch.int32)
max_seqlen = seqlens.max().item()

if position_ids is None:
position_ids = [list(range(len(x))) for x in input_ids]
position_ids = _flatten_and_convert_to_tensors(position_ids, device)

input_ids = _flatten_and_convert_to_tensors(input_ids, device)

if labels is not None:
labels = _flatten_and_convert_to_tensors(labels, device)

return input_ids, position_ids, labels, cu_seqlens, max_seqlen


def collate_fn(
batch: list[dict],
use_output: bool,
loss_mask: LossMask,
eos_token_id: int,
use_padding_free_transformer: bool,
labels_mask_value: int = -100,
pad_to_multiple_of: int = 1,
device: torch.device = None,
Expand All @@ -38,64 +82,40 @@ def collate_fn(

device = torch.cuda.current_device() if device is None else device

if use_padding_free_transformer:
input_ids = inputs
attention_mask = None

if loss_mask == LossMask.output_only:
labels = [
[labels_mask_value] * (len(array_in) - len(array_out)) + array_out
for array_in, array_out in zip(inputs, outputs)
]
elif loss_mask == LossMask.no_mask:
labels = inputs
else:
raise ValueError(f"unexpected loss_mask ({loss_mask})")

tokens_to_add = 0
if pad_to_multiple_of > 1:
total_tokens = sum([len(array) for array in input_ids])
tokens_to_add = (math.ceil(total_tokens / pad_to_multiple_of) * pad_to_multiple_of) - total_tokens

# we pad the last example in the batch on the right
# NOTE this can be done since the attention is causal
input_ids[-1].extend([eos_token_id] * tokens_to_add)
labels[-1].extend([labels_mask_value] * tokens_to_add)

input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = convert_padding_free_lists_to_tensors(
input_ids=input_ids, labels=labels, device=device
)

result = {
"input_ids": input_ids,
"position_ids": position_ids,
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
}
if labels is not None:
result["labels"] = labels
input_ids = inputs

if loss_mask == LossMask.output_only:
labels = [
[labels_mask_value] * (len(array_in) - len(array_out)) + array_out
for array_in, array_out in zip(inputs, outputs)
]
elif loss_mask == LossMask.no_mask:
labels = inputs
else:
max_length = max(list(map(len, inputs)))
if pad_to_multiple_of > 1:
max_length = math.ceil(max_length / pad_to_multiple_of) * pad_to_multiple_of

input_ids = [[eos_token_id] * (max_length - len(array)) + array for array in inputs]
attention_mask = [[0] * (max_length - len(array)) + [1] * len(array) for array in inputs]

if outputs is not None:
if loss_mask == LossMask.output_only:
labels = [[labels_mask_value] * (max_length - len(array)) + array for array in outputs]
elif loss_mask == LossMask.no_mask:
labels = inputs
else:
raise ValueError(f"unexpected loss_mask ({loss_mask})")

result = {
"input_ids": torch.tensor(input_ids, device=device),
"attention_mask": torch.tensor(attention_mask, device=device),
}
if labels is not None:
result["labels"] = torch.tensor(labels, device=device)
raise ValueError(f"unexpected loss_mask ({loss_mask})")

tokens_to_add = 0
if pad_to_multiple_of > 1:
total_tokens = sum([len(array) for array in input_ids])
tokens_to_add = (math.ceil(total_tokens / pad_to_multiple_of) * pad_to_multiple_of) - total_tokens

# we pad the last example in the batch on the right
# NOTE this can be done since the attention is causal
input_ids[-1].extend([eos_token_id] * tokens_to_add)
labels[-1].extend([labels_mask_value] * tokens_to_add)

input_ids, position_ids, _, labels, cu_seqlens, max_seqlen = _convert_padding_free_lists_to_tensors(
input_ids=input_ids, labels=labels, device=device
)

result = {
"input_ids": input_ids,
"position_ids": position_ids,
"cu_seqlens": cu_seqlens,
"max_seqlen": max_seqlen,
}
if labels is not None:
result["labels"] = labels

return result

Expand Down
3 changes: 2 additions & 1 deletion lm_engine/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from .cache import disable_generation_cache
from .config import CommonConfig
from .loss import get_autoregressive_language_modeling_loss, is_aux_loss_zero
from .mask import AttentionMaskInfo
from .mixins import CausalLMOutputWithPast, PipelineParallelInput, PipelineParallelOutput
from .model_conversion import export_to_huggingface, import_from_huggingface
from .models import (
Expand All @@ -30,7 +32,6 @@
)
from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes
from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts
from .utils import convert_padding_free_lists_to_tensors, disable_generation_cache


register_model_classes()
21 changes: 20 additions & 1 deletion lm_engine/hf_models/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Iterable
from typing import Any, Iterable

import torch

Expand Down Expand Up @@ -53,3 +53,22 @@ def get_seq_length(self, layer_idx: int = 0) -> int:
def reorder_cache(self, beam_idx: torch.Tensor) -> None:
for cache in self.cache:
cache.reorder_cache(beam_idx)


_IS_GENERATION_CACHE_ENABLED: bool = True


class disable_generation_cache:
def __enter__(self) -> Any:
global _IS_GENERATION_CACHE_ENABLED
self.original = _IS_GENERATION_CACHE_ENABLED

_IS_GENERATION_CACHE_ENABLED = False

def __exit__(self, exception_type, exception_value, exception_traceback) -> Any:
global _IS_GENERATION_CACHE_ENABLED
_IS_GENERATION_CACHE_ENABLED = self.original


def is_generation_cache_enabled() -> bool:
return _IS_GENERATION_CACHE_ENABLED
12 changes: 5 additions & 7 deletions lm_engine/hf_models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..enums import Kernel
from ..kernels import is_kernel_allowed
from ..utils import ProcessGroupManager, is_xma_available
from .mask import AttentionMaskInfo


if is_xma_available():
Expand All @@ -23,10 +24,9 @@
def get_autoregressive_language_modeling_loss(
lm_logits: torch.Tensor,
labels: torch.Tensor,
attention_mask_info: AttentionMaskInfo,
hidden_states: torch.Tensor | None = None,
vocab_weight: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_padding_free_transformer: bool = False,
reduction: str = "mean",
shift_logits_and_labels: bool = True,
tensor_parallel_enabled: bool = False,
Expand All @@ -40,15 +40,13 @@ def get_autoregressive_language_modeling_loss(

labels = labels[..., 1:]

if use_padding_free_transformer:
if shift_logits_and_labels:
assert cu_seqlens is not None
if shift_logits_and_labels:
cu_seqlens = attention_mask_info.get_cu_seqlens()

if cu_seqlens is not None:
# this is needed so that the last token of current example doesn't predict first token of next example
drop_loss_positions = cu_seqlens[1:-1] - 1
labels[drop_loss_positions] = -100
else:
assert cu_seqlens is None

if is_kernel_allowed(Kernel.fused_linear_cross_entropy):
assert lm_logits is None
Expand Down
Loading
Loading