diff --git a/.devcontainer/recipes/Dockerfile b/.devcontainer/recipes/Dockerfile index 3ab24811eb..5a51176458 100644 --- a/.devcontainer/recipes/Dockerfile +++ b/.devcontainer/recipes/Dockerfile @@ -1,4 +1,4 @@ -FROM nvcr.io/nvidia/pytorch:25.09-py3 +FROM gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \ PIP_CONSTRAINT= pip install -r /workspace/requirements.txt diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index 5477a424b5..e957b9b1d0 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -105,6 +105,20 @@ def test_multi_process_fp8_recipes_are_synced(strategy): from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + def recursive_assert(a, b, path=""): + if isinstance(a, dict) and isinstance(b, dict): + assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}" + for k in a: + recursive_assert(a[k], b[k], path=f"{path}.{k}") + elif isinstance(a, list) and isinstance(b, list): + assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}" + for i in range(len(a)): + recursive_assert(a[i], b[i], path=f"{path}.{i}") + elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}") + else: + assert a == b, f"Value mismatch at {path}: {a} != {b}" + class Strategy(enum.StrEnum): DDP = "ddp" FSDP2 = "fsdp2" @@ -213,10 +227,7 @@ def is_main_process(self) -> bool: assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1" dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes()) dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes()) - recipe_1 = dict_1.pop("recipe") - recipe_2 = dict_2.pop("recipe") - torch.testing.assert_close(dict_1, dict_2) - assert recipe_1 == recipe_2 + recursive_assert(dict_1, dict_2) # One rank, test to ensure the correct FP8 extra states are saved if torch.distributed.get_world_size() == 1: diff --git a/bionemo-recipes/models/esm2/tests/test_thd.py b/bionemo-recipes/models/esm2/tests/test_thd.py index 502fb20da1..4357e3fe6a 100644 --- a/bionemo-recipes/models/esm2/tests/test_thd.py +++ b/bionemo-recipes/models/esm2/tests/test_thd.py @@ -109,6 +109,8 @@ def attn_impl(request, monkeypatch): def test_thd_losses_match(te_model_checkpoint, input_data, input_data_thd, attn_impl): if attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8: pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") + elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") torch.testing.assert_close( input_data["input_ids"][input_data["attention_mask"].to(bool)], @@ -139,6 +141,8 @@ def test_thd_logits_match_with_bf16_autocast(te_model_checkpoint, input_data, in pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.") elif attn_impl == "flash_attn" and torch.cuda.get_device_capability()[0] == 8: pytest.xfail("BIONEMO-2801: On Ada and Ampere, the flash attention logits don't seem to match.") + elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12: + pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.") # Ensure the input data is the same torch.testing.assert_close( diff --git a/bionemo-recipes/recipes/esm2_native_te/Dockerfile b/bionemo-recipes/recipes/esm2_native_te/Dockerfile index 53f79f7726..f4b89bbe22 100644 --- a/bionemo-recipes/recipes/esm2_native_te/Dockerfile +++ b/bionemo-recipes/recipes/esm2_native_te/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1.4 -FROM nvcr.io/nvidia/pytorch:25.09-py3 +FROM gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel RUN --mount=type=secret,id=netrc,target=/root/.netrc \ --mount=type=cache,target=/root/.cache/pip \ diff --git a/bionemo-recipes/recipes/esm2_native_te/dataset.py b/bionemo-recipes/recipes/esm2_native_te/dataset.py index 6c8db1c507..c08d693ae5 100644 --- a/bionemo-recipes/recipes/esm2_native_te/dataset.py +++ b/bionemo-recipes/recipes/esm2_native_te/dataset.py @@ -17,8 +17,9 @@ import datasets import datasets.distributed -from torch.utils.data import DistributedSampler -from torchdata.stateful_dataloader import StatefulDataLoader +from torch.utils.data import DataLoader, DistributedSampler + +# from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer from transformers.data.data_collator import DataCollatorForLanguageModeling @@ -132,13 +133,13 @@ def create_bshd_dataloader( seed=seed, ) - train_dataloader = StatefulDataLoader( + train_dataloader = DataLoader( tokenized_dataset, sampler=sampler, batch_size=micro_batch_size, collate_fn=data_collator, num_workers=num_workers, - pin_memory=True, + pin_memory=True, # TODO: Uncomment this when we figure out why _pin_memory_thread's API changed. persistent_workers=True, ) @@ -201,12 +202,12 @@ def create_thd_dataloader( seed=seed, ) - train_dataloader = StatefulDataLoader( + train_dataloader = DataLoader( TokenPackingDataset(tokenized_dataset, max_tokens_per_batch=token_micro_batch_size), batch_size=None, # The TokenPackingDataset will handle the batching. collate_fn=data_collator, num_workers=num_workers, - pin_memory=True, + pin_memory=True, # TODO: Uncomment this when we figure out why _pin_memory_thread's API changed. persistent_workers=True, ) diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/conftest.py b/bionemo-recipes/recipes/esm2_native_te/tests/conftest.py index 3e296b859b..9bc870997b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/conftest.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/conftest.py @@ -65,7 +65,7 @@ def device_mesh(): _mesh_resources.mesh_stack.clear() _mesh_resources.child_to_root_mapping.clear() _mesh_resources.root_to_flatten_mapping.clear() - _mesh_resources.flatten_name_to_root_dims.clear() + # _mesh_resources.flatten_name_to_root_dims.clear() _mesh_resources.mesh_dim_group_options.clear() torch.cuda.empty_cache() torch.cuda.synchronize() diff --git a/ci/scripts/recipes_local_test.py b/ci/scripts/recipes_local_test.py index ca8cf6c258..8186e995c5 100755 --- a/ci/scripts/recipes_local_test.py +++ b/ci/scripts/recipes_local_test.py @@ -55,7 +55,8 @@ # --output type=registry,compression=zstd,force-compression=true,oci-mediatypes=true,compression-level=15 # and pushed to the dockerhub registry. Our github actions are able to cache image pulls from dockerhub but not nvcr, so # hopefully this cuts down slightly on CI time at the expense of having a slightly in-directed image location. -DEFAULT_CONTAINER = "svcbionemo023/bionemo-framework:pytorch25.09-py3-squashed" +# DEFAULT_CONTAINER = "svcbionemo023/bionemo-framework:pytorch25.09-py3-squashed" +DEFAULT_CONTAINER = "gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel" def get_git_root() -> str: