-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
Description
Bug description
I'm getting the error
FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
specifically the issue seems to be caused by fairscale
:
/home/coder/.local/lib/python3.10/site-packages/fairscale/experimental/nn/offload.py:19: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
I'm running
torch==2.5.0
lightning==2.4.0
lightning-utilities==0.11.8
fairscale==0.4.6
python code is almost standard distributed training code with FSDPStrategy
train strategy and it was working before:
def custom_auto_wrap_policy(module, recurse, nonwrapped_numel, **kwargs):
# Wrap only Embedding layers
if isinstance(module, nn.Embedding):
return True
return False
sharding_strategy=modelArgs.sharding_strategy
state_dict_type=modelArgs.state_dict_type
strategy = FSDPStrategy(
timeout=CUSTOM_TIMEOUT,
cpu_offload=cpu_offload,
activation_checkpointing_policy=custom_auto_wrap_policy,
auto_wrap_policy=custom_auto_wrap_policy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
process_group_backend="nccl",
sharding_strategy=sharding_strategy,
state_dict_type=state_dict_type
)
pairs = preprocess_pairs_tensor
train_dataset = TensorDataset(pairs)
trainloader = DataLoader(train_dataset,
batch_size = modelArgs.batch_size,
collate_fn = my_collate,
drop_last = False,
shuffle=True,
num_workers=psutil.cpu_count(),
persistent_workers=True,
pin_memory=True)
# Initialize a trainer
trainer = L.Trainer(
logger=logger,
log_every_n_steps=1,
precision="bf16-true",
callbacks=[checkpoint_callback],
accelerator=accelerator,
devices=devices,
num_nodes=num_nodes,
strategy=strategy,
#limit_train_batches=1.0,
max_epochs=modelArgs.epochs,
deterministic=True
)
Error messages and logs
Undefined number of following logging
/home/coder/.local/lib/python3.10/site-packages/fairscale/experimental/nn/offload.py:19: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
Environment
<details>
<summary>Current environment</summary>
* CUDA:
- GPU:
- NVIDIA L4
- NVIDIA L4
- NVIDIA L4
- NVIDIA L4
- NVIDIA L4
- NVIDIA L4
- NVIDIA L4
- NVIDIA L4
- available: True
- version: 12.4
* Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- pytorch-lightning: 1.9.5
- torch: 2.5.0
- torchmetrics: 1.5.1
- torchvision: 0.20.0
* Packages:
- aiofiles: 22.1.0
- aiohappyeyeballs: 2.4.3
- aiohttp: 3.10.10
- aiosignal: 1.3.1
- aiosqlite: 0.20.0
- anyio: 4.6.0
- appdirs: 1.4.4
- argon2-cffi: 23.1.0
- argon2-cffi-bindings: 21.2.0
- argparse: 1.4.0
- arrow: 1.3.0
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 24.2.0
- autocommand: 2.2.2
- autofaiss: 2.15.8
- babel: 2.16.0
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.12.3
- bleach: 6.1.0
- blinker: 1.4
- boto3: 1.26.145
- botocore: 1.29.165
- certifi: 2024.8.30
- cffi: 1.17.1
- charset-normalizer: 3.3.2
- click: 8.1.7
- coloredlogs: 15.0.1
- comm: 0.2.2
- cryptography: 3.4.8
- datasets: 2.14.4
- dbus-python: 1.2.18
- debugpy: 1.8.6
- decorator: 5.1.1
- defusedxml: 0.7.1
- dill: 0.3.7
- distro: 1.7.0
- docker-pycreds: 0.4.0
- duckdb: 1.1.1
- embedding-reader: 1.7.0
- entrypoints: 0.4
- exceptiongroup: 1.2.2
- executing: 2.1.0
- fairscale: 0.4.6
- faiss-cpu: 1.9.0
- fastjsonschema: 2.20.0
- filelock: 3.16.1
- fire: 0.4.0
- flatbuffers: 24.3.25
- fqdn: 1.5.1
- frozenlist: 1.5.0
- fsspec: 2024.10.0
- fuzzywuzzy: 0.18.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- hnswlib: 0.7.0
- httplib2: 0.20.2
- huggingface-hub: 0.25.2
- humanfriendly: 10.0
- idna: 3.10
- importlib-metadata: 4.6.4
- importlib-resources: 6.4.0
- inflect: 7.3.1
- ipykernel: 6.29.5
- ipython: 8.27.0
- ipython-genutils: 0.2.0
- isoduration: 20.11.0
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jedi: 0.19.1
- jeepney: 0.7.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.3.2
- json5: 0.9.25
- jsonpointer: 3.0.0
- jsonschema: 4.23.0
- jsonschema-specifications: 2023.12.1
- jupyter-client: 7.4.9
- jupyter-core: 5.7.2
- jupyter-events: 0.10.0
- jupyter-server: 2.14.2
- jupyter-server-fileid: 0.9.3
- jupyter-server-terminals: 0.5.3
- jupyter-server-ydoc: 0.8.0
- jupyter-ydoc: 0.2.5
- jupyterlab: 3.6.2
- jupyterlab-pygments: 0.3.0
- jupyterlab-server: 2.27.3
- keyring: 23.5.0
- launchpadlib: 1.10.16
- lazr.restfulclient: 0.14.4
- lazr.uri: 1.0.6
- levenshtein: 0.23.0
- lightning: 2.4.0
- lightning-utilities: 0.11.8
- markupsafe: 2.1.5
- matplotlib-inline: 0.1.7
- mistune: 3.0.2
- more-itertools: 8.10.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.15
- nbclassic: 1.1.0
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- nest-asyncio: 1.6.0
- networkx: 3.4.2
- nltk: 3.9.1
- notebook: 6.5.7
- notebook-shim: 0.2.4
- numpy: 1.26.4
- nvidia-cublas-cu12: 12.4.5.8
- nvidia-cuda-cupti-cu12: 12.4.127
- nvidia-cuda-nvrtc-cu12: 12.4.127
- nvidia-cuda-runtime-cu12: 12.4.127
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu12: 11.2.1.3
- nvidia-curand-cu12: 10.3.5.147
- nvidia-cusolver-cu12: 11.6.1.9
- nvidia-cusparse-cu12: 12.3.1.170
- nvidia-nccl-cu12: 2.21.5
- nvidia-nvjitlink-cu12: 12.4.127
- nvidia-nvtx-cu12: 12.4.127
- oauthlib: 3.2.0
- onnx: 1.17.0
- onnxruntime-gpu: 1.19.2
- optimum: 1.23.2
- overrides: 7.7.0
- packaging: 24.1
- pandas: 1.3.5
- pandocfilters: 1.5.1
- parso: 0.8.4
- pexpect: 4.9.0
- pillow: 11.0.0
- pip: 24.2
- platformdirs: 4.3.6
- prometheus-client: 0.21.0
- prompt-toolkit: 3.0.48
- propcache: 0.2.0
- protobuf: 4.25.3
- psutil: 5.9.5
- ptyprocess: 0.7.0
- pure-eval: 0.2.3
- pyarrow: 12.0.1
- pycparser: 2.22
- pygments: 2.18.0
- pygobject: 3.42.1
- pyjwt: 2.3.0
- pyparsing: 2.4.7
- python-apt: 2.4.0+ubuntu4
- python-dateutil: 2.9.0.post0
- python-json-logger: 2.0.7
- python-levenshtein: 0.23.0
- pytorch-lightning: 1.9.5
- pytz: 2024.2
- pyyaml: 6.0.2
- pyzmq: 26.0.3
- rapidfuzz: 3.4.0
- referencing: 0.35.1
- regex: 2024.9.11
- requests: 2.32.3
- rfc3339-validator: 0.1.4
- rfc3986-validator: 0.1.1
- rpds-py: 0.20.0
- s3transfer: 0.6.2
- safetensors: 0.4.5
- scikit-learn: 1.5.2
- scipy: 1.14.1
- secretstorage: 3.3.1
- send2trash: 1.8.3
- sentence-transformers: 2.2.2
- sentencepiece: 0.2.0
- sentry-sdk: 2.17.0
- setproctitle: 1.3.3
- setuptools: 75.1.0
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- soupsieve: 2.6
- stack-data: 0.6.3
- sympy: 1.13.1
- termcolor: 2.5.0
- terminado: 0.18.1
- threadpoolctl: 3.5.0
- tinycss2: 1.3.0
- tokenizers: 0.20.1
- tomli: 2.0.1
- torch: 2.5.0
- torchmetrics: 1.5.1
- torchvision: 0.20.0
- tornado: 6.4.1
- tqdm: 4.66.1
- traitlets: 5.14.3
- transformers: 4.46.0
- triton: 3.1.0
- typeguard: 4.3.0
- types-python-dateutil: 2.9.0.20240906
- typing-extensions: 4.12.2
- uri-template: 1.3.0
- urllib3: 1.26.20
- wadllib: 1.3.6
- wandb: 0.16.6
- wcwidth: 0.2.13
- webcolors: 24.8.0
- webencodings: 0.5.1
- websocket-client: 1.8.0
- wheel: 0.44.0
- xxhash: 3.5.0
- y-py: 0.6.2
- yarl: 1.16.0
- ypy-websocket: 0.8.4
- zipp: 1.0.0
* System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.10.12
- release: 5.10.219-208.866.amzn2.x86_64
- version: #1 SMP Tue Jun 18 14:00:06 UTC 2024
</details>
More info
No response