Skip to content

Commit f696fbd

Browse files
committed
Import flash_dmattn_varlen_func alongside flash_dmattn_func for backend availability
1 parent ddc8cd5 commit f696fbd

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

flash_dmattn/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77

88
# Import CUDA functions when available
99
try:
10-
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
10+
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func, flash_dmattn_varlen_func
1111
CUDA_AVAILABLE = True
1212
except ImportError:
1313
CUDA_AVAILABLE = False
14-
flash_dmattn_func = None
14+
flash_dmattn_func, flash_dmattn_varlen_func = None, None
1515

1616
# Import Triton functions when available
1717
try:
@@ -89,6 +89,7 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs):
8989
"TRITON_AVAILABLE",
9090
"FLEX_AVAILABLE",
9191
"flash_dmattn_func",
92+
"flash_dmattn_varlen_func",
9293
"triton_dmattn_func",
9394
"flex_dmattn_func",
9495
"get_available_backends",

0 commit comments

Comments
 (0)