Skip to content

FutureWarning: torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead. #20370

@loretoparisi

Description

@loretoparisi

Bug description

I'm getting the error

FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.

specifically the issue seems to be caused by fairscale:

/home/coder/.local/lib/python3.10/site-packages/fairscale/experimental/nn/offload.py:19: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

I'm running


torch==2.5.0
lightning==2.4.0
lightning-utilities==0.11.8
fairscale==0.4.6

python code is almost standard distributed training code with FSDPStrategy train strategy and it was working before:

def custom_auto_wrap_policy(module, recurse, nonwrapped_numel, **kwargs):
            # Wrap only Embedding layers
            if isinstance(module, nn.Embedding):
                return True
            return False
        
        sharding_strategy=modelArgs.sharding_strategy
        state_dict_type=modelArgs.state_dict_type

strategy = FSDPStrategy(
                timeout=CUSTOM_TIMEOUT,
                cpu_offload=cpu_offload,
                activation_checkpointing_policy=custom_auto_wrap_policy,
                auto_wrap_policy=custom_auto_wrap_policy,
                mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
                process_group_backend="nccl",
                sharding_strategy=sharding_strategy,
                state_dict_type=state_dict_type
            )

    pairs = preprocess_pairs_tensor
    train_dataset = TensorDataset(pairs)
    trainloader = DataLoader(train_dataset,
                             batch_size = modelArgs.batch_size,
                             collate_fn = my_collate,
                             drop_last = False,
                             shuffle=True,
                             num_workers=psutil.cpu_count(),
                             persistent_workers=True,
                             pin_memory=True)
    
    # Initialize a trainer
    trainer = L.Trainer(
        logger=logger,
        log_every_n_steps=1,
        precision="bf16-true",
        callbacks=[checkpoint_callback],
        accelerator=accelerator,
        devices=devices,
        num_nodes=num_nodes,
        strategy=strategy,
        #limit_train_batches=1.0,
        max_epochs=modelArgs.epochs,
        deterministic=True
    )

Error messages and logs

Undefined number of following logging

/home/coder/.local/lib/python3.10/site-packages/fairscale/experimental/nn/offload.py:19: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.

Environment

<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
                - NVIDIA L4
        - available:         True
        - version:           12.4
* Lightning:
        - lightning:         2.4.0
        - lightning-utilities: 0.11.8
        - pytorch-lightning: 1.9.5
        - torch:             2.5.0
        - torchmetrics:      1.5.1
        - torchvision:       0.20.0
* Packages:
        - aiofiles:          22.1.0
        - aiohappyeyeballs:  2.4.3
        - aiohttp:           3.10.10
        - aiosignal:         1.3.1
        - aiosqlite:         0.20.0
        - anyio:             4.6.0
        - appdirs:           1.4.4
        - argon2-cffi:       23.1.0
        - argon2-cffi-bindings: 21.2.0
        - argparse:          1.4.0
        - arrow:             1.3.0
        - asttokens:         2.4.1
        - async-timeout:     4.0.3
        - attrs:             24.2.0
        - autocommand:       2.2.2
        - autofaiss:         2.15.8
        - babel:             2.16.0
        - backports.tarfile: 1.2.0
        - beautifulsoup4:    4.12.3
        - bleach:            6.1.0
        - blinker:           1.4
        - boto3:             1.26.145
        - botocore:          1.29.165
        - certifi:           2024.8.30
        - cffi:              1.17.1
        - charset-normalizer: 3.3.2
        - click:             8.1.7
        - coloredlogs:       15.0.1
        - comm:              0.2.2
        - cryptography:      3.4.8
        - datasets:          2.14.4
        - dbus-python:       1.2.18
        - debugpy:           1.8.6
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.7
        - distro:            1.7.0
        - docker-pycreds:    0.4.0
        - duckdb:            1.1.1
        - embedding-reader:  1.7.0
        - entrypoints:       0.4
        - exceptiongroup:    1.2.2
        - executing:         2.1.0
        - fairscale:         0.4.6
        - faiss-cpu:         1.9.0
        - fastjsonschema:    2.20.0
        - filelock:          3.16.1
        - fire:              0.4.0
        - flatbuffers:       24.3.25
        - fqdn:              1.5.1
        - frozenlist:        1.5.0
        - fsspec:            2024.10.0
        - fuzzywuzzy:        0.18.0
        - gitdb:             4.0.11
        - gitpython:         3.1.43
        - hnswlib:           0.7.0
        - httplib2:          0.20.2
        - huggingface-hub:   0.25.2
        - humanfriendly:     10.0
        - idna:              3.10
        - importlib-metadata: 4.6.4
        - importlib-resources: 6.4.0
        - inflect:           7.3.1
        - ipykernel:         6.29.5
        - ipython:           8.27.0
        - ipython-genutils:  0.2.0
        - isoduration:       20.11.0
        - jaraco.collections: 5.1.0
        - jaraco.context:    5.3.0
        - jaraco.functools:  4.0.1
        - jaraco.text:       3.12.1
        - jedi:              0.19.1
        - jeepney:           0.7.1
        - jinja2:            3.1.4
        - jmespath:          1.0.1
        - joblib:            1.3.2
        - json5:             0.9.25
        - jsonpointer:       3.0.0
        - jsonschema:        4.23.0
        - jsonschema-specifications: 2023.12.1
        - jupyter-client:    7.4.9
        - jupyter-core:      5.7.2
        - jupyter-events:    0.10.0
        - jupyter-server:    2.14.2
        - jupyter-server-fileid: 0.9.3
        - jupyter-server-terminals: 0.5.3
        - jupyter-server-ydoc: 0.8.0
        - jupyter-ydoc:      0.2.5
        - jupyterlab:        3.6.2
        - jupyterlab-pygments: 0.3.0
        - jupyterlab-server: 2.27.3
        - keyring:           23.5.0
        - launchpadlib:      1.10.16
        - lazr.restfulclient: 0.14.4
        - lazr.uri:          1.0.6
        - levenshtein:       0.23.0
        - lightning:         2.4.0
        - lightning-utilities: 0.11.8
        - markupsafe:        2.1.5
        - matplotlib-inline: 0.1.7
        - mistune:           3.0.2
        - more-itertools:    8.10.0
        - mpmath:            1.3.0
        - multidict:         6.1.0
        - multiprocess:      0.70.15
        - nbclassic:         1.1.0
        - nbclient:          0.10.0
        - nbconvert:         7.16.4
        - nbformat:          5.10.4
        - nest-asyncio:      1.6.0
        - networkx:          3.4.2
        - nltk:              3.9.1
        - notebook:          6.5.7
        - notebook-shim:     0.2.4
        - numpy:             1.26.4
        - nvidia-cublas-cu12: 12.4.5.8
        - nvidia-cuda-cupti-cu12: 12.4.127
        - nvidia-cuda-nvrtc-cu12: 12.4.127
        - nvidia-cuda-runtime-cu12: 12.4.127
        - nvidia-cudnn-cu12: 9.1.0.70
        - nvidia-cufft-cu12: 11.2.1.3
        - nvidia-curand-cu12: 10.3.5.147
        - nvidia-cusolver-cu12: 11.6.1.9
        - nvidia-cusparse-cu12: 12.3.1.170
        - nvidia-nccl-cu12:  2.21.5
        - nvidia-nvjitlink-cu12: 12.4.127
        - nvidia-nvtx-cu12:  12.4.127
        - oauthlib:          3.2.0
        - onnx:              1.17.0
        - onnxruntime-gpu:   1.19.2
        - optimum:           1.23.2
        - overrides:         7.7.0
        - packaging:         24.1
        - pandas:            1.3.5
        - pandocfilters:     1.5.1
        - parso:             0.8.4
        - pexpect:           4.9.0
        - pillow:            11.0.0
        - pip:               24.2
        - platformdirs:      4.3.6
        - prometheus-client: 0.21.0
        - prompt-toolkit:    3.0.48
        - propcache:         0.2.0
        - protobuf:          4.25.3
        - psutil:            5.9.5
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.3
        - pyarrow:           12.0.1
        - pycparser:         2.22
        - pygments:          2.18.0
        - pygobject:         3.42.1
        - pyjwt:             2.3.0
        - pyparsing:         2.4.7
        - python-apt:        2.4.0+ubuntu4
        - python-dateutil:   2.9.0.post0
        - python-json-logger: 2.0.7
        - python-levenshtein: 0.23.0
        - pytorch-lightning: 1.9.5
        - pytz:              2024.2
        - pyyaml:            6.0.2
        - pyzmq:             26.0.3
        - rapidfuzz:         3.4.0
        - referencing:       0.35.1
        - regex:             2024.9.11
        - requests:          2.32.3
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rpds-py:           0.20.0
        - s3transfer:        0.6.2
        - safetensors:       0.4.5
        - scikit-learn:      1.5.2
        - scipy:             1.14.1
        - secretstorage:     3.3.1
        - send2trash:        1.8.3
        - sentence-transformers: 2.2.2
        - sentencepiece:     0.2.0
        - sentry-sdk:        2.17.0
        - setproctitle:      1.3.3
        - setuptools:        75.1.0
        - six:               1.16.0
        - smmap:             5.0.1
        - sniffio:           1.3.1
        - soupsieve:         2.6
        - stack-data:        0.6.3
        - sympy:             1.13.1
        - termcolor:         2.5.0
        - terminado:         0.18.1
        - threadpoolctl:     3.5.0
        - tinycss2:          1.3.0
        - tokenizers:        0.20.1
        - tomli:             2.0.1
        - torch:             2.5.0
        - torchmetrics:      1.5.1
        - torchvision:       0.20.0
        - tornado:           6.4.1
        - tqdm:              4.66.1
        - traitlets:         5.14.3
        - transformers:      4.46.0
        - triton:            3.1.0
        - typeguard:         4.3.0
        - types-python-dateutil: 2.9.0.20240906
        - typing-extensions: 4.12.2
        - uri-template:      1.3.0
        - urllib3:           1.26.20
        - wadllib:           1.3.6
        - wandb:             0.16.6
        - wcwidth:           0.2.13
        - webcolors:         24.8.0
        - webencodings:      0.5.1
        - websocket-client:  1.8.0
        - wheel:             0.44.0
        - xxhash:            3.5.0
        - y-py:              0.6.2
        - yarl:              1.16.0
        - ypy-websocket:     0.8.4
        - zipp:              1.0.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - 
        - processor:         x86_64
        - python:            3.10.12
        - release:           5.10.219-208.866.amzn2.x86_64
        - version:           #1 SMP Tue Jun 18 14:00:06 UTC 2024

</details>

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions