Skip to content

Commit 27e4cff

Browse files
authored
better type annotations (#242)
Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
1 parent 6707a44 commit 27e4cff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+508
-350
lines changed

lm_engine/checkpointing/lr_scheduler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
import os
68

79
import torch.nn as nn
@@ -16,10 +18,10 @@
1618

1719

1820
class _LRSchedulerSaver(Stateful):
19-
def __init__(self, lr_scheduler_container: LRSchedulerContainer) -> None:
21+
def __init__(self, lr_scheduler_container: LRSchedulerContainer) -> _LRSchedulerSaver:
2022
self.lr_scheduler_container = lr_scheduler_container
2123

22-
def state_dict(self) -> dict:
24+
def state_dict(self) -> list[dict]:
2325
return [lr_scheduler.state_dict() for lr_scheduler in self.lr_scheduler_container]
2426

2527
def load_state_dict(self, state_dict: list[dict]) -> None:

lm_engine/checkpointing/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
import os
68

79
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
@@ -11,7 +13,7 @@
1113

1214

1315
class _ModelSaver(Stateful):
14-
def __init__(self, model_container: ModelContainer) -> None:
16+
def __init__(self, model_container: ModelContainer) -> _ModelSaver:
1517
self.model_container = model_container
1618

1719
def state_dict(self) -> dict:

lm_engine/checkpointing/model_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
import os
68

79
from torch.distributed.checkpoint.state_dict import (
@@ -17,7 +19,9 @@
1719

1820

1921
class _ModelOptimizerSaver(Stateful):
20-
def __init__(self, model_container: ModelContainer, optimizer_container: OptimizerContainer) -> None:
22+
def __init__(
23+
self, model_container: ModelContainer, optimizer_container: OptimizerContainer
24+
) -> _ModelOptimizerSaver:
2125
self.model_container = model_container
2226
self.optimizer_container = optimizer_container
2327

lm_engine/checkpointing/optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
import os
68

79
from torch.distributed.checkpoint.state_dict import (
@@ -15,7 +17,7 @@
1517

1618

1719
class _OptimizerSaver(Stateful):
18-
def __init__(self, model_container: ModelContainer, optimizer_container: OptimizerContainer) -> None:
20+
def __init__(self, model_container: ModelContainer, optimizer_container: OptimizerContainer) -> _OptimizerSaver:
1921
self.model_container = model_container
2022
self.optimizer_container = optimizer_container
2123

lm_engine/containers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
import logging
68

79
import torch.nn as nn
@@ -10,7 +12,7 @@
1012

1113

1214
class _Container:
13-
def __init__(self, model_list: list[nn.Module]) -> None:
15+
def __init__(self, model_list: list[nn.Module]) -> _Container:
1416
self.model_list = model_list
1517

1618
def __iter__(self):
@@ -31,11 +33,11 @@ def __str__(self):
3133

3234

3335
class ModelContainer(_Container):
34-
def train(self) -> "ModelContainer":
36+
def train(self) -> ModelContainer:
3537
for model in self:
3638
model.train()
3739

38-
def eval(self) -> "ModelContainer":
40+
def eval(self) -> ModelContainer:
3941
for model in self:
4042
model.eval()
4143

lm_engine/data/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_datasets_list(
8989

9090
def get_finetuning_dataloader(
9191
args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE
92-
) -> tuple[ResumableDataLoader]:
92+
) -> ResumableDataLoader:
9393
"""prepares datasets and sampler
9494
9595
Args:
@@ -99,7 +99,7 @@ def get_finetuning_dataloader(
9999
tokenizer (TOKENIZER_TYPE): tokenizer
100100
101101
Returns:
102-
tuple[ResumableDataLoader]: dataloader for a blended dataset
102+
ResumableDataLoader: dataloader for a blended dataset
103103
"""
104104

105105
assert mode == Mode.training, "blended dataset is only supported in training mode"
@@ -121,7 +121,7 @@ def get_finetuning_dataloader(
121121

122122
def get_pretraining_dataloaders(
123123
args: TrainingArgs, tokenizer: TOKENIZER_TYPE, consumed_samples: int
124-
) -> tuple[ResumableDataLoader]:
124+
) -> tuple[ResumableDataLoader, list[ResumableDataLoader], list[ResumableDataLoader]]:
125125
if args.datasets[0].class_name == "MegatronDataset":
126126
dataloaders = get_megatron_gpt_dataloaders(args, tokenizer, consumed_samples=consumed_samples)
127127
elif args.datasets[0].class_name == "IBMDataset":
@@ -132,7 +132,7 @@ def get_pretraining_dataloaders(
132132

133133
def _get_dispatching_dataloader(
134134
args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE
135-
) -> tuple[ResumableDataLoader]:
135+
) -> ResumableDataLoader:
136136
micro_batch_size = args.training_parameters.micro_batch_size
137137

138138
num_ranks_per_node = torch.cuda.device_count()
@@ -211,7 +211,7 @@ def _get_source_broadcast_mapping() -> dict:
211211

212212
def _get_non_dispatching_dataloader(
213213
args: TrainingArgs | InferenceArgs, split: DatasetSplit, mode: Mode, tokenizer: TOKENIZER_TYPE
214-
) -> tuple[ResumableDataLoader]:
214+
) -> ResumableDataLoader:
215215
micro_batch_size = args.training_parameters.micro_batch_size
216216

217217
datasets_list, data_sampling_ratios = get_datasets_list(

lm_engine/data/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
import torch
68

79
from ..defaults import INPUT_FORMAT, OUTPUT_FORMAT
@@ -23,7 +25,7 @@ def __init__(
2325
output_format: str,
2426
max_input_tokens: int,
2527
max_output_tokens: int,
26-
) -> None:
28+
) -> BaseDataset:
2729
super().__init__()
2830

2931
self.split = split
@@ -39,11 +41,7 @@ def __init__(
3941
self.do_format_output = self.output_format != OUTPUT_FORMAT
4042

4143
# length to use for trimming (excludes eos)
42-
if max_input_tokens is None:
43-
self.max_input_tokens = None
44-
else:
45-
self.max_input_tokens = max_input_tokens
46-
44+
self.max_input_tokens = max_input_tokens
4745
self.max_output_tokens = None if max_output_tokens is None else max_output_tokens - 1
4846

4947
self.examples = []
@@ -124,7 +122,7 @@ def __len__(self) -> int:
124122
class BlendedDatasets(torch.utils.data.Dataset):
125123
"""Concatenated list of datasets for training or inference"""
126124

127-
def __init__(self, datasets: list[BaseDataset], split: DatasetSplit) -> None:
125+
def __init__(self, datasets: list[BaseDataset], split: DatasetSplit) -> BlendedDatasets:
128126
super().__init__()
129127

130128
self.split = split

lm_engine/data/dataloader.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5-
from typing import Callable, Iterable
5+
from __future__ import annotations
6+
7+
from typing import Callable, Iterable, Iterator
68

79
import torch
810
import torch.distributed
@@ -37,7 +39,7 @@ def __init__(
3739
broadcast_world_size: int | None = None,
3840
static_shape_per_rank: tuple[int, int] | None = None,
3941
keys: list[str] = ["input_ids", "attention_mask", "labels"],
40-
) -> None:
42+
) -> DispatchingDataLoader:
4143
self.broadcast_world_size = broadcast_world_size
4244

4345
self.is_source, self.source_rank, self.local_rank_in_broadcast_group, self.broadcast_group = (
@@ -67,7 +69,7 @@ def __init__(
6769

6870
self.keys = keys
6971

70-
def __iter__(self):
72+
def __iter__(self) -> Iterator[dict]:
7173
iterator = super().__iter__() if self.is_source else range(self._length)
7274

7375
for batch in iterator:

lm_engine/data/debug.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
from ..enums import DatasetSplit, Mode
68
from ..tokenizers import TOKENIZER_TYPE
79
from .base import BaseDataset
@@ -21,7 +23,7 @@ def __init__(
2123
output_format: str,
2224
max_input_tokens: int,
2325
max_output_tokens: int,
24-
) -> None:
26+
) -> DebugDataset:
2527
super().__init__(
2628
class_args=class_args,
2729
split=split,

lm_engine/data/huggingface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Copyright (c) 2025, Mayank Mishra
33
# **************************************************
44

5+
from __future__ import annotations
6+
57
from datasets import load_dataset
68

79
from lm_engine.tokenizers import TOKENIZER_TYPE
@@ -24,7 +26,7 @@ def __init__(
2426
output_format: str,
2527
max_input_tokens: int,
2628
max_output_tokens: int,
27-
) -> None:
29+
) -> HuggingFaceDataset:
2830
super().__init__(
2931
class_args=class_args,
3032
split=split,

0 commit comments

Comments
 (0)