-
Notifications
You must be signed in to change notification settings - Fork 374
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Checklist
- I have checked FAQs and existing issues for similar problems
- Please report this bug in English to ensure wider understanding and support
Describe the Bug
I encountered a RuntimeError: CUDA error: an illegal memory access was encountered during the backward pass of the KDA operator (chunk_kda) while training a 1.3B model on an 8x NVIDIA B200 node.
The error specifically occurs in the chunk_kda_bwd_wy_dqkg_fused function within fla/ops/kda/chunk_bwd.py, triggered during the Triton autotuning phase of the chunk_kda_bwd_kernel_wy_dqkg_fused kernel.
Steps to Reproduce the Bug
The error happens during the training loop when computing gradients. Below is the stack trace captured from the crash:
Traceback (most recent call last):
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File ".../flame/flame/train.py", line 514, in main
loss.backward()
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
torch.autograd.backward(
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/autograd/__init__.py", line 353, in backward
_engine_run_backward(
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
File ".../flash-linear-attention/fla/utils.py", line 214, in wrapper
return fn(*processed_args, **processed_kwargs)
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 556, in decorate_bwd
return bwd(*args, **kwargs)
File ".../flash-linear-attention/fla/ops/kda/chunk.py", line 367, in backward
dq, dk, dv, db, dg, dh0 = chunk_kda_bwd(
File ".../flash-linear-attention/fla/ops/kda/chunk.py", line 211, in chunk_kda_bwd
dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused(
File ".../flash-linear-attention/fla/ops/kda/chunk_bwd.py", line 358, in chunk_kda_bwd_wy_dqkg_fused
chunk_kda_bwd_kernel_wy_dqkg_fused[grid](
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/jit.py", line 347, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 395, in run
return self.fn.run(*args, **kwargs)
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 192, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 170, in _bench
return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/testing.py", line 146, in do_bench
di.synchronize()
File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/cuda/__init__.py", line 1040, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
Expected Behavior
The backward pass for chunk_kda should complete successfully without triggering illegal memory accesses, allowing training to proceed.
Environment Information
-
Platform/GPU: NVIDIA B200 (8x)
-
Torch: 2.7.0+cu128
-
Triton: 3.3.0
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working