Skip to content

Broken training llama with PyTorch FSDP example on SMHP #944

@aravneelaws

Description

@aravneelaws

I was trying to run the llama training with Pytorch FSDP on 8x ml.g5.8xlarge on SMHP EKS cluster. The instructions are available in https://awslabs.github.io/ai-on-sagemaker-hyperpod/docs/eks-blueprints/training/fsdp/fully-sharded-data-parallel. But the container being built end up with torch and torchvision package compatibility issue. The requirements.txt file used in Dockerfile during build has torch version set to 2.7.1. But after the container is built, it ends up with:

torch                    2.6.0
torchaudio               2.7.1+cu128
torchvision              0.22.1+cu128

When I try a simple import for torch and torchvision, the latter fails with the same error I see when running the FSDP PT job.

>>> import torch
/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
>>> import torchvision
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torchvision/__init__.py", line 10, in <module>
    from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils  # usort:skip
  File "/usr/local/lib/python3.10/dist-packages/torchvision/_meta_registrations.py", line 164, in <module>
    def meta_nms(dets, scores, iou_threshold):
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 828, in register
    use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1)
  File "/usr/local/lib/python3.10/dist-packages/torch/library.py", line 198, in _register_fake
    handle = entry.fake_impl.register(func_to_register, source)
  File "/usr/local/lib/python3.10/dist-packages/torch/_library/fake_impl.py", line 31, in register
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

When I manually override the torch packages in the container and try to import them, the issue disappears. We will need to restrict the packages to requested versions to start resolving this issue.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions