diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d9f89997c..737f6fa3c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -32,7 +32,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Run tests run: pytest -v -ra . diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 5f5e59288..0893de47d 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -34,7 +34,7 @@ jobs: pip install pybind11 FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE \ MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE \ - pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV,DOCS,VISION]" + pip install --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV,DOCS]" - name: Build the documentation run: mkdocs build diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index dbd770895..e572e8e61 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -9,11 +9,11 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data from fast_llm.data.data.gpt.config import GPTDataConfig +from fast_llm.data.data_loader import SampledDatasetIterator from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.monitor import DatasetMonitor -from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank diff --git a/fast_llm/data/iterator.py b/fast_llm/data/data_loader.py similarity index 100% rename from fast_llm/data/iterator.py rename to fast_llm/data/data_loader.py diff --git a/fast_llm/engine/distributed/config.py b/fast_llm/engine/distributed/config.py index 7f4b7bc38..c2d6d1405 100644 --- a/fast_llm/engine/distributed/config.py +++ b/fast_llm/engine/distributed/config.py @@ -97,6 +97,31 @@ def setup(self, group: "ProcessGroup|None"): def check_ranks_in_range(self, start, stop): check_ranks_in_range(self.global_ranks, start, stop) + @classmethod + def from_sizes_and_strides(cls, name: str, global_rank: int, *sizes_and_strides: tuple[int, int]) -> typing.Self: + start = global_rank + rank = 0 + world_size = 1 + for i, (size, stride) in enumerate(sizes_and_strides): + if i > 0: + Assert.multiple(stride, sizes_and_strides[i - 1][1]) + rank_ = global_rank // stride % size + start -= rank_ * stride + rank += world_size * rank_ + world_size *= size + global_ranks = [start] + for size, stride in sizes_and_strides: + if size == 1: + continue + if len(global_ranks) == 1: + global_ranks = range(start, start + size * stride, stride) + elif isinstance(global_ranks, range) and stride == global_ranks.stop - global_ranks.start: + global_ranks = range(start, start + size * stride, global_ranks.step) + else: + global_ranks = [rank0 + rank1 for rank1 in range(0, size * stride, stride) for rank0 in global_ranks] + Assert.eq(len(global_ranks), world_size) + return DistributedDim(name=name, size=world_size, rank=rank, global_ranks=global_ranks) + def check_ranks_in_range(global_ranks, start, stop): Assert.geq(min(global_ranks), start) @@ -112,6 +137,7 @@ class DistributedDimNames: sequence_data = "sequence_data" batch_data = "batch_data" tensor_and_sequence_data = "tensor_and_sequence_data" + model_and_sequence_data = "model_and_sequence_data" tensor_and_data = "tensor_and_data" @@ -300,88 +326,68 @@ def _validate(self) -> None: else: self.distributed_dims = {} - data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + tensor_stride = 1 + sequence_data_stride = self.tensor_parallel * (self.pipeline_parallel if self.pipeline_first else 1) + batch_data_stride = sequence_data_stride * self.sequence_data_parallel pipeline_stride = self.tensor_parallel * (1 if self.pipeline_first else self.data_parallel) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.world, - size=self.world_size, - rank=self.rank, - global_ranks=range(self.world_size), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.world, + (self.world_size, 1), + ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.data, + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), + ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.pipeline, (self.pipeline_parallel, pipeline_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.data, - size=self.data_parallel, - rank=self.data_rank, - global_ranks=self._get_global_ranks(self.data_parallel, data_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor, (self.tensor_parallel, tensor_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.pipeline, - size=self.pipeline_parallel, - rank=self.pipeline_rank, - global_ranks=self._get_global_ranks(self.pipeline_parallel, pipeline_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.sequence_data, + (self.sequence_data_parallel, sequence_data_stride), ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor, - size=self.tensor_parallel, - rank=self.tensor_rank, - global_ranks=self._get_global_ranks(self.tensor_parallel, 1), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.batch_data, (self.batch_data_parallel, batch_data_stride) ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.sequence_data, - size=self.sequence_data_parallel, - rank=self.sequence_data_rank, - global_ranks=self._get_global_ranks(self.sequence_data_parallel, data_stride), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor_and_sequence_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.batch_data, - size=self.batch_data_parallel, - rank=self.batch_data_rank, - global_ranks=self._get_global_ranks( - self.batch_data_parallel, data_stride * self.sequence_data_parallel - ), - ) + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.tensor_and_data, + (self.tensor_parallel, tensor_stride), + (self.sequence_data_parallel, sequence_data_stride), + (self.batch_data_parallel, batch_data_stride), ) - # Global ranks wrong with pipeline first, so we hide the dims as a safety check. - if not self.pipeline_first: - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_sequence_data, - size=self.sequence_data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel, - global_ranks=self._get_global_ranks(self.sequence_data_parallel * self.tensor_parallel, 1), - ) - ) - self._add_distributed_dim( - DistributedDim( - name=DistributedDimNames.tensor_and_data, - size=self.data_parallel * self.tensor_parallel, - rank=self.tensor_rank + self.data_rank * self.tensor_parallel, - global_ranks=self._get_global_ranks(self.data_parallel * self.tensor_parallel, 1), - ) - ) - super()._validate() + self._add_distributed_dim_from_sizes_and_strides( + DistributedDimNames.model_and_sequence_data, + (self.tensor_parallel, tensor_stride), + ( + (self.pipeline_parallel, pipeline_stride) + if self.pipeline_first + else (self.sequence_data_parallel, sequence_data_stride) + ), + ( + (self.sequence_data_parallel, sequence_data_stride) + if self.pipeline_first + else (self.pipeline_parallel, pipeline_stride) + ), + ) + super()._validate() if self.reference_config is not None: self.compare(self.reference_config, ValueError) Assert.in_range(self.rank, 0, self.world_size) Assert.in_range(self.local_rank, 0, self.local_world_size) - def _get_global_ranks(self, size: int, stride: int) -> range: - start = self.rank // (size * stride) * size * stride + self.rank % stride - return range(start, start + size * stride, stride) + def _add_distributed_dim_from_sizes_and_strides(self, name: str, *sizes_and_strides: tuple[int, int]) -> None: + self._add_distributed_dim(DistributedDim.from_sizes_and_strides(name, self.rank, *sizes_and_strides)) def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None: Assert.eq(distributed_dim.global_ranks[distributed_dim.rank], self.rank, msg=distributed_dim) diff --git a/fast_llm/engine/distributed/distributed.py b/fast_llm/engine/distributed/distributed.py index aa2be6ce7..d93e17d1c 100644 --- a/fast_llm/engine/distributed/distributed.py +++ b/fast_llm/engine/distributed/distributed.py @@ -28,6 +28,7 @@ def __init__( local_world_size: int | None = None, timeout: float = 60, use_cpu: bool = False, + init_method: str = "env://", backend: DistributedBackend = DistributedBackend.nccl, ): @@ -58,7 +59,7 @@ def __init__( # TODO: Allow other init methods? self.store, _, _ = next( torch.distributed.rendezvous( - "env://", + init_method, self._rank, self._world_size, timeout=datetime.timedelta(seconds=timeout), @@ -180,14 +181,13 @@ def __init__(self, config: DistributedConfig, use_cpu: bool = False): self.tensor_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor]) self.sequence_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.sequence_data]) self.batch_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.batch_data]) - # Global ranks wrong with pipeline first, so we hide the dims as a safety check. - if not self._config.pipeline_first: - self.tensor_and_sequence_data_group = self.add_group( - self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] - ) - self.tensor_and_data_group = self.add_group( - self._config.distributed_dims[DistributedDimNames.tensor_and_data] - ) + self.tensor_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.tensor_and_sequence_data] + ) + self.tensor_and_data_group = self.add_group(self._config.distributed_dims[DistributedDimNames.tensor_and_data]) + self.model_and_sequence_data_group = self.add_group( + self._config.distributed_dims[DistributedDimNames.model_and_sequence_data] + ) self._config.log_first_rank(f"Setting random seeds...") diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 2ca61aa0e..fa4ea4c2f 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -167,7 +167,7 @@ def rms_close_relative(x, y, threshold, min_threshold=0, *, msg=None): ) @staticmethod - def all_equal(x, *args): + def all_equal(x, *args, msg=None): import torch # Make it work for lists and numpy arrays. @@ -181,7 +181,9 @@ def all_equal(x, *args): index = None if x.numel() == 1 else torch.where(neq) # noqa raise AssertionError( f"Tensors have {index[0].numel()} different entries out of " - f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + f"{x.numel()}: {x[index]} != {arg[index]} at index {torch.stack(index, -1)}" + "" + if msg is None + else f"| {msg}" ) @staticmethod diff --git a/tests/conftest.py b/tests/conftest.py index ba2927c64..e3a2df9a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,8 @@ ) from tests.utils.model_configs import model_testing_config, ModelTestingConfig, testing_group_enabled # isort: skip -from tests.utils.utils import result_path, format_resource_report, report_subtest # isort: skip +from tests.utils.utils import result_path # isort: skip +from tests.utils.subtest import format_resource_report, report_subtest, run_parallel_script # isort: skip # Import all dynamic classes. import fast_llm.cli # isort: skip diff --git a/tests/models/distributed_test_checkpoint.py b/tests/models/distributed_test_checkpoint.py deleted file mode 100644 index 407946545..000000000 --- a/tests/models/distributed_test_checkpoint.py +++ /dev/null @@ -1,93 +0,0 @@ -import gc -import logging - -import torch - -from fast_llm.cli import fast_llm_main_wrapper -from fast_llm.config import NoAutoValidate -from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.checkpoint.config import ( - CheckpointLoadConfig, - CheckpointSaveConfig, - DistributedCheckpointFormat, - FastLLMCheckpointFormat, -) -from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig -from fast_llm.engine.distributed.distributed import ProcessGroupPool -from fast_llm.engine.multi_stage.config import StageMode -from fast_llm.utils import Assert, header -from tests.utils.model_configs import ModelTestingConfig -from tests.utils.run_test_script import parse_run_distributed_script -from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig -from tests.utils.utils import DistributedSubtestContext - -logger = logging.getLogger(__name__) - - -def _test_load_and_save_parallel( - model_testing_config: ModelTestingConfig, - config: DistributedSaveLoadConfig, -): - logger.info(header(config.name)) - logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}") - with NoAutoValidate(): - load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format) - load_config.setup(model_testing_config.model_config_class) - load_config.validate() - model = model_testing_config.model_class.from_pretrained( - load_config, - # The world size and rank are already set through environment variable. - {"distributed": {**config.distributed, "backend": model_testing_config.distributed_backend}}, - mode=StageMode.inference, - ) - for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): - logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") - model.save_checkpoint(CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format)) - del model - gc.collect() - torch.cuda.empty_cache() - - -def main(args: list[str] | None = None) -> None: - base_path, model_testing_config, do_capture = parse_run_distributed_script(args) - - if do_capture: - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - - with ProcessGroupPool( - timeout=20, - backend=DistributedBackend(model_testing_config.distributed_backend), - ) as pool: - failures = [] - world_size = DistributedConfig.default_world_size - rank = DistributedConfig.default_rank - group = pool.get_process_group(range(world_size), rank) - - for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): - if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: - continue - config = config.resolve(base_path, model_testing_config) - Assert.eq(world_size, config.num_gpus) - with DistributedSubtestContext(base_path, config.name, group, world_size, enabled=do_capture) as subtest: - _test_load_and_save_parallel( - model_testing_config=model_testing_config, - config=config, - ) - if not subtest.success: - failures.append(config.name) - - # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(group, "testing end") - # Let pytest know how things went. - # These should already be reported above, we repeat for convenience. - if failures: - raise RuntimeError(f"The following subtests failed: {", ".join(failures)}") - else: - logger.warning("All tests passed") - - -if __name__ == "__main__": - with fast_llm_main_wrapper(): - main() diff --git a/tests/models/distributed_test_model.py b/tests/models/distributed_test_model.py deleted file mode 100644 index 29b68366d..000000000 --- a/tests/models/distributed_test_model.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging - -from fast_llm.cli import fast_llm_main_wrapper -from fast_llm.core.distributed import safe_barrier -from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig -from fast_llm.engine.distributed.distributed import ProcessGroupPool -from tests.utils.distributed_configs import DISTRIBUTED_TESTING_CONFIGS -from tests.utils.run_test_script import do_run_test_script_for_all_models, parse_run_distributed_script -from tests.utils.utils import DistributedSubtestContext - -logger = logging.getLogger(__name__) - - -def main(args: list[str] | None = None) -> None: - base_path, model_testing_config, do_capture = parse_run_distributed_script(args) - - if do_capture: - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - - # TODO: Why are barriers needed? - with ProcessGroupPool( - timeout=60, - backend=DistributedBackend(model_testing_config.distributed_backend), - ) as pool: - failures = [] - world_size = DistributedConfig.default_world_size - rank = DistributedConfig.default_rank - group = pool.get_process_group(range(world_size), rank) - safe_barrier(group, "start") - - for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): - if model_testing_config.should_skip(config): - continue - if world_size < config.num_gpus: - logger.warning(f"{name} {f"SKIPPED (not enough GPUs: {world_size} < {config.num_gpus})"})") - continue - with DistributedSubtestContext(base_path, name, group, config.num_gpus, enabled=do_capture) as subtest: - if rank < config.num_gpus: - do_run_test_script_for_all_models(config, model_testing_config, base_path) - if not subtest.success: - failures.append(name) - - # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(group, "testing end") - # Let pytest know how things went. - # These should already be reported above, we repeat for convenience. - if failures: - raise RuntimeError(f"The following subtests failed: {", ".join(failures)}") - else: - logger.warning("All tests passed") - - -if __name__ == "__main__": - with fast_llm_main_wrapper(): - main() diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index bb53de29e..9a3bc4345 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -1,3 +1,4 @@ +import gc import logging import pathlib import shutil @@ -7,6 +8,7 @@ import torch import yaml +from fast_llm.config import NoAutoValidate from fast_llm.engine.checkpoint.config import ( CheckpointFormat, CheckpointLoadConfig, @@ -16,12 +18,13 @@ ModelConfigType, ) from fast_llm.engine.checkpoint.convert import ConvertConfig -from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName -from fast_llm.utils import Assert +from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode +from fast_llm.utils import Assert, header from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup from tests.utils.save_load_configs import DISTRIBUTED_SAVE_LOAD_CONFIGS, DistributedSaveLoadConfig +from tests.utils.subtest import DistributedTestContext from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -152,7 +155,7 @@ def test_conversion(model_testing_config, run_conversion, get_convert_path): ) -def _compare_safetensor_files( +def compare_safetensor_files( reference: pathlib.Path | dict[str, torch.Tensor], *other_paths: pathlib.Path, expected_keys: set[str] | None = None, @@ -166,9 +169,10 @@ def _compare_safetensor_files( for other_path in other_paths: other = safetensors.torch.load_file(other_path) - Assert.eq(other.keys(), expected_keys) + if other.keys() != expected_keys: + raise ValueError(f"Expected keys {expected_keys} but got {other.keys()} in {other_path}") for key in expected_keys: - Assert.all_equal(reference[key], other[key]) + Assert.all_equal(reference[key], other[key], msg=f"tensor = {key}, path = {other_path}") @requires_cuda @@ -177,24 +181,24 @@ def _compare_safetensor_files( def test_converted_round_trip(model_testing_config, get_convert_path): # Test that the various possible conversion paths yield identical results. if model_testing_config.checkpoint_format is None: - _compare_safetensor_files( + compare_safetensor_files( get_convert_path() / "rank_0.safetensors", get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", expected_keys={_WEIGHT_SHARD_SAVE_NAME}, ) else: - _compare_safetensor_files( + compare_safetensor_files( get_convert_path() / "rank_0.safetensors", get_convert_path(DistributedCheckpointFormat, FastLLMCheckpointFormat) / "rank_0.safetensors", get_convert_path(DistributedCheckpointFormat, model_testing_config.checkpoint_format) / "rank_0.safetensors", expected_keys={_WEIGHT_SHARD_SAVE_NAME}, ) - _compare_safetensor_files( + compare_safetensor_files( get_convert_path(FastLLMCheckpointFormat, DistributedCheckpointFormat) / "model_0.safetensors", get_convert_path(FastLLMCheckpointFormat, model_testing_config.checkpoint_format) / "model_0.safetensors", ) - _compare_safetensor_files( + compare_safetensor_files( get_convert_path(model_testing_config.checkpoint_format, DistributedCheckpointFormat) / "model_0.safetensors", get_convert_path(model_testing_config.checkpoint_format, FastLLMCheckpointFormat) / "model_0.safetensors", @@ -391,31 +395,55 @@ def test_huggingface_model(model_testing_config, get_convert_path): raise ValueError(f"Comparison failed ({len(errors)} errors)") +def _save_and_load_in_parallel( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for config in DISTRIBUTED_SAVE_LOAD_CONFIGS.values(): + if config.load_format == "{checkpoint_format}" and model_testing_config.checkpoint_format is None: + continue + config = config.resolve(base_path, model_testing_config) + with test_context.subtest(base_path, config.name, config.num_gpus) as subtest: + if subtest.do_run: + logger.info(header(config.name)) + logger.info(f"Loading {config.load_format} checkpoint from {config.load_path}") + with NoAutoValidate(): + load_config = CheckpointLoadConfig(path=config.load_path, format=config.load_format) + load_config.setup(model_testing_config.model_config_class) + load_config.validate() + model = model_testing_config.model_class.from_pretrained( + load_config, + # The world size and rank are already set through environment variable. + {"distributed": config.distributed}, + mode=StageMode.inference, + ) + for save_format in (DistributedCheckpointFormat, FastLLMCheckpointFormat): + logger.info(f"Saving {save_format.name} checkpoint to {config.save_path / save_format.name}") + model.save_checkpoint( + CheckpointSaveConfig(path=config.save_path / save_format.name, format=save_format) + ) + del model + gc.collect() + torch.cuda.empty_cache() + + @requires_cuda @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert, ModelTestingGroup.distributed) -def test_save_and_load_in_parallel(run_distributed_script, run_test_script_base_path, model_testing_config, request): +def test_save_and_load_in_parallel(run_parallel_script, run_test_script_base_path, model_testing_config): # Save and load checkpoints to and from various distributed configurations. # Combined in a single test to mitigate process creation overhead. # TODO: Test beyond 2 gpu configs? - import tests.models.distributed_test_checkpoint - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < 2") - - script = [ - "-m", - tests.models.distributed_test_checkpoint.__name__, - str(run_test_script_base_path), - model_testing_config.name, - ] - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - run_distributed_script(script, num_gpus=2) + pytest.skip(f"Not enough GPUs2") + run_parallel_script( + _save_and_load_in_parallel, + (run_test_script_base_path, model_testing_config), + world_size=2, + backend=model_testing_config.distributed_backend, + ) @pytest.fixture(scope="module") @@ -472,7 +500,7 @@ def test_parallel_checkpoint_consistency(model_testing_config, run_test_script_b # Compare Distributed checkpoints for config in ("dp2", "tp2", "stp2", "pp2"): for rank in range(2): - _compare_safetensor_files( + compare_safetensor_files( *[ DISTRIBUTED_SAVE_LOAD_CONFIGS[f"load_{format_}_in_{config}"] .resolve(base_path=run_test_script_base_path, model_testing_config=model_testing_config) @@ -510,7 +538,7 @@ def test_multi_gpu_fast_llm_checkpoint( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) - _compare_safetensor_files( + compare_safetensor_files( reference_fast_llm_shard, distributed_save_load_config_non_pp.save_path / f"{FastLLMCheckpointFormat.name}/model_0.safetensors", ) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index d14721142..58768bc52 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,4 +1,5 @@ import logging +import pathlib import pytest import torch @@ -8,8 +9,10 @@ SIMPLE_TESTING_CONFIG, SINGLE_GPU_TESTING_CONFIGS, ) -from tests.utils.model_configs import ModelTestingGroup -from tests.utils.utils import check_subtest_success, requires_cuda, set_subtest_success +from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup +from tests.utils.run_test_script import do_run_test_script_for_all_models +from tests.utils.subtest import DistributedTestContext, check_subtest_success, set_subtest_success +from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -49,27 +52,34 @@ def test_and_compare_model( compare_results_for_all_models(config) +def _run_model_distributed( + test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig +) -> None: + # Import all dynamic classes. + import fast_llm.cli # noqa + + for name, config in DISTRIBUTED_TESTING_CONFIGS.items(): + if model_testing_config.should_skip(config): + continue + with test_context.subtest(base_path, name, config.num_gpus) as subtest: + if subtest.do_run: + do_run_test_script_for_all_models(config, model_testing_config, base_path) + + @requires_cuda @pytest.mark.depends_on(on=["test_model_simple[{model_testing_config}]"]) @pytest.mark.model_testing_group( ModelTestingGroup.distributed, ) -def test_run_model_distributed(run_distributed_script, model_testing_config, run_test_script_base_path, request): - import tests.models.distributed_test_model - - script = [ - "-m", - tests.models.distributed_test_model.__name__, - str(run_test_script_base_path), - model_testing_config.name, - ] - if request.config.getoption("distributed_capture"): - logger.warning( - "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." - ) - else: - script.append("--no-distributed-capture") - run_distributed_script(script, num_gpus=torch.cuda.device_count()) +def test_run_model_distributed(run_parallel_script, model_testing_config, run_test_script_base_path): + if torch.cuda.device_count() < 2: + pytest.skip(f"Not enough GPUs") + run_parallel_script( + _run_model_distributed, + (run_test_script_base_path, model_testing_config), + world_size=torch.cuda.device_count(), + backend=model_testing_config.distributed_backend, + ) # We don't want to depend on `test_model_distributed` because we still want to run this in cas of failure. diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1248a1117..f399cda4c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -12,6 +12,7 @@ from fast_llm.config import set_nested_dict_value from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.models.gpt.conversion.config import ( @@ -149,15 +150,15 @@ def base_model_config_class(self): @functools.cached_property def distributed_backend(self): - return self.config_dict["model"]["distributed"]["backend"] + return DistributedBackend(self.config_dict["model"]["distributed"]["backend"]) def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) -def _update_and_add_testing_config( - old_name: str, - new_name: str, +def update_and_add_testing_config( + old_name: str | ModelTestingConfig, + new_name: str | None, *, model_type: str | None = None, updates: dict[str | tuple[str, ...], typing.Any] | None = None, @@ -166,7 +167,7 @@ def _update_and_add_testing_config( **kwargs, ) -> ModelTestingConfig: - config = MODEL_CONFIGS[old_name] + config = old_name if isinstance(old_name, ModelTestingConfig) else MODEL_CONFIGS[old_name] config_dict = copy.deepcopy(config.config_dict) if updates is not None: for keys, update in updates.items(): @@ -178,14 +179,15 @@ def _update_and_add_testing_config( megatron_args = config.megatron_args + megatron_args new_config = dataclasses.replace( config, - name=new_name, + name=config.name if new_name is None else new_name, model_type=config.model_type if model_type is None else model_type, groups=groups, config_dict=config_dict, megatron_args=megatron_args, **kwargs, ) - MODEL_CONFIGS[new_name] = new_config + if new_name is not None: + MODEL_CONFIGS[new_name] = new_config return new_config @@ -309,7 +311,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests MQA. "gpt_2", "starcoder", @@ -328,7 +330,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests intermediate between gpt2 and llama, closest converter to gpt2. "gpt_2", "starcoder_2", @@ -357,7 +359,7 @@ def _update_and_add_testing_config( del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"] -_update_and_add_testing_config( +update_and_add_testing_config( # Main tested model. "starcoder_2", "llama", @@ -389,7 +391,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests llama3-style rotary embeddings. "llama", "llama_3", @@ -409,7 +411,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests yarn-style rotary embeddings. "llama", "llama_yarn", @@ -429,7 +431,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests diffusion llama converter. "llama_yarn", "diffusion_llama", @@ -454,7 +456,7 @@ def _update_and_add_testing_config( _llama_block = MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["decoder"]["block"] -_update_and_add_testing_config( +update_and_add_testing_config( # Tests multi-token prediction, custom HF model and converter. "llama", "mtp_llama", @@ -484,7 +486,7 @@ def _update_and_add_testing_config( skip_tests=(r"ce4", r"ms"), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests partial linear biases, Qwen2 converter. "llama", "qwen_2", @@ -506,7 +508,7 @@ def _update_and_add_testing_config( }, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests diffusion dream converter. "qwen_2", "dream", @@ -528,7 +530,7 @@ def _update_and_add_testing_config( auto_model_class=transformers.AutoModel, ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests sliding window attention, mistral converter. "llama", "mistral", @@ -551,7 +553,7 @@ def _update_and_add_testing_config( _mistral_base_model = MODEL_CONFIGS["mistral"].config_dict["model"]["base_model"] -_update_and_add_testing_config( +update_and_add_testing_config( # Tests logit distillation. "mistral", "mistral_distill_logits", @@ -579,7 +581,7 @@ def _update_and_add_testing_config( skip_tests=("ms", "pp2s1_bf4", "pp2s2_bf4", "sdp2"), ) -_update_and_add_testing_config( +update_and_add_testing_config( "mistral_distill_logits", "mistral_reverse_kl", updates={ @@ -600,7 +602,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms", "pp"), ) -_update_and_add_testing_config( +update_and_add_testing_config( "mistral_distill_logits", "mistral_distill_activations", updates={ @@ -631,7 +633,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms", "pp", "tp", GRAD_ACC, "fp16"), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests mixture of experts, mixtral converter. "llama", "mixtral", @@ -659,7 +661,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests hybrid Mamba 2. "llama", "hybrid_mamba", @@ -700,7 +702,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms"), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests vision multimodal. "llama", "llava", @@ -743,7 +745,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests hybrid with attention + gated delta net mixer. "llama", "apriel2_text_gdn_hybrid", @@ -794,7 +796,7 @@ def _update_and_add_testing_config( skip_tests=("sdp", "ms", TP_NO_STP), ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests apriel2 format with pattern decoder mixing all mixer types. # This comprehensive test exercises: attention, mamba, stochastic mixer, sliding window attention, gdn. "llama", @@ -917,7 +919,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests apriel2 multimodal format combining pattern decoder with vision encoder. # Uses the same decoder as apriel2_text_all_hybrid but adds vision capabilities. "apriel2_text_all_hybrid", @@ -960,7 +962,7 @@ def _update_and_add_testing_config( ) -_update_and_add_testing_config( +update_and_add_testing_config( # Tests hybrid with KDA mixer. "llama", "hybrid_kda", @@ -1014,7 +1016,7 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo model_testing_config = item.callspec.params["model_testing_config"] model_config: ModelTestingConfig = MODEL_CONFIGS[model_testing_config] for group in groups: - action = model_config.groups[group] + action = model_config.groups.get(group, ModelTestingGroupAction.unimportant) if action == ModelTestingGroupAction.main: pass elif action == ModelTestingGroupAction.normal and not skip_slow: diff --git a/tests/utils/run_test_script.py b/tests/utils/run_test_script.py index 5c07324cf..0b8232cf7 100644 --- a/tests/utils/run_test_script.py +++ b/tests/utils/run_test_script.py @@ -1,4 +1,3 @@ -import argparse import functools import os import pathlib @@ -12,7 +11,7 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import Assert from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.model_configs import MODEL_CONFIGS, ModelTestingConfig +from tests.utils.model_configs import ModelTestingConfig if typing.TYPE_CHECKING: from tests.conftest import WorkerResources @@ -47,11 +46,7 @@ def do_run_distributed_script( @pytest.fixture(scope="session") -def run_distributed_script( - worker_resources: "WorkerResources", - run_test_script_base_path: pathlib.Path, - model_testing_config: ModelTestingConfig, -): +def run_distributed_script(worker_resources: "WorkerResources"): return functools.partial( do_run_distributed_script, rendezvous_port=worker_resources.rendezvous_port, @@ -144,16 +139,6 @@ def run_test_script_for_all_models( ) -def parse_run_distributed_script(args: list[str] | None = None): - parser = argparse.ArgumentParser() - parser.add_argument("base_path", type=pathlib.Path) - parser.add_argument("model_testing_config", type=str) - parser.add_argument("--no-distributed-capture", dest="distributed_capture", action="store_false") - - parsed = parser.parse_args(args) - return parsed.base_path, MODEL_CONFIGS[parsed.model_testing_config], parsed.distributed_capture - - @pytest.fixture(scope="session") def compare_results_for_all_models( worker_resources: "WorkerResources", diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py new file mode 100644 index 000000000..4fea1fbba --- /dev/null +++ b/tests/utils/subtest.py @@ -0,0 +1,273 @@ +import functools +import json +import logging +import math +import pathlib +import sys +import time +import traceback +import typing + +import pytest +import torch + +from fast_llm.core.distributed import allreduce_scalar, safe_barrier +from fast_llm.engine.config_utils.logging import configure_logging +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig +from fast_llm.engine.distributed.distributed import ProcessGroupPool +from fast_llm.utils import Assert, get_and_reset_memory_usage_mib, header + +logger = logging.getLogger(__name__) + + +class DistributedTestContext: + def __init__( + self, + do_capture: bool, + timeout: float = 20.0, + init_method: str = "env://", + backend: DistributedBackend = DistributedBackend.nccl, + ) -> None: + self._do_capture = do_capture + self._timeout = timeout + self._init_method = init_method + self._backend = backend + + def __enter__(self): + if self._do_capture: + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + + self._pool = ProcessGroupPool( + timeout=self._timeout, init_method=self._init_method, backend=self._backend + ).__enter__() + self._rank = self._pool.rank + self._world_size = self._pool.world_size + self._failures = [] + self._configure_logging() + self._group = self._pool.get_process_group(range(self._world_size), self._rank) + # TODO: Barriers needed? + safe_barrier(self._group, "start") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Final barrier to ensure everything is done before torchrun potentially kills workers. + safe_barrier(self._group, "testing end") + # Let pytest know how things went. + # These should already be reported above, we repeat for convenience. + if self._failures: + raise RuntimeError(f"The following subtests failed: {", ".join(self._failures)}") + else: + logger.warning("All tests passed") + + def subtest(self, base_path: pathlib.Path, name: str, num_gpus: int): + return self.DistributedSubtestContext(self, base_path, name, num_gpus) + + def _configure_logging(self): + configure_logging(rank=self._rank, world_size=self._world_size) + + @property + def rank(self) -> int: + return self._rank + + @property + def world_size(self) -> int: + return self._world_size + + class DistributedSubtestContext: + def __init__( + self, test_context: "DistributedTestContext", base_path: pathlib.Path, name: str, num_gpus: int + ) -> None: + self._test_context = test_context + self._path = base_path / name + self._name = name + self._num_gpus = num_gpus + self._skip = self._test_context._world_size < self._num_gpus + self._do_run = self._test_context._rank < num_gpus and not self._skip + self._do_capture = self._test_context._do_capture and self._do_run + self._success = False + + def __enter__(self) -> typing.Self: + if self._do_capture: + self._sys_stdout = sys.stdout + self._sys_stderr = sys.stderr + self._path.mkdir(parents=True, exist_ok=True) + sys.stdout = self._path.joinpath(f"pytest_stdout_{self._test_context._rank}").open("w") + sys.stderr = self._path.joinpath(f"pytest_stderr_{self._test_context._rank}").open("w") + self._test_context._configure_logging() + # Logging is set to log to the old stdout, so we need to reconfigure. + self._start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._skip: + # Skipped tests should exit right away. + Assert.none(exc_val) + logger.warning( + f"{self._name} {f"SKIPPED (not enough GPUs: {self._test_context._world_size} < {self._num_gpus})"})" + ) + return + + if self._do_capture: + try: + stdout_handle = sys.stdout + stderr_handle = sys.stderr + sys.stdout = self._sys_stdout + sys.stderr = self._sys_stderr + stdout_handle.close() + stderr_handle.close() + finally: + assert DistributedConfig.default_world_size > 1 + self._test_context._configure_logging() + + if exc_type is None: + self._success = True + else: + self._path.mkdir(parents=True, exist_ok=True) + self._path.joinpath(f"pytest_traceback_{self._test_context._rank}").write_text(traceback.format_exc()) + + logger.warning(f"{self._name} done, waiting for other ranks ({"PASSED" if self._success else "FAILED"})") + + if (group := self._test_context._group) is not None: + # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() + + if self._do_capture: + # Free resources to limit memory usage. + report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) + report["duration"] = time.perf_counter() - self._start + + json.dump(report, self._path.joinpath(f"pytest_report_{self._test_context._rank}").open("w")) + + if self._test_context._rank == 0: + set_subtest_success(self._path, self._success) + logger.warning(f"{self._name} {"PASSED" if self._success else "FAILED"}") + if not self._success: + self._test_context._failures.append(self._name) + + return True + + @property + def do_run(self) -> bool: + return self._do_run and not self._skip + + +def set_subtest_success(path: pathlib.Path, success: bool = True): + path.joinpath("pytest_success").write_text(str(int(success))) + + +def check_subtest_success(path: pathlib, fail: bool = True) -> bool: + if not path.is_dir(): + if fail: + pytest.fail(f"Test {path.name} did not run", pytrace=False) + else: + return False + try: + return bool(int(path.joinpath("pytest_success").read_text())) + except OSError: + return False + + +def format_resource_report(title: str, report: dict[str, float]) -> str: + return "".join( + [ + f"{title}:\n ", + f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB", + f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26), + f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25), + f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26), + f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18), + f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "", + ] + ) + + +@pytest.fixture(scope="function") +def report_subtest(request: pytest.FixtureRequest): + verbose = request.config.getoption("verbose") + do_capture = request.config.getoption("distributed_capture") + + def do_report_subtest(path: pathlib.Path, world_size: int) -> None: + success = check_subtest_success(path) + if not do_capture: + logger.warning("Distributed capture is disabled. See distributed test for run output.") + elif verbose > 1 or not success: + for rank in range(world_size): + for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)): + print(header(f"{fd} rank {rank}", 80), file=file_) + file_path = path / f"pytest_{fd}_{rank}" + try: + print(file_path.read_text(), file=file_) + except OSError: + print(f"<<< not found {file_path}>>>", file=file_) + else: + print("Set verbose > 1 to show run output.") + + reports = {} + for rank in range(world_size): + try: + reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r")) + except OSError: + reports[rank] = {} + keys = {key for report in reports.values() for key in report} + report = {key: max(report[key] for report in reports.values() if key in report) for key in keys} + report["gpus"] = world_size + reports["global"] = report + + print(header(f"Resource usage", 80), file=sys.stderr) + for name, report in reports.items(): + print(format_resource_report(name, report), file=sys.stderr) + setattr(request.node, "fast_llm_resource_report", report) + + if not success: + raise RuntimeError(f"test {path.name} failed") + + return do_report_subtest + + +def parallel_worker( + rank: int, + world_size: int, + init_method: str, + backend: DistributedBackend, + do_capture: bool, + fn: typing.Callable, + fn_args: typing.Sequence[typing.Any], +): + DistributedConfig.default_rank = rank + DistributedConfig.default_world_size = world_size + DistributedConfig.default_local_world_size = world_size + with DistributedTestContext(do_capture, 60, init_method, backend) as test_context: + fn(test_context, *fn_args) + + +def do_run_parallel_script( + fn: typing.Callable, + fn_args: typing.Sequence[typing.Any], + port: int, + do_capture: bool, + world_size: int, + timeout: float = 240, + backend: DistributedBackend = DistributedBackend.nccl, +): + if do_capture: + logger.warning( + "Capturing output and forwarding to associated tests. Run with `--no-distributed-capture` to disable." + ) + torch.multiprocessing.spawn( + parallel_worker, + args=(world_size, f"tcp://localhost:{port}", backend, do_capture, fn, fn_args), + nprocs=world_size, + join=False, + ).join(timeout, grace_period=5) + + +@pytest.fixture(scope="session") +def run_parallel_script(worker_resources: "WorkerResources", request: pytest.FixtureRequest): + return functools.partial( + do_run_parallel_script, + port=worker_resources.rendezvous_port, + do_capture=request.config.getoption("distributed_capture"), + ) diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 3b79f7607..f0ca20db8 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -1,23 +1,14 @@ -import json import logging -import math -import pathlib -import sys -import time -import traceback import typing import pytest import torch -from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import set_model_names -from fast_llm.engine.config_utils.logging import configure_logging from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage -from fast_llm.utils import get_and_reset_memory_usage_mib, header from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) @@ -65,137 +56,3 @@ def get_stage( stage.restore_parameters() stage.reset_gradients() return stage - - -class DistributedSubtestContext: - def __init__( - self, base_path: pathlib.Path, name: str, group: ProcessGroup | None, num_gpus: int, enabled: bool = True - ) -> None: - self._path = base_path / name - self._name = name - self._group = group - self._rank = 0 if group is None else group.rank() - self._rank_enabled = self._rank < num_gpus - self._enabled = enabled and self._rank_enabled - self.success = False - - def __enter__(self) -> typing.Self: - if self._enabled: - self._sys_stdout = sys.stdout - self._sys_stderr = sys.stderr - self._path.mkdir(parents=True, exist_ok=True) - sys.stdout = self._path.joinpath(f"pytest_stdout_{self._rank}").open("w") - sys.stderr = self._path.joinpath(f"pytest_stderr_{self._rank}").open("w") - # Logging is set to log to the old stdout, so we need to reconfigure. - configure_logging() - self._start = time.perf_counter() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._enabled: - try: - stdout_handle = sys.stdout - stderr_handle = sys.stderr - sys.stdout = self._sys_stdout - sys.stderr = self._sys_stderr - stdout_handle.close() - stderr_handle.close() - finally: - configure_logging() - - if exc_type is None: - self.success = True - else: - self._path.mkdir(parents=True, exist_ok=True) - self._path.joinpath(f"pytest_traceback_{self._rank}").write_text(traceback.format_exc()) - - if self._group is not None: - # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(self._group, self._name) - self.success = allreduce_scalar(self.success, dtype=torch.int64, group=self._group) == self._group.size() - - if self._rank_enabled: - # Free resources to limit memory usage. - report = get_and_reset_memory_usage_mib(clear_cache=True, global_stats=True, reset_global_stats=True) - report["duration"] = time.perf_counter() - self._start - - json.dump(report, self._path.joinpath(f"pytest_report_{self._rank}").open("w")) - - logger.warning(f"{self._name} {"PASSED" if self.success else "FAILED"})") - if self._rank == 0: - set_subtest_success(self._path, self.success) - - return True - - -def set_subtest_success(path: pathlib.Path, success: bool = True): - path.joinpath("pytest_success").write_text(str(int(success))) - - -def check_subtest_success(path: pathlib, fail: bool = True) -> bool: - if not path.is_dir(): - if fail: - pytest.fail(f"Test {path.name} did not run", pytrace=False) - else: - return False - try: - return bool(int(path.joinpath("pytest_success").read_text())) - except OSError: - return False - - -def format_resource_report(title: str, report: dict[str, float]) -> str: - return "".join( - [ - f"{title}:\n ", - f"Max Reserved: {report.get("max_reserved", math.nan):.0f} MiB", - f"| Max Allocated: {report.get("max_allocated", math.nan):.0f} MiB".ljust(26), - f"| End Reserved: {report.get("reserved", math.nan):.0f} MiB".ljust(25), - f"| End Allocated: {report.get("allocated", math.nan):.0f} MiB".ljust(26), - f"| Duration: {report.get("duration", math.nan):.2f}".ljust(18), - f"| GPUs: {report["gpus"]:.0f}" if "gpus" in report else "", - ] - ) - - -@pytest.fixture(scope="function") -def report_subtest(request: pytest.FixtureRequest): - verbose = request.config.getoption("verbose") - do_capture = request.config.getoption("distributed_capture") - - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: - success = check_subtest_success(path) - if not do_capture: - logger.warning("Distributed capture is disabled. See distributed test for run output.") - elif verbose > 1 or not success: - for rank in range(world_size): - for fd, file_ in (("stdout", sys.stdout), ("stderr", sys.stdout), ("traceback", sys.stderr)): - print(header(f"{fd} rank {rank}", 80), file=file_) - file_path = path / f"pytest_{fd}_{rank}" - try: - print(file_path.read_text(), file=file_) - except OSError: - print(f"<<< not found {file_path}>>>", file=file_) - else: - print("Set verbose > 1 to show run output.") - - reports = {} - for rank in range(world_size): - try: - reports[f"rank_{rank}"] = json.load(path.joinpath(f"pytest_report_{rank}").open("r")) - except OSError: - reports[rank] = {} - keys = {key for report in reports.values() for key in report} - report = {key: max(report[key] for report in reports.values() if key in report) for key in keys} - report["gpus"] = world_size - reports["global"] = report - - print(header(f"Resource usage", 80), file=sys.stderr) - for name, report in reports.items(): - print(format_resource_report(name, report), file=sys.stderr) - setattr(request.node, "fast_llm_resource_report", report) - - if not success: - raise RuntimeError(f"test {path.name} failed") - - return do_report_subtest