Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,11 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta
@contextlib.contextmanager
def set_generator(generator: torch.Generator) -> typing.Generator[None, None, None]:
"""Use the generator as default, for ops that don't support a generator argument."""
default_generator: torch.Generator = torch.cuda.default_generators[torch.cuda.current_device()]
default_generator: torch.Generator = (
torch.cuda.default_generators[generator.device.index]
if generator.device.type == "cuda"
else torch.default_generator
)
assert generator is not default_generator
old_state = default_generator.get_state()
default_generator.set_state(generator.get_state())
Expand Down
78 changes: 51 additions & 27 deletions fast_llm/core/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,29 @@


def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor) -> torch.Tensor:
assert _apex_available
norm, _ = _multi_tensor_applier(
_multi_tensor_l2norm,
noop_flag,
[tensors],
False, # no per-parameter norm
)
if _apex_available:
norm, _ = _multi_tensor_applier(
_multi_tensor_l2norm,
noop_flag,
[tensors],
False, # no per-parameter norm
)
else:
norm = sum(torch.norm(tensor) ** 2 for tensor in tensors) ** 0.5
return norm


def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float) -> None:
assert _apex_available
_multi_tensor_applier(
_multi_tensor_scale,
noop_flag,
[tensors, tensors],
scale,
)
if _apex_available:
_multi_tensor_applier(
_multi_tensor_scale,
noop_flag,
[tensors, tensors],
scale,
)
else:
for tensor in tensors:
tensor.mul_(scale)


# TODO: Same as torch._fused_adam_?
Expand All @@ -52,16 +57,35 @@ def fused_adam(
eps: float,
step: int,
) -> None:
_multi_tensor_applier(
_multi_tensor_adam,
noop_flag,
[grads, params, exp_avgs, exp_avg_sqs],
lr,
beta1,
beta2,
eps,
step,
1, # adamw
1, # bias correction
wd,
)
if _apex_available:
_multi_tensor_applier(
_multi_tensor_adam,
noop_flag,
[grads, params, exp_avgs, exp_avg_sqs],
lr,
beta1,
beta2,
eps,
step,
1, # adamw
1, # bias correction
wd,
)
else:
import torch.optim.adamw as adamw

adamw.adamw(
params,
grads,
exp_avgs,
exp_avg_sqs,
None,
lr=lr,
beta1=beta1,
beta2=beta2,
eps=eps,
state_steps=torch.full([len(params)], step, dtype=torch.int64, device=params[0].device).unbind(),
weight_decay=wd,
amsgrad=False,
maximize=False,
)
8 changes: 2 additions & 6 deletions fast_llm/engine/checkpoint/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig
from fast_llm.engine.config_utils.runnable import RunnableConfig
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.functional.config import TritonConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -64,8 +63,8 @@ def _convert_model_partial(
logger.info(f"Loading {self.input.format} checkpoint from {self.input.path}...")
model = model_class.from_pretrained(
self.input,
{("distributed", "use_cuda"): not self.use_cpu},
mode=StageMode.weights,
use_cpu=self.use_cpu,
stage_filter=stage_filter,
)
logger.info(f"Saving {output.format} checkpoint to {output.path}...")
Expand All @@ -78,9 +77,6 @@ def run(self):
# TODO: Set logging in tests
logging.getLogger().setLevel(logging.INFO)
self.to_logs()
# Disable Triton to convert model on CPU
if self.use_cpu:
TritonConfig.TRITON_ENABLED = False
# Skip on exist_ok=False if the model has already been processed
if not self.exist_ok and (self.output.path / "ok").exists():
logger.info(
Expand All @@ -100,8 +96,8 @@ def run(self):
# Create a dummy version to determine the stage split.
model = model_class.from_pretrained(
self.input.to_copy({"model_weights": False}),
{("distributed", "use_cuda"): not self.use_cpu},
mode=StageMode.off_device,
use_cpu=self.use_cpu,
)
stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage)
num_stages = len(model.stages)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/config_utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def configure_logging(
def get_run(self, distributed: "Distributed") -> "Run":
from fast_llm.functional.config import TritonConfig

TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels
TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda
TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels
run = Run(config=self, distributed=distributed)
set_global_variables(not self.run.torch_dynamo_enable)
Expand Down
5 changes: 5 additions & 0 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ class DistributedConfig(Config):
hint=FieldHint.optional,
valid=check_field(Assert.gt, 0),
)
use_cuda: bool = Field(
default=True,
desc="Enable CUDA device.",
hint=FieldHint.expert,
)
seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional)
# TODO: Rename to compute_dtype (not just for training), move elsewhere
compute_dtype: DataType = Field(
Expand Down
23 changes: 12 additions & 11 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
world_size: int | None = None,
local_world_size: int | None = None,
timeout: float = 60,
use_cpu: bool = False,
use_cuda: bool = True,
backend: DistributedBackend = DistributedBackend.nccl,
):

Expand All @@ -37,19 +37,20 @@ def __init__(
DistributedConfig.default_local_world_size if local_world_size is None else local_world_size
)
self._timeout = timeout
self._use_cpu = use_cpu
self._use_cuda = use_cuda
self._backend = backend
self._process_groups = {}

if self._use_cpu:
if backend == DistributedBackend.nccl:
Assert.eq(self._world_size, 1)
self._device = torch.device("cpu")
else:
if self._use_cuda:
assert torch.cuda.is_available()
Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count())
torch.cuda.init()
self._device = torch.device(self._rank % self._local_world_size)
torch.cuda.set_device(self._device)
else:
if backend == DistributedBackend.nccl:
Assert.eq(self._world_size, 1)
self._device = torch.device("cpu")

if self._world_size > 1:
if self._rank == 0:
Expand Down Expand Up @@ -152,7 +153,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):
TODO: Clarify cpu support.
"""

def __init__(self, config: DistributedConfig, use_cpu: bool = False):
def __init__(self, config: DistributedConfig):
super().__init__(config)
assert self._config.reference_config is None

Expand All @@ -163,15 +164,15 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False):
self._config.world_size,
self._config.local_world_size,
self._config.timeout,
use_cpu,
self._config.use_cuda,
self._config.backend,
)
else:
self._pool = _default_pool
Assert.geq(self._pool.world_size, self._config.world_size)
Assert.eq(self._pool.rank, self._config.rank)
Assert.geq(self._pool.local_world_size, self._config.local_world_size)
Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda")
Assert.eq(self._pool.device.type, "cuda" if self._config.use_cuda else "cpu")
Assert.eq(self._pool.backend, self._config.backend)

self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world])
Expand Down Expand Up @@ -259,5 +260,5 @@ def set_step(self, step: int, phase: PhaseType) -> None:
self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED)

def __del__(self):
if self._local_pool:
if getattr(self, "_local_pool", False) and hasattr(self, "_pool"):
self._pool.shutdown()
2 changes: 0 additions & 2 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def from_pretrained(
optimizer_state_names: tuple[str, ...] | None = None,
# setup: bool = True,
mode: StageMode = StageMode.training,
use_cpu: bool = False,
stage_filter: set | None = None,
**kwargs,
) -> typing.Self:
Expand All @@ -104,7 +103,6 @@ def from_pretrained(
optimizer_state_names=optimizer_state_names,
setup=True,
mode=mode,
use_cpu=use_cpu,
stage_filter=stage_filter,
)

Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/multi_stage/fast_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def from_pretrained(
optimizer_state_names: tuple[str, ...] | None = None,
setup: bool = True,
mode: StageMode = StageMode.training,
use_cpu: bool = False,
stage_filter: set | None = None,
) -> typing.Self:
metadata = cls.config_class.load_metadata(pretrained_config)
Expand All @@ -69,7 +68,7 @@ def from_pretrained(
)

if setup:
model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode)
model.setup(Distributed(config.distributed), mode=mode)

if mode.on_device:
if pretrained_config.model_weights:
Expand Down
18 changes: 18 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,21 @@ class EventType(str, enum.Enum):
send = "send"
recv = "recv"
pipe_wait_compute = "pipe_wait_compute"


class MockStream:
stream_id: int = 0

def wait_stream(self, stream):
pass

def __eq__(self, other):
return isinstance(other, MockStream)


class MockEvent:
def record(self, stream=None):
pass

def wait(self):
pass
Loading
Loading