Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/recipes/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvcr.io/nvidia/pytorch:25.09-py3
FROM gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=requirements.txt,target=/workspace/requirements.txt \
PIP_CONSTRAINT= pip install -r /workspace/requirements.txt
Expand Down
19 changes: 15 additions & 4 deletions bionemo-recipes/models/esm2/tests/test_distributed_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ def test_multi_process_fp8_recipes_are_synced(strategy):

from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM

def recursive_assert(a, b, path=""):
if isinstance(a, dict) and isinstance(b, dict):
assert a.keys() == b.keys(), f"Dictionary keys mismatch: {a.keys()} != {b.keys()} at {path}"
for k in a:
recursive_assert(a[k], b[k], path=f"{path}.{k}")
elif isinstance(a, list) and isinstance(b, list):
assert len(a) == len(b), f"List lengths mismatch: {len(a)} != {len(b)} at {path}"
for i in range(len(a)):
recursive_assert(a[i], b[i], path=f"{path}.{i}")
elif isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
torch.testing.assert_close(a, b, msg=f"Tensor mismatch at {path}")
else:
assert a == b, f"Value mismatch at {path}: {a} != {b}"

class Strategy(enum.StrEnum):
DDP = "ddp"
FSDP2 = "fsdp2"
Expand Down Expand Up @@ -213,10 +227,7 @@ def is_main_process(self) -> bool:
assert len(state_2) > 0, f"No FP8 extra states for {key}, rank 1"
dict_1 = pickle.loads(state_1.detach().numpy(force=True).tobytes())
dict_2 = pickle.loads(state_2.detach().numpy(force=True).tobytes())
recipe_1 = dict_1.pop("recipe")
recipe_2 = dict_2.pop("recipe")
torch.testing.assert_close(dict_1, dict_2)
assert recipe_1 == recipe_2
recursive_assert(dict_1, dict_2)

# One rank, test to ensure the correct FP8 extra states are saved
if torch.distributed.get_world_size() == 1:
Expand Down
4 changes: 4 additions & 0 deletions bionemo-recipes/models/esm2/tests/test_thd.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def attn_impl(request, monkeypatch):
def test_thd_losses_match(te_model_checkpoint, input_data, input_data_thd, attn_impl):
if attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

torch.testing.assert_close(
input_data["input_ids"][input_data["attention_mask"].to(bool)],
Expand Down Expand Up @@ -139,6 +141,8 @@ def test_thd_logits_match_with_bf16_autocast(te_model_checkpoint, input_data, in
pytest.xfail("On Ada and Ampere, no THD implementation is available for fused attn.")
elif attn_impl == "flash_attn" and torch.cuda.get_device_capability()[0] == 8:
pytest.xfail("BIONEMO-2801: On Ada and Ampere, the flash attention logits don't seem to match.")
elif attn_impl == "fused_attn" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail("BIONEMO-2840: On sm120, the THD implementation is not available for fused attn.")

# Ensure the input data is the same
torch.testing.assert_close(
Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/recipes/esm2_native_te/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# syntax=docker/dockerfile:1.4
FROM nvcr.io/nvidia/pytorch:25.09-py3
FROM gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel

RUN --mount=type=secret,id=netrc,target=/root/.netrc \
--mount=type=cache,target=/root/.cache/pip \
Expand Down
13 changes: 7 additions & 6 deletions bionemo-recipes/recipes/esm2_native_te/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import datasets
import datasets.distributed
from torch.utils.data import DistributedSampler
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.utils.data import DataLoader, DistributedSampler

# from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorForLanguageModeling

Expand Down Expand Up @@ -132,13 +133,13 @@ def create_bshd_dataloader(
seed=seed,
)

train_dataloader = StatefulDataLoader(
train_dataloader = DataLoader(
tokenized_dataset,
sampler=sampler,
batch_size=micro_batch_size,
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=True,
pin_memory=True, # TODO: Uncomment this when we figure out why _pin_memory_thread's API changed.
persistent_workers=True,
)

Expand Down Expand Up @@ -201,12 +202,12 @@ def create_thd_dataloader(
seed=seed,
)

train_dataloader = StatefulDataLoader(
train_dataloader = DataLoader(
TokenPackingDataset(tokenized_dataset, max_tokens_per_batch=token_micro_batch_size),
batch_size=None, # The TokenPackingDataset will handle the batching.
collate_fn=data_collator,
num_workers=num_workers,
pin_memory=True,
pin_memory=True, # TODO: Uncomment this when we figure out why _pin_memory_thread's API changed.
persistent_workers=True,
)

Expand Down
2 changes: 1 addition & 1 deletion bionemo-recipes/recipes/esm2_native_te/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def device_mesh():
_mesh_resources.mesh_stack.clear()
_mesh_resources.child_to_root_mapping.clear()
_mesh_resources.root_to_flatten_mapping.clear()
_mesh_resources.flatten_name_to_root_dims.clear()
# _mesh_resources.flatten_name_to_root_dims.clear()
_mesh_resources.mesh_dim_group_options.clear()
torch.cuda.empty_cache()
torch.cuda.synchronize()
3 changes: 2 additions & 1 deletion ci/scripts/recipes_local_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
# --output type=registry,compression=zstd,force-compression=true,oci-mediatypes=true,compression-level=15
# and pushed to the dockerhub registry. Our github actions are able to cache image pulls from dockerhub but not nvcr, so
# hopefully this cuts down slightly on CI time at the expense of having a slightly in-directed image location.
DEFAULT_CONTAINER = "svcbionemo023/bionemo-framework:pytorch25.09-py3-squashed"
# DEFAULT_CONTAINER = "svcbionemo023/bionemo-framework:pytorch25.09-py3-squashed"
DEFAULT_CONTAINER = "gitlab-master.nvidia.com:5005/dl/transformerengine/transformerengine:main-pytorch-py3-devel"


def get_git_root() -> str:
Expand Down