Skip to content
Merged
2 changes: 1 addition & 1 deletion lm_engine/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
from torch.distributed.tensor.parallel import loss_parallel
from transformers import set_seed

from .arguments import TrainingArgs, get_args
from .checkpointing import ensure_last_checkpoint_is_saved, load_checkpoint_for_training, save_checkpoint
Expand All @@ -29,6 +28,7 @@
StepTracker,
TorchProfiler,
init_distributed,
set_seed,
setup_tf32,
)

Expand Down
2 changes: 1 addition & 1 deletion lm_engine/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torch.distributed.tensor.parallel import loss_parallel
from torch.utils.data import DataLoader
from transformers import set_seed

from .arguments import DistillationArgs, TrainingArgs, get_args
from .checkpointing import ensure_last_checkpoint_is_saved, load_checkpoint_for_training, save_checkpoint
Expand All @@ -39,6 +38,7 @@
is_torchao_available,
log_environment,
log_rank_0,
set_seed,
setup_tf32,
)

Expand Down
1 change: 1 addition & 0 deletions lm_engine/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .parallel import ProcessGroupManager, get_pipeline_stage_ids_on_current_rank, run_rank_n
from .profiler import TorchProfiler
from .pydantic import BaseArgs
from .random import set_seed
from .safetensors import SafeTensorsWeightsManager
from .step_tracker import StepTracker
from .tracking import ExperimentsTracker, ProgressBar
Expand Down
15 changes: 15 additions & 0 deletions lm_engine/utils/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import random

import numpy as np
import torch


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@

import torch
import torch.distributed
from transformers import set_seed

from lm_engine.enums import Kernel
from lm_engine.hf_models import GPTBaseConfig, get_model_parallel_class
from lm_engine.kernels import enable_kernels
from lm_engine.utils import Communication, ProcessGroupManager, SafeTensorsWeightsManager, string_to_torch_dtype
from lm_engine.utils import (
Communication,
ProcessGroupManager,
SafeTensorsWeightsManager,
set_seed,
string_to_torch_dtype,
)

from ...test_common import TestCommons

Expand Down
2 changes: 1 addition & 1 deletion tests/hf_models/single_gpu/gpt_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import torch
from parameterized import parameterized
from transformers import set_seed

from lm_engine.enums import Kernel
from lm_engine.kernels import enable_kernels
from lm_engine.utils import set_seed

from ..test_common import TestCommons

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import torch
from parameterized import parameterized
from transformers import set_seed

from lm_engine.enums import Kernel
from lm_engine.hf_models import GPTBaseConfig
from lm_engine.kernels import enable_kernels
from lm_engine.utils import set_seed

from ..test_common import TestCommons

Expand Down
2 changes: 1 addition & 1 deletion tests/hf_models/single_gpu/scattermoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import torch
from parameterized import parameterized
from transformers import set_seed

from lm_engine.enums import Kernel
from lm_engine.kernels import enable_kernels
from lm_engine.utils import set_seed

from ..test_common import TestCommons

Expand Down
3 changes: 1 addition & 2 deletions tests/training/params_group/efficient_init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import os

import torch
from transformers import set_seed

from lm_engine.arguments import UnshardingArgs
from lm_engine.checkpointing import load_checkpoint_and_unshard, save_checkpoint
from lm_engine.distributed import wrap_model_container_for_distributed_training
from lm_engine.model_wrapper import get_model_container
from lm_engine.utils import ProcessGroupManager, load_yaml
from lm_engine.utils import ProcessGroupManager, load_yaml, set_seed

from ..test_commons import TestCommons

Expand Down