diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index e9adc98fc6af..dcc6f287c51c 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -27,7 +27,6 @@ from ...configuration_utils import PreTrainedConfig from ...generation.configuration_utils import GenerationConfig -from ...integrations.hub_kernels import load_and_register_kernel from ...utils.logging import logging from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced from .cache import PagedAttentionCache @@ -609,6 +608,8 @@ def __init__( """ self.model = model.eval() if "paged|" not in model.config._attn_implementation: + from ...integrations.hub_kernels import load_and_register_kernel + attn_implementation = "paged|" + self.model.config._attn_implementation load_and_register_kernel(attn_implementation) model.set_attn_implementation(attn_implementation)