From 66f16d5e235ae778878da8bd614ace549dc4bdd3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 Jan 2026 01:58:06 -0500 Subject: [PATCH 1/2] stuff --- fast_llm/core/distributed.py | 6 +- fast_llm/core/kernels.py | 78 ++++++++++++------- fast_llm/engine/checkpoint/convert.py | 7 -- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/engine/distributed/config.py | 5 ++ fast_llm/engine/distributed/distributed.py | 23 +++--- fast_llm/engine/inference/huggingface.py | 2 - fast_llm/engine/multi_stage/fast_llm_model.py | 3 +- fast_llm/engine/schedule/config.py | 18 +++++ fast_llm/engine/schedule/runner.py | 61 +++++++++------ fast_llm/engine/schedule/schedule.py | 6 +- fast_llm/engine/training/trainer.py | 3 +- fast_llm/functional/triton/mlp.py | 15 ++-- fast_llm/functional/triton/normalization.py | 2 +- fast_llm/functional/triton/pointwise.py | 6 +- fast_llm/layers/attention/attention.py | 6 ++ fast_llm/layers/attention/rotary/config.py | 18 ----- fast_llm/layers/attention/rotary/rotary.py | 46 ++++++++--- fast_llm/utils.py | 9 +++ tests/data/common.py | 6 +- tests/functional/test_cross_entropy.py | 9 +-- tests/functional/test_functional.py | 27 ++++--- tests/functional/test_triton_kernels.py | 31 ++++---- tests/layers/test_attention.py | 14 ++-- tests/layers/test_lm_head.py | 9 ++- tests/layers/test_rotary.py | 11 ++- tests/layers/test_ssm.py | 29 +++---- tests/layers/test_varlen.py | 24 ++++-- tests/models/test_model.py | 1 - tests/test_multi_stage.py | 2 - tests/utils/model_configs.py | 12 ++- tests/utils/utils.py | 2 +- 32 files changed, 290 insertions(+), 203 deletions(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index 4dcc53d55..9d1f16fbe 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -185,7 +185,11 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta @contextlib.contextmanager def set_generator(generator: torch.Generator) -> typing.Generator[None, None, None]: """Use the generator as default, for ops that don't support a generator argument.""" - default_generator: torch.Generator = torch.cuda.default_generators[torch.cuda.current_device()] + default_generator: torch.Generator = ( + torch.cuda.default_generators[generator.device.index] + if generator.device.type == "cuda" + else torch.default_generator + ) assert generator is not default_generator old_state = default_generator.get_state() default_generator.set_state(generator.get_state()) diff --git a/fast_llm/core/kernels.py b/fast_llm/core/kernels.py index 9ead051d7..93371a654 100644 --- a/fast_llm/core/kernels.py +++ b/fast_llm/core/kernels.py @@ -18,24 +18,29 @@ def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor) -> torch.Tensor: - assert _apex_available - norm, _ = _multi_tensor_applier( - _multi_tensor_l2norm, - noop_flag, - [tensors], - False, # no per-parameter norm - ) + if _apex_available: + norm, _ = _multi_tensor_applier( + _multi_tensor_l2norm, + noop_flag, + [tensors], + False, # no per-parameter norm + ) + else: + norm = sum(torch.norm(tensor) ** 2 for tensor in tensors) ** 0.5 return norm def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float) -> None: - assert _apex_available - _multi_tensor_applier( - _multi_tensor_scale, - noop_flag, - [tensors, tensors], - scale, - ) + if _apex_available: + _multi_tensor_applier( + _multi_tensor_scale, + noop_flag, + [tensors, tensors], + scale, + ) + else: + for tensor in tensors: + tensor.mul_(scale) # TODO: Same as torch._fused_adam_? @@ -52,16 +57,35 @@ def fused_adam( eps: float, step: int, ) -> None: - _multi_tensor_applier( - _multi_tensor_adam, - noop_flag, - [grads, params, exp_avgs, exp_avg_sqs], - lr, - beta1, - beta2, - eps, - step, - 1, # adamw - 1, # bias correction - wd, - ) + if _apex_available: + _multi_tensor_applier( + _multi_tensor_adam, + noop_flag, + [grads, params, exp_avgs, exp_avg_sqs], + lr, + beta1, + beta2, + eps, + step, + 1, # adamw + 1, # bias correction + wd, + ) + else: + import torch.optim.adamw as adamw + + adamw.adamw( + params, + grads, + exp_avgs, + exp_avg_sqs, + None, + lr=lr, + beta1=beta1, + beta2=beta2, + eps=eps, + state_steps=torch.full([len(params)], step, dtype=torch.int64, device=params[0].device).unbind(), + weight_decay=wd, + amsgrad=False, + maximize=False, + ) diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index 4ab7b3d54..b40d8a1b3 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -8,7 +8,6 @@ from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode -from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -21,7 +20,6 @@ class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() - use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) @@ -65,7 +63,6 @@ def _convert_model_partial( model = model_class.from_pretrained( self.input, mode=StageMode.weights, - use_cpu=self.use_cpu, stage_filter=stage_filter, ) logger.info(f"Saving {output.format} checkpoint to {output.path}...") @@ -78,9 +75,6 @@ def run(self): # TODO: Set logging in tests logging.getLogger().setLevel(logging.INFO) self.to_logs() - # Disable Triton to convert model on CPU - if self.use_cpu: - TritonConfig.TRITON_ENABLED = False # Skip on exist_ok=False if the model has already been processed if not self.exist_ok and (self.output.path / "ok").exists(): logger.info( @@ -101,7 +95,6 @@ def run(self): model = model_class.from_pretrained( self.input.to_copy({"model_weights": False}), mode=StageMode.off_device, - use_cpu=self.use_cpu, ) stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage) num_stages = len(model.stages) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 1849a2316..415147d06 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -101,7 +101,7 @@ def configure_logging( def get_run(self, distributed: "Distributed") -> "Run": from fast_llm.functional.config import TritonConfig - TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels + TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels and distributed.config.use_cuda TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) set_global_variables(not self.run.torch_dynamo_enable) diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 7f4b7bc38..d0a078812 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -201,6 +201,11 @@ class DistributedConfig(Config): hint=FieldHint.optional, valid=check_field(Assert.gt, 0), ) + use_cuda: bool = Field( + default=True, + desc="Enable CUDA device.", + hint=FieldHint.expert, + ) seed: int = Field(default=1234, desc="A seed for training.", hint=FieldHint.optional) # TODO: Rename to compute_dtype (not just for training), move elsewhere compute_dtype: DataType = Field( diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index aa2be6ce7..e2f8daaa4 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -27,7 +27,7 @@ def __init__( world_size: int | None = None, local_world_size: int | None = None, timeout: float = 60, - use_cpu: bool = False, + use_cuda: bool = True, backend: DistributedBackend = DistributedBackend.nccl, ): @@ -37,19 +37,20 @@ def __init__( DistributedConfig.default_local_world_size if local_world_size is None else local_world_size ) self._timeout = timeout - self._use_cpu = use_cpu + self._use_cuda = use_cuda self._backend = backend self._process_groups = {} - if self._use_cpu: - if backend == DistributedBackend.nccl: - Assert.eq(self._world_size, 1) - self._device = torch.device("cpu") - else: + if self._use_cuda: + assert torch.cuda.is_available() Assert.in_range_incl(self._local_world_size, 1, torch.cuda.device_count()) torch.cuda.init() self._device = torch.device(self._rank % self._local_world_size) torch.cuda.set_device(self._device) + else: + if backend == DistributedBackend.nccl: + Assert.eq(self._world_size, 1) + self._device = torch.device("cpu") if self._world_size > 1: if self._rank == 0: @@ -152,7 +153,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]): TODO: Clarify cpu support. """ - def __init__(self, config: DistributedConfig, use_cpu: bool = False): + def __init__(self, config: DistributedConfig): super().__init__(config) assert self._config.reference_config is None @@ -163,7 +164,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self._config.world_size, self._config.local_world_size, self._config.timeout, - use_cpu, + self._config.use_cuda, self._config.backend, ) else: @@ -171,7 +172,7 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): Assert.geq(self._pool.world_size, self._config.world_size) Assert.eq(self._pool.rank, self._config.rank) Assert.geq(self._pool.local_world_size, self._config.local_world_size) - Assert.eq(self._pool.device.type, "cpu" if use_cpu else "cuda") + Assert.eq(self._pool.device.type, "cuda" if self._config.use_cuda else "cpu") Assert.eq(self._pool.backend, self._config.backend) self.world_group = self.add_group(self._config.distributed_dims[DistributedDimNames.world]) @@ -259,5 +260,5 @@ def set_step(self, step: int, phase: PhaseType) -> None: self.tp_generator.manual_seed((self._tp_seed + seed_shift) % MAX_SEED) def __del__(self): - if self._local_pool: + if getattr(self, "_local_pool", False) and hasattr(self, "_pool"): self._pool.shutdown() diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 3ffed4533..aa1eaa401 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -85,7 +85,6 @@ def from_pretrained( optimizer_state_names: tuple[str, ...] | None = None, # setup: bool = True, mode: StageMode = StageMode.training, - use_cpu: bool = False, stage_filter: set | None = None, **kwargs, ) -> typing.Self: @@ -104,7 +103,6 @@ def from_pretrained( optimizer_state_names=optimizer_state_names, setup=True, mode=mode, - use_cpu=use_cpu, stage_filter=stage_filter, ) diff --git a/fast_llm/engine/multi_stage/fast_llm_model.py b/fast_llm/engine/multi_stage/fast_llm_model.py index 6a6223cb7..ccde838e8 100644 --- a/fast_llm/engine/multi_stage/fast_llm_model.py +++ b/fast_llm/engine/multi_stage/fast_llm_model.py @@ -48,7 +48,6 @@ def from_pretrained( optimizer_state_names: tuple[str, ...] | None = None, setup: bool = True, mode: StageMode = StageMode.training, - use_cpu: bool = False, stage_filter: set | None = None, ) -> typing.Self: metadata = cls.config_class.load_metadata(pretrained_config) @@ -69,7 +68,7 @@ def from_pretrained( ) if setup: - model.setup(Distributed(config.distributed, use_cpu=use_cpu), mode=mode) + model.setup(Distributed(config.distributed), mode=mode) if mode.on_device: if pretrained_config.model_weights: diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 272b7c6ae..8696f0a59 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -191,3 +191,21 @@ class EventType(str, enum.Enum): send = "send" recv = "recv" pipe_wait_compute = "pipe_wait_compute" + + +class MockStream: + stream_id: int = 0 + + def wait_stream(self, stream): + pass + + def __eq__(self, other): + return isinstance(other, MockStream) + + +class MockEvent: + def record(self, stream=None): + pass + + def wait(self): + pass diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 133b3206b..5ddf2ff98 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -1,4 +1,5 @@ import collections +import contextlib import dataclasses import logging import time @@ -16,7 +17,7 @@ from fast_llm.engine.multi_stage.multi_stage import MultiStageModel from fast_llm.engine.multi_stage.stage import Stage from fast_llm.engine.optimizer.optimizer import Optimizer -from fast_llm.engine.schedule.config import EventType, ScheduleConfig, StepType, StreamType +from fast_llm.engine.schedule.config import EventType, MockEvent, MockStream, ScheduleConfig, StepType, StreamType from fast_llm.engine.schedule.schedule import Schedule, Step from fast_llm.logging import log_memory_usage from fast_llm.utils import Assert @@ -36,7 +37,7 @@ class BatchContext: # Dictionary of losses, purely for logging purposes. # Losses will be reduced over DP and PP, and aggregated over steps. losses: dict | None = None - profile: list[tuple[EventType, Step | None, torch.cuda.Event, StreamType, float]] = dataclasses.field( + profile: list[tuple[EventType, Step | None, torch.cuda.Event | MockEvent, StreamType, float]] = dataclasses.field( default_factory=list ) # Store metrics like: grad norm, loss scale, learning-rate, etc. @@ -65,15 +66,15 @@ def __repr__(self): class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ConfigType]): _is_setup: bool = False - _compute_stream: torch.cuda.Stream - _data_stream: torch.cuda.Stream - _pipeline_stream: torch.cuda.Stream + _compute_stream: torch.cuda.Stream | MockStream + _data_stream: torch.cuda.Stream | MockStream + _pipeline_stream: torch.cuda.Stream | MockStream _streams: dict[int, StreamType] - _compute_event: torch.cuda.Event - _reduce_event: torch.cuda.Event - _send_event: torch.cuda.Event + _compute_event: torch.cuda.Event | MockEvent + _reduce_event: torch.cuda.Event | MockEvent + _send_event: torch.cuda.Event | MockEvent _data_stream_needs_sync: bool - _profile_events: dict[tuple[EventType, tuple | None], torch.cuda.Event] + _profile_events: dict[tuple[EventType, tuple | None], torch.cuda.Event | MockEvent] _distributed: Distributed _optimizer: Optimizer | None _stages_on_device: list[Stage] @@ -111,12 +112,16 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> self._stages_owned = [stage.mode.on_device and not stage.is_tied_weight_copy for stage in self._stages] # Setup the streams - self._compute_stream = torch.cuda.current_stream(self._distributed.device) + self._compute_stream = self._get_current_stream() self._data_stream = ( - torch.cuda.Stream(self._distributed.device) if self._config.data_overlap else self._compute_stream + torch.cuda.Stream(self._distributed.device) + if self._config.data_overlap and self._distributed_config.use_cuda + else self._compute_stream ) self._pipeline_stream = ( - torch.cuda.Stream(self._distributed.device) if self._config.pipeline_overlap else self._compute_stream + torch.cuda.Stream(self._distributed.device) + if self._config.pipeline_overlap and self._distributed_config.use_cuda + else self._compute_stream ) # Putting compute stream last in the dict in case it's the same id. self._streams = { @@ -126,10 +131,12 @@ def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> } # Setup the synchronization and profiling events - self._profile_events = collections.defaultdict(lambda: torch.cuda.Event(enable_timing=True)) - self._compute_event = torch.cuda.Event() - self._reduce_event = torch.cuda.Event() - self._send_event = torch.cuda.Event() + self._profile_events = collections.defaultdict( + lambda: torch.cuda.Event(enable_timing=True) if self._distributed_config.use_cuda else MockEvent() + ) + self._compute_event = torch.cuda.Event() if self._distributed_config.use_cuda else MockEvent() + self._reduce_event = torch.cuda.Event() if self._distributed_config.use_cuda else MockEvent() + self._send_event = torch.cuda.Event() if self._distributed_config.use_cuda else MockEvent() self._data_stream_needs_sync = False def run_step( @@ -164,7 +171,7 @@ def run_step( self._distributed.set_step(iteration, schedule.phase) # Synchronize streams - Assert.eq(torch.cuda.current_stream(self._distributed.device), self._compute_stream) + Assert.eq(self._get_current_stream(), self._compute_stream) if self._config.profile_schedule: # Synchronize clocks safe_barrier(self._distributed.world_group, f"clock sync {iteration}") @@ -354,7 +361,7 @@ def _preprocess_data( def _restore(self, context: BatchContext, step: Step) -> None: if step.restore_launch: - with torch.cuda.stream(self._data_stream): + with self._with_stream(self._data_stream): self._sync_data_stream(context, step) for restore_step in step.restore_launch: self._stages[restore_step.stage].restore_parameters() @@ -368,7 +375,7 @@ def _restore(self, context: BatchContext, step: Step) -> None: def _recv(self, context: BatchContext, step: Step) -> None: if step.recv_launch: - with torch.cuda.stream(self._pipeline_stream): + with self._with_stream(self._pipeline_stream): for recv_step in step.recv_launch: # TODO: Pre-allocated buffers context.inputs[recv_step.global_index] = torch.empty_like( @@ -432,7 +439,7 @@ def _send(self, context: BatchContext, step: Step, output: torch.Tensor) -> None if step.next_step.recv_step is None: context.inputs[step.next_step.global_index] = output else: - with torch.cuda.stream(self._pipeline_stream): + with self._with_stream(self._pipeline_stream): self._compute_event.wait() self._record_event(context, EventType.pipe_wait_compute, step, self._pipeline_stream) if self._config.debug_send_recv: @@ -452,7 +459,7 @@ def _send(self, context: BatchContext, step: Step, output: torch.Tensor) -> None def _reduce(self, context: BatchContext, step: Step) -> None: if step.reduce: - with torch.cuda.stream(self._data_stream): + with self._with_stream(self._data_stream): self._sync_data_stream(context, step) stage = self._stages[step.stage] if not self._config.skip_step: @@ -462,12 +469,12 @@ def _reduce(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.reduce, step) def _record_event( - self, context: BatchContext, type_: EventType, step: Step | None, stream: torch.cuda.Stream = None + self, context: BatchContext, type_: EventType, step: Step | None, stream: torch.cuda.Stream | MockStream = None ) -> None: if not self._config.profile_schedule: return if stream is None: - stream = torch.cuda.current_stream() + stream = self._get_current_stream() event = self._profile_events[(type_, None if step is None else step.map_index)] event.record(stream) cpu_time = time.perf_counter() @@ -529,3 +536,11 @@ def _record_compute(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.run, step) if self._config.data_overlap: self._data_stream_needs_sync = True + + def _get_current_stream(self): + return ( + torch.cuda.current_stream(self._distributed.device) if self._distributed_config.use_cuda else MockStream() + ) + + def _with_stream(self, stream: torch.cuda.Stream | MockStream): + return torch.cuda.stream(stream) if self._distributed_config.use_cuda else contextlib.nullcontext() diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 18ca44b78..fa25c914d 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -281,7 +281,7 @@ def _setup_restore_steps(self, weight_buffer_indices: dict[int, int]) -> None: for step in device_steps: buffer_index = weight_buffer_indices[step.stage] if buffer_contents.get(buffer_index) != step.stage: - if self._schedule_config.data_overlap: + if self._schedule_config.data_overlap and self._distributed_config.use_cuda: step.restore_step = device_steps[buffer_last_used.get(buffer_index, -1) + 1] step.restore_event = torch.cuda.Event() else: @@ -378,7 +378,7 @@ def _setup_send_recv_steps(self) -> None: launch_step.recv_launch.append(recv_step) send_step.send_to = launch_step recv_step.recv_step = launch_step - if self._schedule_config.pipeline_overlap: + if self._schedule_config.pipeline_overlap and self._distributed_config.use_cuda: recv_step.recv_event = torch.cuda.Event() def _validate_send_recv_steps(self) -> None: @@ -449,7 +449,7 @@ def _validate_send_recv_steps(self) -> None: raise RuntimeError(f"Cannot find valid timeline for {self}, \nStatuses:{msg}") def _setup_throttle_steps(self) -> None: - if not self._schedule_config.throttle_cpu: + if not self._schedule_config.throttle_cpu or not self._distributed_config.use_cuda: return for device_steps in self._device_steps: for i, step in enumerate(device_steps): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 7225ed20a..b35733cc7 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -358,7 +358,8 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Synchronization is probably unnecessary. safe_barrier(self._distributed.world_group, "train begin") - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() start_time = time.perf_counter() last_time = start_time start_iteration = self._completed_steps diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 1d2d0b3d6..dc3ee0f04 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -235,14 +235,13 @@ def mlp_forward( input_ = None # Activation - if TritonConfig.TRITON_ENABLED: + if TritonConfig.TRITON_ENABLED and intermediate_1.device.type == "cuda": intermediate_2, _ = triton_mlp_activation_forward(intermediate_1, gated, activation_type) else: do_grad = training and not recompute_level.recompute_activation with torch.set_grad_enabled(do_grad): - intermediate_2 = torch_mlp_activation( - intermediate_1.detach().requires_grad_(do_grad), gated, activation_type - ) + intermediate_1 = intermediate_1.detach().requires_grad_(do_grad) + intermediate_2 = torch_mlp_activation(intermediate_1, gated, activation_type) if recompute_level.recompute_layer_1: intermediate_1 = None @@ -345,20 +344,20 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ )[0] # Activation recomputation and/or backward - if TritonConfig.TRITON_ENABLED: + if TritonConfig.TRITON_ENABLED and grad_output.device.type == "cuda": grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( grad_intermediate_2, (intermediate_1, gated, activation_type), intermediate_2 is None ) else: if intermediate_2 is None: + intermediate_1 = intermediate_1.detach().requires_grad_(True) with torch.set_grad_enabled(True): - intermediate_2_ = torch_mlp_activation( - intermediate_1.detach().requires_grad_(True), gated, activation_type - ) + intermediate_2_ = torch_mlp_activation(intermediate_1, gated, activation_type) else: intermediate_2_ = intermediate_2 intermediate_2_.backward(grad_intermediate_2) grad_intermediate_1 = intermediate_1.grad + print("AAAAA", intermediate_2 is None, grad_intermediate_1) # Layer 2 parameter grad del grad_intermediate_2, intermediate_1 diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 96d1663f7..a018ad44b 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -187,7 +187,7 @@ def triton_normalization_forward( n_cols = weight.numel() output = torch.empty_like(input_, dtype=weight.dtype) - inv_var = torch.empty(n_rows, dtype=torch.float32, device="cuda") + inv_var = torch.empty(n_rows, dtype=torch.float32, device=input_.device) block_size = triton.next_power_of_2(n_cols) assert block_size * input_.element_size() <= TritonConfig.MAX_BLOCK_SIZE_BYTES diff --git a/fast_llm/functional/triton/pointwise.py b/fast_llm/functional/triton/pointwise.py index bd14de9e2..22676ae1a 100644 --- a/fast_llm/functional/triton/pointwise.py +++ b/fast_llm/functional/triton/pointwise.py @@ -32,7 +32,7 @@ def triton_copy( """ A triton implementation of tensor copying (`torch.Tensor.copy_()`). """ - if not TritonConfig.TRITON_ENABLED: + if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": return out.copy_(input_) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -65,7 +65,7 @@ def triton_fill( """ A faster triton implementation of tensor copying (`torch.Tensor.fill_()`). """ - if not TritonConfig.TRITON_ENABLED: + if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": return input_.fill_(value) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -106,7 +106,7 @@ def triton_add( """ A faster triton implementation of tensor addition (`torch.Tensor.add()`). """ - if not TritonConfig.TRITON_ENABLED: + if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": return torch.add(input_, other, out=out) # TODO: Improve assumptions. assert input_.is_contiguous() diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 073599479..902352c25 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -459,6 +459,11 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size if self._config.causal: + print( + "WWWWW", + kwargs[AttentionKwargs.sequence_length], + self._backup_attention_tensor_cache_max_sequence_length, + ) if ( sequence_length := kwargs[AttentionKwargs.sequence_length] ) > self._backup_attention_tensor_cache_max_sequence_length: @@ -491,6 +496,7 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non if attention_mask is None: attention_mask = document_mask else: + print("AAAA", attention_mask.shape, document_mask.shape, kwargs) attention_mask = attention_mask & document_mask kwargs[AttentionKwargs.attention_mask] = attention_mask diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 92adc880e..1ec35ae0c 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -1,12 +1,10 @@ import abc import math import typing -import warnings from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -59,22 +57,6 @@ class DefaultRotaryConfig(RotaryConfig): desc="Scale for the rotary positional embeddings", hint=FieldHint.architecture, ) - # TODO: Make a backup implementation that doesn't affect the layout. - triton: bool = Field( - default=True, - desc="Enable the triton implementation of the rotary embeddings. Affects the model layout.", - hint=FieldHint.architecture, - ) - - @property - def complex_format(self) -> bool: - # TODO: Make a backup implementation that doesn't affect the layout. - return not self.triton - - def _validate(self) -> None: - super()._validate() - if self.triton and not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": from fast_llm.layers.attention.rotary.rotary import DefaultRotary diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 258f9d8bc..304f96b83 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -6,6 +6,7 @@ from fast_llm.config import Configurable from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.attention.rotary.config import ( @@ -28,7 +29,7 @@ def convert_rotary_real_to_complex(tensor: torch.Tensor, head_size: int, dim: in return tensor.unflatten(dim, (-1, 2, div(head_size, 2))).movedim(dim + 1, dim + 2).flatten(dim, dim + 2) -def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: +def rotary_embeddings_complex(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: """ Apply rotary embeddings to a tensor: * Convert it to a complex, full-precision tensor @@ -41,6 +42,23 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) +@torch.compile +def rotary_embeddings_real(tensor: torch.Tensor, rope_frequencies: torch.Tensor) -> torch.Tensor: + """ + Apply rotary embeddings to a tensor. + """ + tensor_re, tensor_im = torch.chunk(tensor, 2, dim=-1) + frequencies_re, frequencies_im = torch.chunk(rope_frequencies, 2, dim=-1) + + return torch.cat( + [ + tensor_re * frequencies_re - tensor_im * frequencies_im, + tensor_im * frequencies_re + tensor_re * frequencies_im, + ], + dim=-1, + ) + + class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, @@ -82,7 +100,11 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + rotary_fn = ( + triton_rotary_autograd_ + if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" + else rotary_embeddings_real + ) query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key @@ -107,10 +129,9 @@ def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.d positions = torch.arange(sequence_length, device=device, dtype=torch.float64) angles = torch.outer(positions, self._get_angle_scales(head_size, device)) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not self._config.complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), head_size, 3 - ).contiguous() + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), head_size, 3 + ).contiguous() return frequencies def _get_angle_scales(self, head_size: int, device: torch.device) -> torch.Tensor: @@ -207,10 +228,9 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: out=angles.view(-1, 2, self._head_size // 4).permute(1, 0, 2), ) frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) - if not self._config.complex_format: - frequencies = convert_rotary_complex_to_real( - torch.view_as_real(frequencies).flatten(-2), self._head_size, 3 - ).contiguous() + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), self._head_size, 3 + ).contiguous() # TODO: Support different q and k frequencies. kwargs[AttentionKwargs.rotary_freq_q] = frequencies kwargs[AttentionKwargs.rotary_freq_k] = frequencies @@ -218,7 +238,11 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + rotary_fn = ( + triton_rotary_autograd_ + if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" + else rotary_embeddings_real + ) query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 2ca61aa0e..408441b95 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -421,6 +421,15 @@ def get_and_reset_memory_usage_mib( global _global_max_allocated, _global_max_reserved import torch + if not torch.cuda.is_available(): + return { + "reserved": 0.0, + "allocated": 0.0, + "max_reserved": 0.0, + "max_allocated": 0.0, + "global_max_reserved": 0.0, + } + if clear_cache: # Free memory for more accurate reporting, and to reduce OOM risk with lots of workers. # Cublas workspace can unnecessarily keep 100s of MBs of reserved memory. diff --git a/tests/data/common.py b/tests/data/common.py index 34fdba321..7ec4a9018 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -32,7 +32,7 @@ def get_sampling_data( preprocessing: LanguageModelPreprocessingConfig | None = None, ) -> GPTSamplingData: # Config with convenient defaults. - distributed = Distributed(DistributedConfig(), use_cpu=True) + distributed = Distributed(DistributedConfig(use_cuda=torch.cuda.is_available())) if preprocessing is None: preprocessing = LanguageModelPreprocessingConfig() return GPTSamplingData( @@ -71,8 +71,8 @@ def get_test_data_and_compare_samples( expected_samples: dict[str, list[list[int]]] | list[list[int]], preprocessing: LanguageModelPreprocessingConfig, ) -> GPTData: - distributed_config = DistributedConfig(seed=87522) - distributed = Distributed(distributed_config, use_cpu=True) + distributed_config = DistributedConfig(seed=87522, use_cuda=torch.cuda.is_available()) + distributed = Distributed(distributed_config) if isinstance(samples_per_dataset, int): samples_per_dataset = {PhaseType.training.value.lower(): samples_per_dataset} diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 20d16bb96..1bfde36ed 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -10,12 +10,12 @@ from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat, device="cuda" + num_columns: int, loss_masking: bool, target_format: TargetFormat ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" # We want something moderately close to the target for the test to be meaningful logits_var = torch.randn(256, num_columns, dtype=torch.bfloat16, device=device) / 3 loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None @@ -49,7 +49,6 @@ def _compare_cross_entropy_outputs( assert ref_grad is None -@requires_cuda @pytest.mark.slow @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), @@ -85,6 +84,8 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski threshold = 2e-5 if logits_scale_factor == 1.0 else 1e-2 _compare_cross_entropy_outputs(out_fused, out_torch, grad_output is not None, grad_fused, grad_torch, threshold) + if not torch.cuda.is_available(): + return if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) @@ -111,7 +112,6 @@ def _reverse_kl_forward_backward_torch(logits: torch.Tensor, target: torch.Tenso return output, logits.grad -@requires_cuda @pytest.mark.slow # TODO: Support the same parameterization as above in the reference implementation. @pytest.mark.parametrize("loss_masking", [False, True]) @@ -206,7 +206,6 @@ def compare_parallel_cross_entropy(rank: int, group: torch.distributed.ProcessGr raise RuntimeError("Test failed") -@requires_cuda @pytest.mark.slow def test_distillation_losses(): _spawn_dist(2, compare_parallel_cross_entropy) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index c48a0a531..76c0841d9 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -8,7 +8,6 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans -from tests.utils.utils import requires_cuda def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): @@ -78,22 +77,22 @@ def test_dpo_loss(): Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) -@requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] ) def test_mlp_recomputation(gated, activation): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokens = 1024 hidden_size = 2048 intermediate_size = 4096 std = 1 / 64 - input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - output_grad = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device="cuda", requires_grad=True) - bias_1 = torch.normal(0, std, (intermediate_size * (gated + 1),), device="cuda", requires_grad=True) - weight_2 = torch.normal(0, std, (intermediate_size, hidden_size), device="cuda", requires_grad=True) - bias_2 = torch.normal(0, std, (hidden_size,), device="cuda", requires_grad=True) + input_ = torch.randn(tokens, hidden_size, device=device, requires_grad=True) + output_grad = torch.randn(tokens, hidden_size, device=device, requires_grad=True) + weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device=device, requires_grad=True) + bias_1 = torch.normal(0, std, (intermediate_size * (gated + 1),), device=device, requires_grad=True) + weight_2 = torch.normal(0, std, (intermediate_size, hidden_size), device=device, requires_grad=True) + bias_2 = torch.normal(0, std, (hidden_size,), device=device, requires_grad=True) params = (weight_1, bias_1, weight_2, bias_2) output_ref = torch.nn.functional.linear( @@ -137,27 +136,27 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow @pytest.mark.skip("Dropless MoE is broken") -@requires_cuda def test_dropless_mlp(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") num_experts = 4 experts_per_token = 4 tokens = 256 hidden_size = 512 intermediate_size = 1024 std = 1 / 64 - input_ = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) - router_weight = torch.normal(0, std, (num_experts, hidden_size), device="cuda") + input_ = torch.randn(tokens, hidden_size, device=device, requires_grad=True) + router_weight = torch.normal(0, std, (num_experts, hidden_size), device=device) top_logits, top_experts = torch.topk( torch.nn.functional.linear(input_.detach(), router_weight), k=experts_per_token, dim=-1 ) scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32).detach().requires_grad_() - output_grad = torch.randn(tokens, hidden_size, device="cuda", requires_grad=True) + output_grad = torch.randn(tokens, hidden_size, device=device, requires_grad=True) weight_1 = torch.normal( - 0, std, (intermediate_size * 2 * num_experts, hidden_size), device="cuda", requires_grad=True + 0, std, (intermediate_size * 2 * num_experts, hidden_size), device=device, requires_grad=True ) - weight_2 = torch.normal(0, std, (intermediate_size * num_experts, hidden_size), device="cuda", requires_grad=True) + weight_2 = torch.normal(0, std, (intermediate_size * num_experts, hidden_size), device=device, requires_grad=True) params = (weight_1, weight_2) for param in params: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index a5a693be6..79817bb03 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -19,9 +19,10 @@ from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.layers.attention.rotary.rotary import ( - apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, + rotary_embeddings_complex, + rotary_embeddings_real, ) from fast_llm.utils import Assert, rms_diff from tests.utils.utils import requires_cuda @@ -81,30 +82,32 @@ def test_triton_add(): ) def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): assert TritonConfig.TRITON_ENABLED - x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.bfloat16, device="cuda") - - y1 = apply_rotary_embeddings( - x, - DefaultRotaryConfig(triton=False) + x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.float32, device="cuda") + frequencies = ( + DefaultRotaryConfig() .get_layer(TensorDim("", head_size)) ._get_frequencies( sequence_length, head_size, device="cuda", - ), + ) ) - y2 = convert_rotary_real_to_complex( - triton_rotary_( - convert_rotary_complex_to_real(x, head_size, 3), - DefaultRotaryConfig(triton=True) - .get_layer(TensorDim("", head_size)) - ._get_frequencies(sequence_length, head_size, device="cuda"), + y_real = rotary_embeddings_real(x, frequencies) + + y_complex = convert_rotary_complex_to_real( + rotary_embeddings_complex( + convert_rotary_real_to_complex(x, head_size, 3), + torch.view_as_complex(convert_rotary_real_to_complex(frequencies, head_size, 3).unflatten(-1, (-1, 2))), ), head_size, 3, ) - Assert.rms_close(y1, y2, 1e-3) + + y_triton = triton_rotary_(x, frequencies) + + Assert.rms_close(y_real, y_complex, 1e-4) + Assert.rms_close(y_real, y_triton, 1e-4) @requires_cuda diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 508597173..924c2cc7f 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -3,19 +3,19 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.attention import Attention, _flash_available from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda -@requires_cuda @pytest.mark.parametrize("cross_document_attention", (True, False)) @pytest.mark.parametrize(("causal", "window_size"), ((True, None), (True, 50), (False, None))) +@pytest.mark.skipif(not _flash_available, reason="Flash attention not available") def test_attention_implementations(cross_document_attention: bool, causal: bool, window_size: int | None): """ Check that the flash and backup attention implementation give the same result. """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") attention: Attention = AttentionConfig( head_size=32, heads=4, @@ -29,11 +29,11 @@ def test_attention_implementations(cross_document_attention: bool, causal: bool, lr_scale=None, peft=None, ) - query = torch.empty(4, 100, 4, 32, dtype=torch.bfloat16, device="cuda").normal_() - key = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device="cuda").normal_() - value = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device="cuda").normal_() + query = torch.empty(4, 100, 4, 32, dtype=torch.bfloat16, device=device).normal_() + key = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device=device).normal_() + value = torch.empty(4, 100, 2, 32, dtype=torch.bfloat16, device=device).normal_() kwargs = { - AttentionKwargs.device: torch.device("cuda"), + AttentionKwargs.device: device, AttentionKwargs.sequence_length: 100, AttentionKwargs.sequence_lengths: [ [20, 32, 10, 11, 9, 18], diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 623a30d82..86ce0253d 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -11,7 +11,7 @@ from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert -from tests.utils.utils import get_base_model, get_stage, requires_cuda +from tests.utils.utils import get_base_model, get_stage def _reverse_kl_loss( @@ -94,7 +94,6 @@ def _lm_head( VOCAB_SIZE = 500 -@requires_cuda @pytest.mark.slow @pytest.mark.parametrize("cross_entropy_impl", tuple(CrossEntropyImpl)) @pytest.mark.parametrize( @@ -163,9 +162,11 @@ def test_lm_head( loss_masking: bool, prediction_heads: int, ): + if cross_entropy_impl in (CrossEntropyImpl.auto, CrossEntropyImpl.triton) and not torch.cuda.is_available(): + pytest.skip("Cuda is not available") head_config = { "cross_entropy_implementation": cross_entropy_impl, - "normalization": {"type": "rms_norm"}, + "normalization": {"type": "rms_norm", "implementation": "auto" if torch.cuda.is_available() else "torch"}, } config = GPTBaseModelConfig.from_dict( { @@ -191,7 +192,7 @@ def test_lm_head( GPTModelConfig.from_dict( { "base_model": config, - "distributed": distributed_config_dict, + "distributed": {**distributed_config_dict, "use_cuda": torch.cuda.is_available()}, }, ) ) diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py index 85d72b316..112c88a66 100644 --- a/tests/layers/test_rotary.py +++ b/tests/layers/test_rotary.py @@ -6,27 +6,26 @@ from fast_llm.layers.attention.rotary.config import Rotary2DConfig from fast_llm.layers.vision.config import VisionKwargs from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda -@requires_cuda def test_rotary_2d(): """ Compare Fast-LLM's implementation of 2d rotary embeddings with Pixtral. """ head_dim = 16 num_heads = 8 + device = "cuda" if torch.cuda.is_available() else "cpu" patch_positions = torch.tensor( [[h, w] for h in range(4) for w in range(4)], dtype=torch.int64, - device="cuda", + device=device, ) - query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device="cuda").normal_() + query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device=device).normal_() key = torch.empty_like(query).normal_() pixtral_config = transformers.PixtralVisionConfig(hidden_size=head_dim * num_heads, num_attention_heads=num_heads) - pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to("cuda") + pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to(device) # Convert patch positions (h, w) to Pixtral's linear position IDs # Pixtral expects: position_id = h * max_patches_per_side + w position_ids = ( @@ -38,7 +37,7 @@ def test_rotary_2d(): ) fast_llm_rotary = Rotary2DConfig().get_layer(TensorDim("head_dim", head_dim)) - kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: "cuda"} + kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: device} fast_llm_rotary.preprocess(kwargs) output_fast_llm_query, output_fast_llm_key = fast_llm_rotary.forward(query, key, kwargs) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index b371ba086..f6be506ca 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -1,5 +1,6 @@ import pytest import torch +import transformers from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.tensor_dim import TensorDim @@ -11,18 +12,7 @@ from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.utils import Assert from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention -from tests.utils.utils import get_stage, requires_cuda - -try: - from fast_llm_external_models.apriel2.modeling_apriel2 import Apriel2GatedDeltaNet, Apriel2Mamba -except ImportError: - Apriel2GatedDeltaNet = None - Apriel2Mamba = None - -try: - from fast_llm_external_models.apriel_hybrid_ssm.modeling_apriel_hybrid_ssm import KimiDeltaAttention -except ImportError: - KimiDeltaAttention = None +from tests.utils.utils import get_stage HIDDEN_SIZE = 16 SEQ_LEN = 65 @@ -31,7 +21,9 @@ def _compare_mixers( fast_llm_config: MixerConfig, hf_layer: torch.nn.Module, param_map: dict[str, str], threshold=1e-5 ): - distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16)) + distributed = Distributed( + distributed_config := DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=torch.cuda.is_available()) + ) fast_llm_layer = fast_llm_config.get_layer( distributed_config, TensorDim("", HIDDEN_SIZE), @@ -82,10 +74,9 @@ def _compare_mixers( @pytest.mark.slow -@pytest.mark.skipif(Apriel2GatedDeltaNet is None, reason="Apriel GDN deps missing") -@requires_cuda +# Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. +@pytest.mark.skipif(not transformers.utils.import_utils.is_causal_conv1d_available(), reason="GDN deps missing") def test_gdn(): - device = torch.device("cuda") dtype = torch.bfloat16 NUM_V_HEADS = 4 @@ -103,7 +94,7 @@ def test_gdn(): hf_layer = ( Apriel2GatedDeltaNet(HIDDEN_SIZE, {**config_common, "norm_eps": 1e-5}, layer_idx=0, dtype=dtype) - .to(device=device, dtype=dtype) + .to(device="cuda" if torch.cuda.is_available() else "cpu", dtype=dtype) .eval() ) fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) @@ -111,7 +102,6 @@ def test_gdn(): @pytest.mark.slow -@requires_cuda @pytest.mark.skipif(kda_module.chunk_kda is None, reason="KDA fused kernels not available") def test_kda(): NUM_HEADS = 4 @@ -133,10 +123,9 @@ def test_kda(): @pytest.mark.slow -@requires_cuda @pytest.mark.parametrize("add_linear_biases", [True, False]) @pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) -@pytest.mark.skipif(Apriel2Mamba is None, reason="Apriel2 Mamba not available") +@pytest.mark.skipif(not transformers.utils.import_utils.is_mamba_ssm_available(), reason="Mamba not available") def test_mamba(add_linear_biases, repeat_kv_before_conv): D_INNER = 128 D_XB = 64 diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index c8d962f40..ff27a8e8d 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -12,12 +12,11 @@ from fast_llm.layers.ssm import kda as kda_module from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.utils import Assert -from tests.utils.utils import get_stage, requires_cuda +from tests.utils.utils import get_stage # TODO: include mamba varlen @pytest.mark.slow -@requires_cuda @pytest.mark.parametrize( "config", [ @@ -50,7 +49,9 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): """ hidden_size = 32 hidden_dim = TensorDim("hidden", hidden_size) - distributed = Distributed(distributed_config := DistributedConfig(compute_dtype=DataType.float16)) + distributed = Distributed( + distributed_config := DistributedConfig(compute_dtype=DataType.float16, use_cuda=torch.cuda.is_available()) + ) mixer = config.get_layer(distributed_config, hidden_dim, lr_scale=None, peft=None, return_bias=False) stage = get_stage([mixer], distributed) @@ -71,11 +72,15 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): BlockKwargs.device: distributed.device, BlockKwargs.sequence_first: False, BlockKwargs.hidden_dims: (hidden_dim,), + } + + kwargs_packed = { + **kwargs, + BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.sequence_length: seq_len, BlockKwargs.sequence_q_dim: TensorDim("", seq_len), BlockKwargs.sequence_k_dim: TensorDim("", seq_len), } - - kwargs_packed = {**kwargs, BlockKwargs.sequence_lengths: sequence_lengths} mixer.preprocess(kwargs_packed) out_packed, context = stage.forward(hidden_states, kwargs_packed) @@ -89,7 +94,14 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): out_refs = [] for i in range(batch_size): for seq in torch.split(hidden_states[i], sequence_lengths[i], dim=0): - kwargs_seq = {**kwargs, BlockKwargs.sequence_lengths: [[len(seq)]]} + seq_len_ = len(seq) + kwargs_seq = { + **kwargs, + BlockKwargs.sequence_lengths: [[seq_len_]], + BlockKwargs.sequence_length: seq_len_, + BlockKwargs.sequence_q_dim: TensorDim("", seq_len_), + BlockKwargs.sequence_k_dim: TensorDim("", seq_len_), + } mixer.preprocess(kwargs_seq) out, context = stage.forward(seq.unsqueeze(0), kwargs_seq) stage.backward(torch.ones_like(out), context) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index d14721142..f80b2b25f 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -14,7 +14,6 @@ logger = logging.getLogger(__name__) -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_model_simple(run_test_script_for_all_models, run_test_script_base_path): # A simple config to prevent unnecessary testing and creation of dependency group diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index e3870a7b1..2f476ae52 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -7,7 +7,6 @@ from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel from fast_llm.utils import Assert from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import requires_cuda def _get_model(config_dict: dict, model_type: str = "gpt") -> FastLLMModel: @@ -18,7 +17,6 @@ def _get_model(config_dict: dict, model_type: str = "gpt") -> FastLLMModel: return model -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): model_testing_config.get_dataset() diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1248a1117..41f209b58 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -8,6 +8,7 @@ import typing import pytest +import torch import transformers from fast_llm.config import set_nested_dict_value @@ -97,6 +98,7 @@ class ModelTestingConfig: auto_model_class: type["transformers.models.auto.auto_factory._BaseAutoModelClass"] = ( transformers.AutoModelForCausalLM ) + requires_cuda: bool = False def __post_init__(self): _, config, _, _ = self.get_dataset(config_only=True) @@ -259,10 +261,11 @@ def _update_and_add_testing_config( "reproducible_init": True, "timeout": 20, "backend": "nccl", + "use_cuda": torch.cuda.is_available(), }, }, "batch": {"batch_size": 8, "sequence_length": 512}, - "data": {}, + "data": {"sampling": {"gpu": torch.cuda.is_available()}}, "optimizer": {"learning_rate": {"base": 0.0001}}, }, megatron_args=[ @@ -698,6 +701,7 @@ def _update_and_add_testing_config( compare_factor=2.0, # Micro-sequence split not supported. skip_tests=("sdp", "ms"), + requires_cuda=True, ) _update_and_add_testing_config( @@ -792,6 +796,7 @@ def _update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), + requires_cuda=True, ) _update_and_add_testing_config( @@ -914,6 +919,7 @@ def _update_and_add_testing_config( # Pipeline-parallel gives a different mixer selection. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). skip_tests=("sdp", "ms", "pp", TP_NO_STP), + requires_cuda=True, ) @@ -957,6 +963,7 @@ def _update_and_add_testing_config( # Micro-sequence split and sequence-first not supported for Mamba. # TP excluded because no gradient reductions implemented for TP norm in GDN (use STP instead). skip_tests=("sdp", "ms", GRAD_ACC, TP_NO_STP), + requires_cuda=True, ) @@ -996,6 +1003,7 @@ def _update_and_add_testing_config( # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), + requires_cuda=True, ) @@ -1013,6 +1021,8 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo groups: tuple[ModelTestingGroup] = item.keywords["model_testing_group"].args model_testing_config = item.callspec.params["model_testing_config"] model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] + if model_config.requires_cuda and not torch.cuda.is_available(): + item.add_marker(pytest.mark.skip(reason=f"Cuda not available.")) for group in groups: action = model_config.groups[group] if action == ModelTestingGroupAction.main: diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 3b79f7607..e176d9b32 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +requires_cuda = pytest.mark.skipif(False, reason="CUDA is not available") @pytest.fixture(scope="session") From f144b87e6cfa83bd5e8dd80cb74b472af0c53f5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 9 Jan 2026 22:56:47 -0500 Subject: [PATCH 2/2] fixes --- fast_llm/engine/checkpoint/convert.py | 3 ++ fast_llm/engine/config_utils/run.py | 2 +- fast_llm/functional/triton/mlp.py | 1 - fast_llm/layers/attention/attention.py | 8 +--- fast_llm/layers/attention/rotary/config.py | 7 ++++ .../common/normalization/normalization.py | 18 +++++++- fast_llm/layers/language_model/head.py | 2 +- fast_llm/logging.py | 2 +- fast_llm/models/gpt/conversion/llama.py | 9 ---- fast_llm/models/gpt/megatron.py | 9 ---- tests/conftest.py | 3 ++ tests/functional/test_cross_entropy.py | 2 +- tests/models/test_checkpoint.py | 42 +++++++++---------- tests/models/test_generate.py | 8 ---- tests/models/test_lm_eval.py | 5 --- tests/models/test_model.py | 9 ++-- tests/utils/compare_tensor_logs.py | 1 + tests/utils/distributed_configs.py | 16 +++++-- tests/utils/utils.py | 2 +- 19 files changed, 73 insertions(+), 76 deletions(-) diff --git a/fast_llm/engine/checkpoint/convert.py b/fast_llm/engine/checkpoint/convert.py index b40d8a1b3..103d9488c 100644 --- a/fast_llm/engine/checkpoint/convert.py +++ b/fast_llm/engine/checkpoint/convert.py @@ -20,6 +20,7 @@ class ConvertConfig(RunnableConfig): input: CheckpointLoadConfig = Field() output: CheckpointSaveConfig = Field() + use_cpu: bool = Field(default=False) exist_ok: bool = Field(default=False) layers_per_step: int | None = Field(default=None) model: type[FastLLMModelConfig] = Field(default=None) @@ -62,6 +63,7 @@ def _convert_model_partial( logger.info(f"Loading {self.input.format} checkpoint from {self.input.path}...") model = model_class.from_pretrained( self.input, + {("distributed", "use_cuda"): not self.use_cpu}, mode=StageMode.weights, stage_filter=stage_filter, ) @@ -94,6 +96,7 @@ def run(self): # Create a dummy version to determine the stage split. model = model_class.from_pretrained( self.input.to_copy({"model_weights": False}), + {("distributed", "use_cuda"): not self.use_cpu}, mode=StageMode.off_device, ) stages_per_step = math.ceil(self.layers_per_step / model._config.multi_stage.layers_per_stage) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 415147d06..77507afa8 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -101,7 +101,7 @@ def configure_logging( def get_run(self, distributed: "Distributed") -> "Run": from fast_llm.functional.config import TritonConfig - TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels and distributed.config.use_cuda + TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) set_global_variables(not self.run.torch_dynamo_enable) diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index dc3ee0f04..286e7159a 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -357,7 +357,6 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ intermediate_2_ = intermediate_2 intermediate_2_.backward(grad_intermediate_2) grad_intermediate_1 = intermediate_1.grad - print("AAAAA", intermediate_2 is None, grad_intermediate_1) # Layer 2 parameter grad del grad_intermediate_2, intermediate_1 diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 902352c25..be58724ea 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -222,7 +222,7 @@ def _attn_backup( attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( - attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value + attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk).to(value.dtype), value ) if self._local_head_groups == 1: @@ -459,11 +459,6 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size if self._config.causal: - print( - "WWWWW", - kwargs[AttentionKwargs.sequence_length], - self._backup_attention_tensor_cache_max_sequence_length, - ) if ( sequence_length := kwargs[AttentionKwargs.sequence_length] ) > self._backup_attention_tensor_cache_max_sequence_length: @@ -496,7 +491,6 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non if attention_mask is None: attention_mask = document_mask else: - print("AAAA", attention_mask.shape, document_mask.shape, kwargs) attention_mask = attention_mask & document_mask kwargs[AttentionKwargs.attention_mask] = attention_mask diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 1ec35ae0c..80f499748 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -58,6 +58,13 @@ class DefaultRotaryConfig(RotaryConfig): hint=FieldHint.architecture, ) + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if "complex_format" in default: + Assert.is_(default["complex_format"], False) + del default["complex_format"] + return super()._from_dict(default, strict=strict) + def _get_configurable_class(self) -> "type[DefaultRotary]": from fast_llm.layers.attention.rotary.rotary import DefaultRotary diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index b1e875707..4bd1343aa 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -183,6 +183,12 @@ class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[Con def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) implementation = self._config.implementation + print( + "IKUEGBNHIUWGBN", + implementation, + TritonConfig.TRITON_ENABLED, + (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered, + ) if implementation == NormalizationImplementation.auto: if ( _fast_normalization_available @@ -190,7 +196,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | and not self._config.zero_centered ): implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: + elif (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -199,6 +205,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch + print("BNHTHERDGRG", implementation) if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: @@ -258,8 +265,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel implementation = self._config.implementation + print( + "IKUEGBNHIUWGBN", + implementation, + TritonConfig.TRITON_ENABLED, + (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered, + ) if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._config.zero_centered: + if (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") @@ -267,6 +280,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch + print("BNHTHERDGRG", implementation) if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b1d0c2acd..9f3b6506f 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -92,7 +92,7 @@ def __init__( if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._vocab_parallel: self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: + elif TritonConfig.TRITON_ENABLED and torch.cuda.is_available(): self._cross_entropy_impl = CrossEntropyImpl.triton else: self._cross_entropy_impl = CrossEntropyImpl.fused diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 931c7f644..5a2ff2dac 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -196,7 +196,7 @@ def log_tensor[ step = max(tensor.numel() // target_samples, 1) while step > 1 and any(step % s == 0 and s > 1 for s in shape): step -= 1 - samples = tensor.flatten()[: target_samples * step : step].cpu() + samples = tensor.flatten()[: target_samples * step : step].to("cpu", copy=True) stats.update(samples=samples, step=step) # Crop the list in the logs. The full tensor is still in stats. samples = [format_number(x) for x in samples.tolist()[: TensorLogs.config.max_elements]] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index bc75f6236..00d871dbf 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -15,7 +15,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.config import DecoderBlockConfig @@ -314,16 +313,12 @@ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.rotary.complex_format: - query = convert_rotary_complex_to_real(query[:], self._config.head_size, 0) return (query,) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (query,) = weight - if self._config.rotary.complex_format: - query = convert_rotary_real_to_complex(query[:], self._config.head_size, 0) return (query,) @@ -336,16 +331,12 @@ def export_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: (key_value,) = weight key, value = key_value[:].chunk(2) - if self._config.rotary.complex_format: - key = convert_rotary_complex_to_real(key, self._config.head_size, 0) return key, value def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: key, value = weight - if self._config.rotary.complex_format: - key = convert_rotary_real_to_complex(key[:], self._config.head_size, 0) key_value = torch.cat([key[:], value[:]]) return (key_value,) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index f63bd76f8..3b97df3d1 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,6 +1,5 @@ import typing -from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MoEMLPConfig from fast_llm.utils import Assert, div @@ -84,12 +83,10 @@ def _init_attention_megatron( generator, ) if "dense" in meta.tensor_name: - kv_dim = 1 tensor_ = dense_tensor_ else: # Keep the original random state for key_value and dense. generator.set_state(state) - kv_dim = 0 if "query" in meta.tensor_name: # We want to generate the same tensor for key_value. tensor_ = qkv_tensor_[:, :heads_per_group] @@ -98,12 +95,6 @@ def _init_attention_megatron( else: raise NotImplementedError(meta.tensor_name) - if isinstance(config.mixer.rotary, DefaultRotaryConfig) and config.mixer.rotary.complex_format: - from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex - - # Megatron uses (2, head_size/2) for the complex split; we use (head_size/2, 2). - # TODO: Avoid unnecessarily changing the value and dense tensors. - tensor_ = convert_rotary_real_to_complex(tensor_.view_as(meta), config.mixer.head_size, kv_dim) return tensor_ diff --git a/tests/conftest.py b/tests/conftest.py index ba2927c64..28bab0ad5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import pytest import xdist.scheduler +from fast_llm.functional.config import TritonConfig from fast_llm.utils import get_and_reset_memory_usage_mib from tests.utils.depends import DependencyManager from tests.utils.global_variables import TEST_RESULTS_PATH, set_testing_global_variables @@ -259,6 +260,8 @@ def pytest_runtest_call(item: pytest.Function): except RuntimeError: pytest.skip("Cuda runtime unavailable due to an error in an earlier test.") manager.handle_missing(item) + # Some tests may modify this global variable. + TritonConfig.TRITON_ENABLED = True def pytest_unconfigure(): diff --git a/tests/functional/test_cross_entropy.py b/tests/functional/test_cross_entropy.py index 1bfde36ed..420316ce3 100644 --- a/tests/functional/test_cross_entropy.py +++ b/tests/functional/test_cross_entropy.py @@ -66,7 +66,6 @@ def _compare_cross_entropy_outputs( @pytest.mark.parametrize("target_format", (TargetFormat.labels, TargetFormat.logits, TargetFormat.probabilities)) def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_masking, target_format): # TODO: Test tensor-parallel implementation. - assert TritonConfig.TRITON_ENABLED logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) kwargs = { "logits": logits, @@ -86,6 +85,7 @@ def test_cross_entropy(num_columns, grad_output, logits_scale_factor, loss_maski if not torch.cuda.is_available(): return + assert TritonConfig.TRITON_ENABLED if num_columns > 65536: with pytest.raises(AssertionError): cross_entropy_forward_backward(**kwargs, implementation=CrossEntropyImpl.triton) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index bb53de29e..5c31dde16 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -22,7 +22,6 @@ from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig -from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -35,7 +34,6 @@ ] -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_checkpoint_and_eval(run_test_script_for_all_models, model_testing_config): # A baseline config (single-gpu, bf16, flash-attn). @@ -59,7 +57,6 @@ def do_prepare_resume(distributed_testing_config: DistributedTestingConfig): return do_prepare_resume -@requires_cuda @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume(run_test_script_for_all_models, compare_results_for_all_models, prepare_resume): @@ -75,7 +72,6 @@ def test_resume(run_test_script_for_all_models, compare_results_for_all_models, compare_results_for_all_models(distributed_testing_config) -@requires_cuda @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.checkpoint) def test_resume_frozen(run_test_script_for_all_models, prepare_resume): @@ -102,13 +98,13 @@ def do_run_conversion( path=get_convert_path(save_format, load_format), format=save_format, ), + use_cpu=not torch.cuda.is_available(), model=model_testing_config.model_config_class, ).run() return do_run_conversion -@requires_cuda @pytest.mark.depends_on(on=["test_checkpoint_and_eval[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_conversion(model_testing_config, run_conversion, get_convert_path): @@ -171,7 +167,6 @@ def _compare_safetensor_files( Assert.all_equal(reference[key], other[key]) -@requires_cuda @pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_converted_round_trip(model_testing_config, get_convert_path): @@ -218,7 +213,8 @@ def do_load_and_compare_checkpoints( CheckpointLoadConfig( path=load_path, format=load_format, - ) + ), + {("distributed", "use_cuda"): torch.cuda.is_available()}, ) if reference_config is not None: _compare_model_configs(reference_config, model.config) @@ -228,7 +224,6 @@ def do_load_and_compare_checkpoints( return do_load_and_compare_checkpoints -@requires_cuda @pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_pretrained( @@ -238,9 +233,9 @@ def test_load_pretrained( reference_config = model_testing_config.model_config_class.from_dict( yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) - reference_shard = safetensors.torch.load_file(get_convert_path() / "rank_0.safetensors", device="cuda")[ - _WEIGHT_SHARD_SAVE_NAME - ] + reference_shard = safetensors.torch.load_file( + get_convert_path() / "rank_0.safetensors", device="cuda" if torch.cuda.is_available() else "cpu" + )[_WEIGHT_SHARD_SAVE_NAME] load_and_compare_checkpoints( FastLLMCheckpointFormat, get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat), @@ -303,10 +298,11 @@ def test_load_pretrained( ) -@requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_huggingface_model(model_testing_config, get_convert_path): + device = "cuda" if torch.cuda.is_available() else "cpu" + distributed_update = {("distributed", "use_cuda"): torch.cuda.is_available()} if model_testing_config.checkpoint_format is None: return # Test that Fast-LLM's Hugging Face wrapper produces the same results as the converted Hugging Face model. @@ -323,18 +319,19 @@ def test_huggingface_model(model_testing_config, get_convert_path): path=get_convert_path(), format=DistributedCheckpointFormat, load_config=ModelConfigType.model, - ) + ), + distributed_update, ).eval() test_input = torch.randint( 0, 384, size=(4, 100), dtype=torch.int64, - device="cuda", + device=device, ) kwargs = {} if model_testing_config.model_type == "multimodal": - kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).cuda() + kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(device) kwargs["image_sizes"] = torch.tensor( [ [20, 20], # Full image, 25 patches @@ -360,16 +357,21 @@ def test_huggingface_model(model_testing_config, get_convert_path): # Last one cropped out. output_ref = model_ref(test_input, **kwargs) - model_from_fast_llm = hf_class.from_pretrained(fast_llm_path).eval() + model_from_fast_llm = hf_class.from_pretrained(fast_llm_path, distributed_update).eval() model_from_hf = hf_class.from_pretrained( CheckpointLoadConfig( path=hf_path, format=model_testing_config.checkpoint_format, load_config=ModelConfigType.model, - ) + ), + distributed_update, ).eval() errors = [] - model_as_hf = model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True).cuda().eval() + model_as_hf = ( + model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True) + .to("cuda" if torch.cuda.is_available() else "cpu") + .eval() + ) for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), @@ -391,7 +393,6 @@ def test_huggingface_model(model_testing_config, get_convert_path): raise ValueError(f"Comparison failed ({len(errors)} errors)") -@requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_path, model_testing_config, request): @@ -430,7 +431,6 @@ def reference_distributed_shard(get_convert_path) -> torch.Tensor | None: # We don't want to depend on `test_save_and_load_in_parallel` because we still want to run this in cas of failure. # This should still run after `test_save_and_load_in_parallel` -@requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_load_parallel_checkpoint_in_single_gpu( @@ -464,7 +464,6 @@ def test_load_parallel_checkpoint_in_single_gpu( ) -@requires_cuda @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_base_path): @@ -496,7 +495,6 @@ def reference_fast_llm_shard(get_convert_path) -> dict[str, torch.Tensor] | None return None -@requires_cuda @pytest.mark.depends_on(on=["test_save_and_load_in_parallel[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) def test_multi_gpu_fast_llm_checkpoint( diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index bce77d4f2..c595b5148 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -11,7 +11,6 @@ from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import requires_cuda def _prepare_data(tokenizer, use_batch_size2: bool): @@ -206,7 +205,6 @@ def _test_generate( @pytest.mark.extra_slow -@requires_cuda @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", [ @@ -238,7 +236,6 @@ def test_generate( ) -@pytest.mark.slow @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_export_for_generate(run_test_script_for_all_models, model_testing_config): # Not really testing, anything, but handles dependencies more easily than a fixture. @@ -254,7 +251,6 @@ def test_export_for_generate(run_test_script_for_all_models, model_testing_confi @pytest.mark.slow -@requires_cuda @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.parametrize( "use_flash_attention, use_bf16, max_new_tokens, min_matching_tokens_batch_size_1, min_matching_tokens_batch_size_2", @@ -307,7 +303,6 @@ def _test_generate_from_model(model_path, tokenizer, fast_llm_checkpoint_format) ) -@requires_cuda @pytest.mark.extra_slow def test_generate_from_model( model_path, @@ -315,7 +310,6 @@ def test_generate_from_model( _test_generate_from_model(model_path, AutoTokenizer.from_pretrained(model_path), LlamaCheckpointFormat) -@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) @@ -356,7 +350,6 @@ def _test_forward_return_hidden_states( @pytest.mark.extra_slow -@requires_cuda def test_forward_return_hidden_states(model_path): _test_forward_return_hidden_states( model_path, LlamaCheckpointFormat, AutoTokenizer.from_pretrained(model_path).vocab_size @@ -364,7 +357,6 @@ def test_forward_return_hidden_states(model_path): @pytest.mark.slow -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.generate) @pytest.mark.depends_on(on=["test_export_for_generate[{model_testing_config}]"]) def test_small_forward_return_hidden_states(model_testing_config, run_test_script_base_path): diff --git a/tests/models/test_lm_eval.py b/tests/models/test_lm_eval.py index 8011b5bbc..7ae26c2d6 100644 --- a/tests/models/test_lm_eval.py +++ b/tests/models/test_lm_eval.py @@ -7,7 +7,6 @@ from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.global_variables import TOKENIZER_PATH from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import requires_cuda # NOTE: These tests only verify that the functionality runs without crashing. # NOTE: The tokenizer is from a LLaMA-style model, which may not be suitable for all models, @@ -55,7 +54,6 @@ def do_get_lm_eval_config(base_path): # "gsm8k,xnli_en,wikitext" -@requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_lm_eval_in_training(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): run_test_script_for_all_models( @@ -76,7 +74,6 @@ def do_copy_training_output(distributed_testing_config: DistributedTestingConfig return do_copy_training_output -@requires_cuda @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_lm_eval_evaluation_last_checkpoint( @@ -91,7 +88,6 @@ def test_lm_eval_evaluation_last_checkpoint( run_test_script_for_all_models(distributed_testing_config=distributed_testing_config, runnable_type="evaluate") -@requires_cuda @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.generate) def test_lm_eval_evaluation_from_pretrained( @@ -111,7 +107,6 @@ def test_lm_eval_evaluation_from_pretrained( # TODO: rewrite for a new distributed test function -# @requires_cuda # @pytest.mark.depends_on(on=["test_lm_eval_in_training[{model_testing_config}]"]) # @pytest.mark.model_testing_group(ModelTestingGroup.generate, ModelTestingGroup.distributed) # def test_lm_eval_in_training_dp2(run_test_script_for_all_models, run_test_script_base_path, get_lm_eval_config): diff --git a/tests/models/test_model.py b/tests/models/test_model.py index f80b2b25f..b3247102b 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -9,7 +9,7 @@ SINGLE_GPU_TESTING_CONFIGS, ) from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import check_subtest_success, requires_cuda, set_subtest_success +from tests.utils.utils import check_subtest_success, set_subtest_success logger = logging.getLogger(__name__) @@ -21,7 +21,6 @@ def test_model_simple(run_test_script_for_all_models, run_test_script_base_path) set_subtest_success(run_test_script_base_path / SIMPLE_TESTING_CONFIG.name) -@requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.basic) # Parametrize with config name so it shows in test name. @@ -46,9 +45,9 @@ def test_and_compare_model( if config.compare is not None: compare_results_for_all_models(config) + # raise ValueError() -@requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group( ModelTestingGroup.distributed, @@ -56,6 +55,9 @@ def test_and_compare_model( def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): import tests.models.distributed_test_model + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") + script = [ "-m", tests.models.distributed_test_model.__name__, @@ -73,7 +75,6 @@ def test_run_model_distributed(run_distributed_script, model_testing_config, run # We don't want to depend on `test_model_distributed` because we still want to run this in cas of failure. # This should still run after `test_model_distributed` -@requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.distributed) @pytest.mark.parametrize("config_name", list(DISTRIBUTED_TESTING_CONFIGS)) diff --git a/tests/utils/compare_tensor_logs.py b/tests/utils/compare_tensor_logs.py index 1c8ebd76a..9a13fd13f 100644 --- a/tests/utils/compare_tensor_logs.py +++ b/tests/utils/compare_tensor_logs.py @@ -140,6 +140,7 @@ def compare_tensors(self, tensor_ref, tensor_test, errors, step_name, tensor_nam [ f" Test samples: " + "".join(f"{x:12.4e}" for x in samples_test[: self.show_samples].tolist()), f" Ref samples: " + "".join(f"{x:12.4e}" for x in samples_ref[: self.show_samples].tolist()), + f"scale={sub_config.scale}", ] ) errors.append("\n".join([f">>>> [{step_name}] Excessive diff for tensor {tensor_name}:"] + tensor_errors)) diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index 9c1cc9369..5b45371e6 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -2,6 +2,8 @@ import dataclasses import logging +import torch + from tests.utils.compare_tensor_logs import CompareConfig logger = logging.getLogger(__name__) @@ -58,8 +60,9 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon ("init", None): get_config(), (None, "fw"): get_config(1.5e-2, 1.5e-3), (None, "bw"): get_config(1.5e-2, 1e-5), - (None, "bias"): get_config(2e-2, 1e-3), - (None, "gradient"): get_config(2e-2, 5e-5), + # Error is higher on cpu. TODO: Diff too big, especially for bias. + (None, "bias"): get_config(2e-2, 1e-3) if torch.cuda.is_available() else get_config(0.25, 2e-3), + (None, "gradient"): get_config(2e-2, 5e-5) if torch.cuda.is_available() else get_config(8e-2, 1e-4), } ) @@ -69,8 +72,13 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Saved gradient include the gradient scaling by 2**16 (default initial value) (None, "fw"): get_config(1.2e-3, 3e-4), (None, "bw"): get_config(3e-3, 1e-5, scale=2**16), - (None, "bias"): get_config(3e-3, 1e-4, scale=2**16), - (None, "gradient"): get_config(3e-3, 5e-5, scale=2**16), + # Error is higher on cpu. + (None, "bias"): ( + get_config(3e-3, 1e-4, scale=2**16) if torch.cuda.is_available() else get_config(1e-2, 2e-4, scale=2**16) + ), + (None, "gradient"): ( + get_config(3e-3, 5e-5, scale=2**16) if torch.cuda.is_available() else get_config(1e-2, 1e-4, scale=2**16) + ), } ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index e176d9b32..3b79f7607 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) -requires_cuda = pytest.mark.skipif(False, reason="CUDA is not available") +requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") @pytest.fixture(scope="session")