@@ -183,6 +183,10 @@ def __init__(
183183 # - Packed sequence requires FA2 and CP must be 1
184184 # - CP > 1 requires SDPA
185185 cp_size_cfg = self .cfg ["dtensor_cfg" ]["context_parallel_size" ]
186+
187+ # NeMoAutoModelForCausalLM uses flash_attention_2 by default
188+ # so we need to set it to None if sequence packing is disabled
189+ # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180
186190 attn_impl = (
187191 "flash_attention_2"
188192 if (self .enable_seq_packing and cp_size_cfg == 1 )
@@ -273,23 +277,26 @@ def __init__(
273277 automodel_kwargs ["use_liger_kernel" ] = False
274278
275279 with init_empty_weights ():
276- # NeMoAutoModelForCausalLM uses flash_attention_2 by default
277- # so we need to set it to None if sequence packing is disabled
278- # https://github.com/NVIDIA-NeMo/Automodel/blob/7e748be260651349307862426c0c168cebdeeec3/nemo_automodel/components/_transformers/auto_model.py#L180
279- if cp_size > 1 or self .cfg ["dtensor_cfg" ]["activation_checkpointing" ]:
280- # For cp, match Automodel's `get_train_context` in `cp_utils.py` where only
280+ from torch .nn .attention import SDPBackend
281+
282+ if cp_size > 1 :
283+ # Match Automodel's `get_train_context` in `cp_utils.py` where only
281284 # flash and efficient backends are supported
282285 # Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57
283-
284- # For activation_checkpointing, CUDNN_ATTENTION must be excluded
285- # since it results in an error:
286- # "Recomputed values for the following tensors have different metadata than during the forward pass."
287- from torch .nn .attention import SDPBackend
288-
289286 sdpa_method = [
290287 SDPBackend .FLASH_ATTENTION ,
291288 SDPBackend .EFFICIENT_ATTENTION ,
292289 ]
290+ elif self .cfg ["dtensor_cfg" ]["activation_checkpointing" ]:
291+ # For activation checkpointing, we must disable the cudnn SDPA backend because
292+ # it may not be selected during recomputation.
293+ # In that case, we will get the following error:
294+ # "Recomputed values have different metadata than during forward pass."
295+ sdpa_method = [
296+ SDPBackend .FLASH_ATTENTION ,
297+ SDPBackend .EFFICIENT_ATTENTION ,
298+ SDPBackend .MATH ,
299+ ]
293300 else :
294301 sdpa_method = None
295302
@@ -305,6 +312,13 @@ def __init__(
305312 if self .lora_enabled :
306313 apply_lora_to_linear_modules (self .model , self .peft_config )
307314
315+ # For activation checkpointing, we also must globally disable the cudnn SDPA backend
316+ # to ensure that cudnn does not get selected during recomputation.
317+ if self .cfg ["dtensor_cfg" ]["activation_checkpointing" ]:
318+ from torch .backends import cuda
319+
320+ cuda .enable_cudnn_sdp (False )
321+
308322 # Hold a copy of model state_dict keys before any parallelization
309323 self .model_state_dict_keys = list (self .model .state_dict ().keys ())
310324
0 commit comments