@@ -187,13 +187,16 @@ def list_backends(cls):
187187
188188
189189@contextlib .contextmanager
190- def attention_backend (backend : AttentionBackendName = AttentionBackendName .NATIVE ):
190+ def attention_backend (backend : Union [ str , AttentionBackendName ] = AttentionBackendName .NATIVE ):
191191 """
192192 Context manager to set the active attention backend.
193193 """
194194 if backend not in _AttentionBackendRegistry ._backends :
195195 raise ValueError (f"Backend { backend } is not registered." )
196196
197+ backend = AttentionBackendName (backend )
198+ _check_attention_backend_requirements (backend )
199+
197200 old_backend = _AttentionBackendRegistry ._active_backend
198201 _AttentionBackendRegistry ._active_backend = backend
199202
@@ -226,8 +229,6 @@ def dispatch_attention_fn(
226229 backend_name = AttentionBackendName (backend )
227230 backend_fn = _AttentionBackendRegistry ._backends .get (backend_name )
228231
229- _check_backend_requirements (backend_name )
230-
231232 kwargs = {
232233 "query" : query ,
233234 "key" : key ,
@@ -316,10 +317,7 @@ def _check_shape(
316317# ===== Helper functions =====
317318
318319
319- # LRU cache is hack to avoid checking the backend requirements multiple times. Maybe not needed
320- # because CPU is running much farther ahead of the accelerator and this will not be blocking anyway.
321- @functools .lru_cache (maxsize = 16 )
322- def _check_backend_requirements (backend : AttentionBackendName ) -> None :
320+ def _check_attention_backend_requirements (backend : AttentionBackendName ) -> None :
323321 if backend in [AttentionBackendName .FLASH , AttentionBackendName .FLASH_VARLEN ]:
324322 if not _CAN_USE_FLASH_ATTN :
325323 raise RuntimeError (
0 commit comments