Skip to content

Something wrong with autograd func while using jvp in compute_nll_ode #3

@xinyangATK

Description

@xinyangATK

HI, Chaoran,

Thanks for this great work! Recently I try to reproduce and modify the sfm, however I have this error when I use compute_nll_ode to compute bpc metric with text8. Do you have any idea about it?

 File "/root/anaconda3/envs/mdlm/lib/python3.9/site-packages/flash_attn/layers/rotary.py", line 233, in apply_rotary_emb_qkv_
    return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
  File "/root/anaconda3/envs/mdlm/lib/python3.9/site-packages/torch/autograd/function.py", line 556, in apply
    raise RuntimeError(
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html

It seems like some function in flash-attn doesn't achieve setup_context staticmethod.

Here are my mainly package:
python 3.9.21
torch 2.2.2
torchdiffeq 0.2.5
flash-attn 2.6.3

Best,
Xinyang

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions