Skip to content

[Bug] CUDA Illegal Memory Access when use GLA gradient checkpoint (seqlen>=128k)ย #728

@Applauzzz

Description

@Applauzzz

Checklist

  • I have checked FAQs and existing issues for similar problems
  • Please report this bug in English to ensure wider understanding and support

Describe the Bug

When I train a 7B GLA model and turn on gradient checkpoint. It reports:

[rank3]: Traceback (most recent call last): [rank3]: File "/mnt/zehao/ULTra/./main/train.py", line 580, in <module> [rank3]: main() [rank3]: File "/mnt/zehao/ULTra/./main/train.py", line 576, in main [rank3]: train(cfg) [rank3]: File "/mnt/zehao/ULTra/./main/train.py", line 380, in train [rank3]: output,_ = model(input_ids=input_ids, labels=labels) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl [rank3]: return inner() [rank3]: ^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1805, in inner [rank3]: result = forward_call(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func [rank3]: return func(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/mnt/zehao/flash-linear-attention/fla/models/gla/modeling_gla.py", line 326, in forward [rank3]: outputs = self.model( [rank3]: ^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl [rank3]: return forward_call(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/mnt/zehao/flash-linear-attention/fla/models/gla/modeling_gla.py", line 230, in forward [rank3]: hidden_states, attentions, past_key_values = layer( [rank3]: ^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/transformers/modeling_layers.py", line 93, in __call__ [rank3]: return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner [rank3]: return disable_fn(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn [rank3]: return fn(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 488, in checkpoint [rank3]: return CheckpointFunction.apply(function, preserve, *args) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply [rank3]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 263, in forward [rank3]: outputs = run_function(*args) [rank3]: ^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1857, in _call_impl [rank3]: return inner() [rank3]: ^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1805, in inner [rank3]: result = forward_call(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/mnt/zehao/flash-linear-attention/fla/models/gla/modeling_gla.py", line 107, in forward [rank3]: hidden_states = self.mlp(hidden_states, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl [rank3]: return self._call_impl(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl [rank3]: return forward_call(*args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/mnt/zehao/flash-linear-attention/fla/modules/mlp.py", line 68, in forward [rank3]: return self.down_proj(swiglu(gate, y)) [rank3]: ^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/torch/autograd/function.py", line 575, in apply [rank3]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/mnt/zehao/flash-linear-attention/fla/modules/activations.py", line 500, in forward [rank3]: return swiglu_fwd(x, y) [rank3]: ^^^^^^^^^^^^^^^^ [rank3]: File "/mnt/zehao/flash-linear-attention/fla/modules/activations.py", line 470, in swiglu_fwd [rank3]: swiglu_fwd_kernel[lambda meta: (triton.cdiv(T, meta['B']),)](x, y, z, T=T, D=D) [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/runtime/jit.py", line 347, in <lambda> [rank3]: return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 192, in run [rank3]: timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 170, in _bench [rank3]: return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/testing.py", line 145, in do_bench [rank3]: fn() [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 156, in kernel_call [rank3]: self.fn.run( [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/runtime/jit.py", line 591, in run [rank3]: kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, [rank3]: File "/home/ubuntu/miniforge3/envs/liu/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 529, in __call__ [rank3]: self.launch(gridX, gridY, gridZ, stream, function, self.launch_cooperative_grid, global_scratch, *args) [rank3]: RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

if I do not use ckpt or seqlen less than 128k, everything is fine

Steps to Reproduce the Bug

Use input length > 128k. Every fuse function turns on.

config:
`

{ -- "_name_or_path": "fla-hub/gla-7B-mistral-20B", "architectures": [ "GLAForCausalLM" ], "attn": null, "attn_mode": "chunk", "bos_token_id": 1, "clamp_min": null, "conv_size": 4, "elementwise_affine": true, "eos_token_id": 2, "expand_k": 1, "expand_v": 1, "feature_map": "relu", "fuse_cross_entropy": true, "fuse_norm": true, "hidden_act": "swish", "hidden_ratio": 4, "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 32768, "model_type": "gla", "norm_eps": 1e-05, "num_heads": 32, "num_hidden_layers": 32, "num_kv_heads": 8, "share_conv_kernel": true, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.48.2", "use_cache": true, "use_gk": true, "use_gv": false, "use_output_gate": false, "use_short_conv": false, "vocab_size": 32000 } ย 


`

Expected Behavior

It may be caused by int32 offset in triton kernel.

Environment Information

  1. Torch: 3.7.1
  2. Triton: 3.3.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions