From eedf916ae6e4f0d0bf07e15f98c88348103f5e42 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 6 Oct 2025 14:30:27 +0200 Subject: [PATCH 1/6] up --- .../generation/continuous_batching/continuous_api.py | 7 +++++++ src/transformers/modeling_utils.py | 2 -- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 8fad631f3915..b25442c4e819 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -604,6 +604,13 @@ def __init__( streaming: Whether to stream tokens as they are generated """ self.model = model.eval() + attn_implementation = self.model.config._attn_implementation + # We need to use the wrapper around `paged_attention` but the implementation set. + # The user could be using the flash fallback `kernels-community/flash-attn` if fa2 is not installed + # this does: kernel_function = partial(attention_wrapper, implementation=kernel) + # which passes the loaded kernel. + # If the user selected "flash_attention2" but does not have it -> set_xxx will replace it + self.model.set_attn_implementation(f"paged_attention|{attn_implementation}") generation_config = model.generation_config if generation_config is None else generation_config self.generation_config = generation_config self.input_queue = queue.Queue(maxsize=max_queue_size) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4ce8cd01de4d..a0d28555e6b7 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2635,8 +2635,6 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]): else attn_implementation.get("", self.config._attn_implementation) ) - # At this point, the model was already instantiated, so instead of crashing on bad value, let's simply - # warn the user that the requested value is not working if requested_implementation != self.config._attn_implementation: # In this case, raise if not self._can_set_attn_implementation(): From 4e090433cb30c8b8c2e7de79ee34cb31f41dfd7d Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 6 Oct 2025 15:15:35 +0200 Subject: [PATCH 2/6] refactor the way we handle paged attention --- examples/pytorch/continuous_batching.py | 4 +--- examples/pytorch/continuous_batching_simple.py | 4 +--- .../generation/continuous_batching/cache.py | 2 +- .../continuous_batching/continuous_api.py | 16 ++++++++-------- src/transformers/modeling_utils.py | 8 ++++---- tests/generation/test_continuous_batching.py | 9 +++++++++ tests/generation/test_paged_attention.py | 16 ++++++++-------- 7 files changed, 32 insertions(+), 27 deletions(-) diff --git a/examples/pytorch/continuous_batching.py b/examples/pytorch/continuous_batching.py index cf5379fc619c..39fad6cb7a4e 100644 --- a/examples/pytorch/continuous_batching.py +++ b/examples/pytorch/continuous_batching.py @@ -184,9 +184,7 @@ def batch_generate( parser.add_argument("--num-blocks", "-n", type=int, default=None) parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) - parser.add_argument( - "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" - ) + parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation") parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable parser.add_argument("--no-slice-inputs", action="store_true") # slicing is enabled by default because much faster parser.add_argument("--use-cuda-graph", "-cg", action="store_true") diff --git a/examples/pytorch/continuous_batching_simple.py b/examples/pytorch/continuous_batching_simple.py index 3ae5e3d83870..8048042eb485 100644 --- a/examples/pytorch/continuous_batching_simple.py +++ b/examples/pytorch/continuous_batching_simple.py @@ -31,9 +31,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--num-blocks", "-n", type=int, default=None) parser.add_argument("--max-batch-tokens", "-b", type=int, default=None) - parser.add_argument( - "--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation" - ) + parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation") parser.add_argument("--samples", type=int, default=500) args = parser.parse_args() diff --git a/src/transformers/generation/continuous_batching/cache.py b/src/transformers/generation/continuous_batching/cache.py index f46c6fa811fc..e0373e8a65a0 100644 --- a/src/transformers/generation/continuous_batching/cache.py +++ b/src/transformers/generation/continuous_batching/cache.py @@ -174,7 +174,7 @@ def __init__( # Infer number of blocks and max batch tokens page_size = self.head_dim * self.num_key_value_heads - if getattr(config, "attn_implementation", None) == "paged_attention": + if "flash" in self.config._attn_implementation: num_attention_masks = 0 else: # TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))` diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index b25442c4e819..38b00d29e0a8 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -27,6 +27,7 @@ from ...configuration_utils import PreTrainedConfig from ...generation.configuration_utils import GenerationConfig +from ...integrations.hub_kernels import load_and_register_kernel from ...utils.logging import logging from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced from .cache import PagedAttentionCache @@ -241,7 +242,10 @@ def setup_static_tensors(self, num_groups: int) -> None: self.reset_static_tensors(full_reset=True) def return_attention_mask(self) -> bool: - return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call + return self.config._attn_implementation in [ + "paged|eager", + "paged|sdpa", + ] # we set `is_causal` to True in paged call @traced @torch.no_grad() @@ -604,13 +608,9 @@ def __init__( streaming: Whether to stream tokens as they are generated """ self.model = model.eval() - attn_implementation = self.model.config._attn_implementation - # We need to use the wrapper around `paged_attention` but the implementation set. - # The user could be using the flash fallback `kernels-community/flash-attn` if fa2 is not installed - # this does: kernel_function = partial(attention_wrapper, implementation=kernel) - # which passes the loaded kernel. - # If the user selected "flash_attention2" but does not have it -> set_xxx will replace it - self.model.set_attn_implementation(f"paged_attention|{attn_implementation}") + attn_implementation = "paged|" + self.model.config._attn_implementation + load_and_register_kernel(attn_implementation) + model.set_attn_implementation(attn_implementation) generation_config = model.generation_config if generation_config is None else generation_config self.generation_config = generation_config self.input_queue = queue.Queue(maxsize=max_queue_size) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a0d28555e6b7..3e112eeafae4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2527,7 +2527,7 @@ def _check_and_adjust_attn_implementation( # If FA not installed, do not fail but use kernels instead if ( attn_implementation is not None - and attn_implementation.startswith("flash_attention") + and "flash" in attn_implementation and self._supports_flash_attn and not (is_flash_attn_2_available() or is_flash_attn_3_available()) and is_kernels_available() @@ -5902,10 +5902,10 @@ class AttentionInterface(GeneralInterface): "flash_attention_3": flash_attention_forward, "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, - "paged_attention": paged_attention_forward, "sdpa": sdpa_attention_forward, - "sdpa_paged": sdpa_attention_paged_forward, - "eager_paged": eager_paged_attention_forward, + "paged|flash_attention2": paged_attention_forward, + "paged|sdpa": sdpa_attention_paged_forward, + "paged|eager": eager_paged_attention_forward, } diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 943320bfe00b..14f2946d3dfe 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -328,6 +328,15 @@ def test_continuous_batching_parity_gpt_oss_flash(self) -> None: "openai/gpt-oss-20b", "paged_attention|kernels-community/flash-attn", expected_outputs ) + def test_attn_implementation(self) -> None: + model = AutoModelForCausalLM.from_pretrained("gpt2") + manager = model.init_continuous_batching() + assert "paged|sdpa" in manager.model.config._attn_implementation + + model = AutoModelForCausalLM.from_pretrained("gpt2", _attn_implementation="eager") + manager = model.init_continuous_batching() + assert "paged|eager" in manager.model.config._attn_implementation + # FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are # inverted on CUDA. On AMD they do fine. diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py index e7673f5f08cd..837da1d73587 100644 --- a/tests/generation/test_paged_attention.py +++ b/tests/generation/test_paged_attention.py @@ -44,10 +44,10 @@ def setUpClass(cls): @parameterized.expand( [ - ("eager_paged", 64, 128, 64), - ("sdpa_paged", 32, 256, 128), - ("paged_attention", 16, 512, 256), - ("flex_paged", 64, 128, 64), + ("paged|eager", 64, 128, 64), + ("paged|sdpa", 32, 256, 128), + ("paged|flash_attention2", 16, 512, 256), + ("paged|flex_attention", 64, 128, 64), ] ) def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens): @@ -89,10 +89,10 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max @parameterized.expand( [ - ("eager_paged", 64, 128, 64), - ("sdpa_paged", 32, 256, 128), - ("paged_attention", 16, 512, 256), - ("flex_paged", 64, 128, 64), + ("paged|eager", 64, 128, 64), + ("paged|sdpa", 32, 256, 128), + ("paged|flash_attention2", 16, 512, 256), + ("paged|flex_attention", 64, 128, 64), ] ) def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens): From c44815be15e333a801b5db3db12a89daed168516 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 6 Oct 2025 15:33:14 +0200 Subject: [PATCH 3/6] affect serve as well --- src/transformers/commands/serving.py | 13 ------------- .../continuous_batching/continuous_api.py | 8 -------- 2 files changed, 21 deletions(-) diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 970d59c96e74..d46b8b55694a 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -489,19 +489,6 @@ def __init__(self, args: ServeArguments): # Store and process input arguments self.args = args self.use_continuous_batching = self.args.continuous_batching - if self.use_continuous_batching: - default_attn_impl = ContinuousBatchingManager.default_attention_implementation() - # checking if attn_implementation is supported by continuous batching - if self.args.attn_implementation is None: - self.args.attn_implementation = default_attn_impl # default to sdpa_paged - logger.info(f"No attn_implementation passed, defaulting to {default_attn_impl}") - supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations() - if self.args.attn_implementation not in supported_attn_impl: - raise ValueError( - f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got " - f"{self.args.attn_implementation}" - f"Try setting `--attn_implementation={default_attn_impl}`" - ) self.enable_cors = self.args.enable_cors if self.args.default_seed is not None: diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 38b00d29e0a8..3457da423e01 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -765,14 +765,6 @@ def request_id_iter(self, request_id): if self.batch_processor is not None: request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id) - @staticmethod - def supported_attention_implementations() -> set[str]: - return {"eager_paged", "sdpa_paged", "flash_attention_2"} - - @staticmethod - def default_attention_implementation() -> str: - return "sdpa_paged" - @traced def warmup(self, batch_processor): stream = torch.cuda.Stream(device=self.model.device) From 7ed8e2bb2494e3d1ce8ed95b669cc7550a93aa85 Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 6 Oct 2025 16:43:02 +0200 Subject: [PATCH 4/6] update --- .../generation/continuous_batching/continuous_api.py | 7 ++++--- src/transformers/modeling_utils.py | 3 ++- tests/generation/test_paged_attention.py | 6 ++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/continuous_batching/continuous_api.py b/src/transformers/generation/continuous_batching/continuous_api.py index 3457da423e01..e9adc98fc6af 100644 --- a/src/transformers/generation/continuous_batching/continuous_api.py +++ b/src/transformers/generation/continuous_batching/continuous_api.py @@ -608,9 +608,10 @@ def __init__( streaming: Whether to stream tokens as they are generated """ self.model = model.eval() - attn_implementation = "paged|" + self.model.config._attn_implementation - load_and_register_kernel(attn_implementation) - model.set_attn_implementation(attn_implementation) + if "paged|" not in model.config._attn_implementation: + attn_implementation = "paged|" + self.model.config._attn_implementation + load_and_register_kernel(attn_implementation) + model.set_attn_implementation(attn_implementation) generation_config = model.generation_config if generation_config is None else generation_config self.generation_config = generation_config self.input_queue = queue.Queue(maxsize=max_queue_size) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e112eeafae4..4a03070a9013 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5903,7 +5903,8 @@ class AttentionInterface(GeneralInterface): "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "sdpa": sdpa_attention_forward, - "paged|flash_attention2": paged_attention_forward, + "paged|flash_attention_2 +": paged_attention_forward, "paged|sdpa": sdpa_attention_paged_forward, "paged|eager": eager_paged_attention_forward, } diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py index 837da1d73587..bf6e6f10c48a 100644 --- a/tests/generation/test_paged_attention.py +++ b/tests/generation/test_paged_attention.py @@ -46,7 +46,8 @@ def setUpClass(cls): [ ("paged|eager", 64, 128, 64), ("paged|sdpa", 32, 256, 128), - ("paged|flash_attention2", 16, 512, 256), + ("paged|flash_attention_2 +", 16, 512, 256), ("paged|flex_attention", 64, 128, 64), ] ) @@ -91,7 +92,8 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max [ ("paged|eager", 64, 128, 64), ("paged|sdpa", 32, 256, 128), - ("paged|flash_attention2", 16, 512, 256), + ("paged|flash_attention_2 +", 16, 512, 256), ("paged|flex_attention", 64, 128, 64), ] ) From b3071c6e0c36db3e040acc4b338ae4b662befbda Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 6 Oct 2025 16:44:52 +0200 Subject: [PATCH 5/6] fix --- src/transformers/modeling_utils.py | 3 +-- tests/generation/test_continuous_batching.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4a03070a9013..76fe9a305fdc 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5903,8 +5903,7 @@ class AttentionInterface(GeneralInterface): "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "sdpa": sdpa_attention_forward, - "paged|flash_attention_2 -": paged_attention_forward, + "paged|flash_attention_2": paged_attention_forward, "paged|sdpa": sdpa_attention_paged_forward, "paged|eager": eager_paged_attention_forward, } diff --git a/tests/generation/test_continuous_batching.py b/tests/generation/test_continuous_batching.py index 14f2946d3dfe..b5476aa5f398 100644 --- a/tests/generation/test_continuous_batching.py +++ b/tests/generation/test_continuous_batching.py @@ -331,11 +331,11 @@ def test_continuous_batching_parity_gpt_oss_flash(self) -> None: def test_attn_implementation(self) -> None: model = AutoModelForCausalLM.from_pretrained("gpt2") manager = model.init_continuous_batching() - assert "paged|sdpa" in manager.model.config._attn_implementation + assert "paged|sdpa" == manager.model.config._attn_implementation model = AutoModelForCausalLM.from_pretrained("gpt2", _attn_implementation="eager") manager = model.init_continuous_batching() - assert "paged|eager" in manager.model.config._attn_implementation + assert "paged|eager" == manager.model.config._attn_implementation # FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are From d3913951c92cea85d3974a2061939e3237d86a0a Mon Sep 17 00:00:00 2001 From: Arthur Date: Mon, 6 Oct 2025 16:47:59 +0200 Subject: [PATCH 6/6] cup --- tests/generation/test_paged_attention.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_paged_attention.py b/tests/generation/test_paged_attention.py index bf6e6f10c48a..0cb13eb1dc23 100644 --- a/tests/generation/test_paged_attention.py +++ b/tests/generation/test_paged_attention.py @@ -46,8 +46,7 @@ def setUpClass(cls): [ ("paged|eager", 64, 128, 64), ("paged|sdpa", 32, 256, 128), - ("paged|flash_attention_2 -", 16, 512, 256), + ("paged|flash_attention_2", 16, 512, 256), ("paged|flex_attention", 64, 128, 64), ] ) @@ -92,8 +91,7 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max [ ("paged|eager", 64, 128, 64), ("paged|sdpa", 32, 256, 128), - ("paged|flash_attention_2 -", 16, 512, 256), + ("paged|flash_attention_2", 16, 512, 256), ("paged|flex_attention", 64, 128, 64), ] )