|
5 | 5 | from diffusers import FluxPipeline |
6 | 6 | from torch._inductor.package import load_package as inductor_load_package |
7 | 7 | from typing import List, Optional, Tuple |
| 8 | +import inspect |
8 | 9 |
|
9 | 10 |
|
10 | 11 | @torch.library.custom_op("flash::flash_attn_func", mutates_args=()) |
@@ -36,23 +37,28 @@ def flash_attn_func( |
36 | 37 | import flash_attn_interface |
37 | 38 |
|
38 | 39 | dtype = torch.float8_e4m3fn |
| 40 | + |
| 41 | + sig = inspect.signature(flash_attn_interface.flash_attn_func) |
| 42 | + accepted = set(sig.parameters) |
| 43 | + all_kwargs = { |
| 44 | + "softmax_scale": softmax_scale, |
| 45 | + "causal": causal, |
| 46 | + "qv": qv, |
| 47 | + "q_descale": q_descale, |
| 48 | + "k_descale": k_descale, |
| 49 | + "v_descale": v_descale, |
| 50 | + "window_size": window_size, |
| 51 | + "sink_token_length": sink_token_length, |
| 52 | + "softcap": softcap, |
| 53 | + "num_splits": num_splits, |
| 54 | + "pack_gqa": pack_gqa, |
| 55 | + "deterministic": deterministic, |
| 56 | + "sm_margin": sm_margin, |
| 57 | + } |
| 58 | + kwargs = {k: v for k, v in all_kwargs.items() if k in accepted} |
| 59 | + |
39 | 60 | outputs = flash_attn_interface.flash_attn_func( |
40 | | - q.to(dtype), |
41 | | - k.to(dtype), |
42 | | - v.to(dtype), |
43 | | - softmax_scale=softmax_scale, |
44 | | - causal=causal, |
45 | | - qv=qv, |
46 | | - q_descale=q_descale, |
47 | | - k_descale=k_descale, |
48 | | - v_descale=v_descale, |
49 | | - window_size=window_size, |
50 | | - sink_token_length=sink_token_length, |
51 | | - softcap=softcap, |
52 | | - num_splits=num_splits, |
53 | | - pack_gqa=pack_gqa, |
54 | | - deterministic=deterministic, |
55 | | - sm_margin=sm_margin, |
| 61 | + q.to(dtype), k.to(dtype), v.to(dtype), **kwargs, |
56 | 62 | ) |
57 | 63 | return outputs[0] |
58 | 64 |
|
|
0 commit comments