Skip to content

Commit 812b163

Browse files
committed
inspect
1 parent 0a3b7a1 commit 812b163

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

utils/pipeline_utils.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from diffusers import FluxPipeline
66
from torch._inductor.package import load_package as inductor_load_package
77
from typing import List, Optional, Tuple
8+
import inspect
89

910

1011
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
@@ -36,23 +37,28 @@ def flash_attn_func(
3637
import flash_attn_interface
3738

3839
dtype = torch.float8_e4m3fn
40+
41+
sig = inspect.signature(flash_attn_interface.flash_attn_func)
42+
accepted = set(sig.parameters)
43+
all_kwargs = {
44+
"softmax_scale": softmax_scale,
45+
"causal": causal,
46+
"qv": qv,
47+
"q_descale": q_descale,
48+
"k_descale": k_descale,
49+
"v_descale": v_descale,
50+
"window_size": window_size,
51+
"sink_token_length": sink_token_length,
52+
"softcap": softcap,
53+
"num_splits": num_splits,
54+
"pack_gqa": pack_gqa,
55+
"deterministic": deterministic,
56+
"sm_margin": sm_margin,
57+
}
58+
kwargs = {k: v for k, v in all_kwargs.items() if k in accepted}
59+
3960
outputs = flash_attn_interface.flash_attn_func(
40-
q.to(dtype),
41-
k.to(dtype),
42-
v.to(dtype),
43-
softmax_scale=softmax_scale,
44-
causal=causal,
45-
qv=qv,
46-
q_descale=q_descale,
47-
k_descale=k_descale,
48-
v_descale=v_descale,
49-
window_size=window_size,
50-
sink_token_length=sink_token_length,
51-
softcap=softcap,
52-
num_splits=num_splits,
53-
pack_gqa=pack_gqa,
54-
deterministic=deterministic,
55-
sm_margin=sm_margin,
61+
q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
5662
)
5763
return outputs[0]
5864

0 commit comments

Comments
 (0)