Skip to content

Commit 88e946e

Browse files
authored
Fix early CUDA initialisation (#41409)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 93464a0 commit 88e946e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
from ...configuration_utils import PreTrainedConfig
2929
from ...generation.configuration_utils import GenerationConfig
30-
from ...integrations.hub_kernels import load_and_register_kernel
3130
from ...utils.logging import logging
3231
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
3332
from .cache import PagedAttentionCache
@@ -609,6 +608,8 @@ def __init__(
609608
"""
610609
self.model = model.eval()
611610
if "paged|" not in model.config._attn_implementation:
611+
from ...integrations.hub_kernels import load_and_register_kernel
612+
612613
attn_implementation = "paged|" + self.model.config._attn_implementation
613614
load_and_register_kernel(attn_implementation)
614615
model.set_attn_implementation(attn_implementation)

0 commit comments

Comments
 (0)