Skip to content

CUDA Illegal Memory Access Error in Flash-DMA during Training #169

@helloworld01001

Description

@helloworld01001

Hi, I'm encountering a CUDA illegal memory access error when training the DMA model using flash_dmattn. The error occurs consistently at step 44 during training, specifically in the flash_dmattn_gpu.fwd() function.

Training Command

export CUDA_VISIBLE_DEVICES=0 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero1.yaml ./trainer/pt_dma.py --config recipes/dma/config_80M.yaml 2>&1 | tee dma.log

Environment Information
PyTorch: 2.6.0+cu124
CUDA: 12.4
GPU: NVIDIA A100-SXM4-40GB

Additional context

  • OS: Ubuntu 22.04.5
  • Python version: 3.10
  • Flash-DMA version: 1.0.6
  • CUDA Compute Capability: 8.0

Error Details

Error Message:

RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error Location:
The error occurs in the DogeAttention layer using flash_dynamic_mask_attention_forward during the forward pass.

Complete Stack Trace

Traceback (most recent call last):
  File "/share/project/zhangjialing/A-scripts/dma/./trainer/pt_dma.py", line 194, in <module>
    main(script_args, training_args, model_args, model_config)
  File "/share/project/zhangjialing/A-scripts/dma/./trainer/pt_dma.py", line 139, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 2672, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 4009, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 4099, in compute_loss
    outputs = model(**inputs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2131, in forward
    loss = self.module(*inputs, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
    result = forward_call(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/utils/generic.py", line 940, in wrapper
    output = func(self, *args, **kwargs)
  File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 660, in forward
    outputs: MoeModelOutputWithPast = self.model(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
    outputs = func(self, *args, **kwargs)
  File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 468, in forward
    hidden_states = decoder_layer(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 345, in forward
    hidden_states, self_attn_weights = self.self_attn(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 230, in forward
    attn_output, attn_weights = attention_interface(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/integrations/flash_dynamic_mask_attention.py", line 89, in flash_dynamic_mask_attention_forward
    attn_output = _flash_dynamic_mask_attention_forward(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py", line 108, in _flash_dynamic_mask_attention_forward
    out = flash_fn(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/flash_dmattn_interface.py", line 728, in flash_dmattn_func
    return FlashDMAttnFunc.apply(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/flash_dmattn_interface.py", line 451, in forward
    out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 641, in __call__
    return self._opoverload(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
    return self._op(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/autograd.py", line 113, in autograd_impl
    result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
    result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_ops.py", line 728, in redispatch
    return self._handle.redispatch_boxed(keyset, *args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 305, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
    return fn(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 337, in wrapped_fn
    return fn(*args, **kwargs)
  File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/flash_dmattn_interface.py", line 86, in _flash_dmattn_forward
    out, softmax_lse, S_dmask = flash_dmattn_gpu.fwd(
RuntimeError: CUDA error: an illegal memory access was encountered

Training Progress Before Error

The training runs successfully for the first 43 steps, then fails at step 44. The loss values appear normal before the crash:

  • Step 43: loss=10.186, grad_norm=1.8719213008880615, learning_rate=6.3e-05
  • Step 44: (crashes during forward pass)

Code Context

The error occurs in the attention mechanism implementation:

# File: /share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py, line 230
attn_output, attn_weights = attention_interface(
    self,
    query_states,
    key_states,
    value_states,
    attention_mask=attention_mask,
    attention_bias=attn_bias,
    scale=self.scaling,
)

Where attention_interface is set to flash_dynamic_mask_attention_forward.

Debugging Information
Add any other context about the problem here, including:

  • Sequence lengths and batch sizes you're using: max_position_embeddings=2048, per_device_train_batch_size=1, gradient_accumulation_steps=8
  • Whether this works with standard PyTorch SDPA: Not tested yet
  • Any custom modifications to the code: Using custom DogeAttention layer with flash_dynamic_mask_attention_forward
  • Training progress: Error occurs consistently at step 44, training runs fine for first 43 steps
  • Model architecture: DMA-80M based on Qwen3-1.7B with custom attention mechanism
  • Training framework: DeepSpeed ZeRO stage 1 with bfloat16 precision

Thank you for your help in resolving this issue!

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions