Skip to content

Commit 70142eb

Browse files
chtruong814yfw
andauthored
cp: fix: Disable cudnn sdpa backend when using activation checkpointing (1717) into r0.5.0 (#1727)
Signed-off-by: Yi-Fu Wu <[email protected]> Signed-off-by: NeMo Bot <[email protected]> Co-authored-by: Yi-Fu Wu <[email protected]>
1 parent edfc23d commit 70142eb

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def __init__(
183183
# - Packed sequence requires FA2 and CP must be 1
184184
# - CP > 1 requires SDPA
185185
cp_size_cfg = self.cfg["dtensor_cfg"]["context_parallel_size"]
186+
187+
# NeMoAutoModelForCausalLM uses flash_attention_2 by default
188+
# so we need to set it to None if sequence packing is disabled
189+
# https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180
186190
attn_impl = (
187191
"flash_attention_2"
188192
if (self.enable_seq_packing and cp_size_cfg == 1)
@@ -273,23 +277,26 @@ def __init__(
273277
automodel_kwargs["use_liger_kernel"] = False
274278

275279
with init_empty_weights():
276-
# NeMoAutoModelForCausalLM uses flash_attention_2 by default
277-
# so we need to set it to None if sequence packing is disabled
278-
# https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180
279-
if cp_size > 1 or self.cfg["dtensor_cfg"]["activation_checkpointing"]:
280-
# For cp, match Automodel's `get_train_context` in `cp_utils.py` where only
280+
from torch.nn.attention import SDPBackend
281+
282+
if cp_size > 1:
283+
# Match Automodel's `get_train_context` in `cp_utils.py` where only
281284
# flash and efficient backends are supported
282285
# Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57
283-
284-
# For activation_checkpointing, CUDNN_ATTENTION must be excluded
285-
# since it results in an error:
286-
# "Recomputed values for the following tensors have different metadata than during the forward pass."
287-
from torch.nn.attention import SDPBackend
288-
289286
sdpa_method = [
290287
SDPBackend.FLASH_ATTENTION,
291288
SDPBackend.EFFICIENT_ATTENTION,
292289
]
290+
elif self.cfg["dtensor_cfg"]["activation_checkpointing"]:
291+
# For activation checkpointing, we must disable the cudnn SDPA backend because
292+
# it may not be selected during recomputation.
293+
# In that case, we will get the following error:
294+
# "Recomputed values have different metadata than during forward pass."
295+
sdpa_method = [
296+
SDPBackend.FLASH_ATTENTION,
297+
SDPBackend.EFFICIENT_ATTENTION,
298+
SDPBackend.MATH,
299+
]
293300
else:
294301
sdpa_method = None
295302

@@ -305,6 +312,13 @@ def __init__(
305312
if self.lora_enabled:
306313
apply_lora_to_linear_modules(self.model, self.peft_config)
307314

315+
# For activation checkpointing, we also must globally disable the cudnn SDPA backend
316+
# to ensure that cudnn does not get selected during recomputation.
317+
if self.cfg["dtensor_cfg"]["activation_checkpointing"]:
318+
from torch.backends import cuda
319+
320+
cuda.enable_cudnn_sdp(False)
321+
308322
# Hold a copy of model state_dict keys before any parallelization
309323
self.model_state_dict_keys = list(self.model.state_dict().keys())
310324

0 commit comments

Comments
 (0)