-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Description
HI, Chaoran,
Thanks for this great work! Recently I try to reproduce and modify the sfm, however I have this error when I use compute_nll_ode
to compute bpc metric with text8. Do you have any idea about it?
File "/root/anaconda3/envs/mdlm/lib/python3.9/site-packages/flash_attn/layers/rotary.py", line 233, in apply_rotary_emb_qkv_
return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets)
File "/root/anaconda3/envs/mdlm/lib/python3.9/site-packages/torch/autograd/function.py", line 556, in apply
raise RuntimeError(
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/master/notes/extending.func.html
It seems like some function in flash-attn doesn't achieve setup_context staticmethod.
Here are my mainly package:
python 3.9.21
torch 2.2.2
torchdiffeq 0.2.5
flash-attn 2.6.3
Best,
Xinyang
Metadata
Metadata
Assignees
Labels
No labels