Skip to content

[Bug] CUDA Illegal Memory Access in chunk_kda_bwd on NVIDIA B200Β #727

@ReyJerry

Description

@ReyJerry

Checklist

  • I have checked FAQs and existing issues for similar problems
  • Please report this bug in English to ensure wider understanding and support

Describe the Bug

I encountered a RuntimeError: CUDA error: an illegal memory access was encountered during the backward pass of the KDA operator (chunk_kda) while training a 1.3B model on an 8x NVIDIA B200 node.

The error specifically occurs in the chunk_kda_bwd_wy_dqkg_fused function within fla/ops/kda/chunk_bwd.py, triggered during the Triton autotuning phase of the chunk_kda_bwd_kernel_wy_dqkg_fused kernel.

Steps to Reproduce the Bug

The error happens during the training loop when computing gradients. Below is the stack trace captured from the crash:

Traceback (most recent call last):
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File ".../flame/flame/train.py", line 514, in main
    loss.backward()
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File ".../flash-linear-attention/fla/utils.py", line 214, in wrapper
    return fn(*processed_args, **processed_kwargs)
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 556, in decorate_bwd
    return bwd(*args, **kwargs)
  File ".../flash-linear-attention/fla/ops/kda/chunk.py", line 367, in backward
    dq, dk, dv, db, dg, dh0 = chunk_kda_bwd(
  File ".../flash-linear-attention/fla/ops/kda/chunk.py", line 211, in chunk_kda_bwd
    dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused(
  File ".../flash-linear-attention/fla/ops/kda/chunk_bwd.py", line 358, in chunk_kda_bwd_wy_dqkg_fused
    chunk_kda_bwd_kernel_wy_dqkg_fused[grid](
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/jit.py", line 347, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 395, in run
    return self.fn.run(*args, **kwargs)
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 192, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 170, in _bench
    return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/triton/testing.py", line 146, in do_bench
    di.synchronize()
  File ".../miniconda3/envs/flame/lib/python3.11/site-packages/torch/cuda/__init__.py", line 1040, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered

Expected Behavior

The backward pass for chunk_kda should complete successfully without triggering illegal memory accesses, allowing training to proceed.

Environment Information

  • Platform/GPU: NVIDIA B200 (8x)

  • Torch: 2.7.0+cu128

  • Triton: 3.3.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions