-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Hi, I'm encountering a CUDA illegal memory access error when training the DMA model using flash_dmattn. The error occurs consistently at step 44 during training, specifically in the flash_dmattn_gpu.fwd() function.
Training Command
export CUDA_VISIBLE_DEVICES=0 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero1.yaml ./trainer/pt_dma.py --config recipes/dma/config_80M.yaml 2>&1 | tee dma.logEnvironment Information
PyTorch: 2.6.0+cu124
CUDA: 12.4
GPU: NVIDIA A100-SXM4-40GB
Additional context
- OS: Ubuntu 22.04.5
- Python version: 3.10
- Flash-DMA version: 1.0.6
- CUDA Compute Capability: 8.0
Error Details
Error Message:
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Error Location:
The error occurs in the DogeAttention layer using flash_dynamic_mask_attention_forward during the forward pass.
Complete Stack Trace
Traceback (most recent call last):
File "/share/project/zhangjialing/A-scripts/dma/./trainer/pt_dma.py", line 194, in <module>
main(script_args, training_args, model_args, model_config)
File "/share/project/zhangjialing/A-scripts/dma/./trainer/pt_dma.py", line 139, in main
train_result = trainer.train(resume_from_checkpoint=checkpoint)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 2328, in train
return inner_training_loop(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 2672, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 4009, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/trainer.py", line 4099, in compute_loss
outputs = model(**inputs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
ret_val = func(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 2131, in forward
loss = self.module(*inputs, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
return inner()
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1793, in inner
result = forward_call(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/utils/generic.py", line 940, in wrapper
output = func(self, *args, **kwargs)
File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 660, in forward
outputs: MoeModelOutputWithPast = self.model(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/utils/generic.py", line 1064, in wrapper
outputs = func(self, *args, **kwargs)
File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 468, in forward
hidden_states = decoder_layer(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 345, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py", line 230, in forward
attn_output, attn_weights = attention_interface(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/integrations/flash_dynamic_mask_attention.py", line 89, in flash_dynamic_mask_attention_forward
attn_output = _flash_dynamic_mask_attention_forward(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py", line 108, in _flash_dynamic_mask_attention_forward
out = flash_fn(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/flash_dmattn_interface.py", line 728, in flash_dmattn_func
return FlashDMAttnFunc.apply(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/flash_dmattn_interface.py", line 451, in forward
out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward(
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 641, in __call__
return self._opoverload(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_ops.py", line 723, in __call__
return self._op(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/autograd.py", line 113, in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/autograd.py", line 40, in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_ops.py", line 728, in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 305, in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
return disable_fn(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
return fn(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 337, in wrapped_fn
return fn(*args, **kwargs)
File "/share/project/zhaohuxing/anaconda3/envs/dma/lib/python3.10/site-packages/flash_dmattn/flash_dmattn_interface.py", line 86, in _flash_dmattn_forward
out, softmax_lse, S_dmask = flash_dmattn_gpu.fwd(
RuntimeError: CUDA error: an illegal memory access was encountered
Training Progress Before Error
The training runs successfully for the first 43 steps, then fails at step 44. The loss values appear normal before the crash:
- Step 43: loss=10.186, grad_norm=1.8719213008880615, learning_rate=6.3e-05
- Step 44: (crashes during forward pass)
Code Context
The error occurs in the attention mechanism implementation:
# File: /share/project/zhangjialing/A-scripts/dma/models/dma/modeling_doge.py, line 230
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
attention_bias=attn_bias,
scale=self.scaling,
)Where attention_interface is set to flash_dynamic_mask_attention_forward.
Debugging Information
Add any other context about the problem here, including:
- Sequence lengths and batch sizes you're using: max_position_embeddings=2048, per_device_train_batch_size=1, gradient_accumulation_steps=8
- Whether this works with standard PyTorch SDPA: Not tested yet
- Any custom modifications to the code: Using custom DogeAttention layer with flash_dynamic_mask_attention_forward
- Training progress: Error occurs consistently at step 44, training runs fine for first 43 steps
- Model architecture: DMA-80M based on Qwen3-1.7B with custom attention mechanism
- Training framework: DeepSpeed ZeRO stage 1 with bfloat16 precision
Thank you for your help in resolving this issue!