Skip to content

Commit 0395ed5

Browse files
authored
[CB] Refactors the way we access paged (#41370)
* up * refactor the way we handle paged attention * affect serve as well * update * fix * cup
1 parent 39b0c94 commit 0395ed5

File tree

8 files changed

+33
-43
lines changed

8 files changed

+33
-43
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,7 @@ def batch_generate(
184184
parser.add_argument("--num-blocks", "-n", type=int, default=None)
185185
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
186186

187-
parser.add_argument(
188-
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
189-
)
187+
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
190188
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
191189
parser.add_argument("--no-slice-inputs", action="store_true") # slicing is enabled by default because much faster
192190
parser.add_argument("--use-cuda-graph", "-cg", action="store_true")

examples/pytorch/continuous_batching_simple.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
parser = argparse.ArgumentParser()
3232
parser.add_argument("--num-blocks", "-n", type=int, default=None)
3333
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
34-
parser.add_argument(
35-
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
36-
)
34+
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
3735
parser.add_argument("--samples", type=int, default=500)
3836
args = parser.parse_args()
3937

src/transformers/commands/serving.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -489,19 +489,6 @@ def __init__(self, args: ServeArguments):
489489
# Store and process input arguments
490490
self.args = args
491491
self.use_continuous_batching = self.args.continuous_batching
492-
if self.use_continuous_batching:
493-
default_attn_impl = ContinuousBatchingManager.default_attention_implementation()
494-
# checking if attn_implementation is supported by continuous batching
495-
if self.args.attn_implementation is None:
496-
self.args.attn_implementation = default_attn_impl # default to sdpa_paged
497-
logger.info(f"No attn_implementation passed, defaulting to {default_attn_impl}")
498-
supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations()
499-
if self.args.attn_implementation not in supported_attn_impl:
500-
raise ValueError(
501-
f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got "
502-
f"{self.args.attn_implementation}"
503-
f"Try setting `--attn_implementation={default_attn_impl}`"
504-
)
505492
self.enable_cors = self.args.enable_cors
506493

507494
if self.args.default_seed is not None:

src/transformers/generation/continuous_batching/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def __init__(
172172
# Infer number of blocks and max batch tokens
173173
page_size = self.head_dim * self.num_key_value_heads
174174

175-
if getattr(config, "attn_implementation", None) == "paged_attention":
175+
if "flash" in self.config._attn_implementation:
176176
num_attention_masks = 0
177177
else:
178178
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from ...configuration_utils import PreTrainedConfig
2929
from ...generation.configuration_utils import GenerationConfig
30+
from ...integrations.hub_kernels import load_and_register_kernel
3031
from ...utils.logging import logging
3132
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
3233
from .cache import PagedAttentionCache
@@ -241,7 +242,10 @@ def setup_static_tensors(self, num_groups: int) -> None:
241242
self.reset_static_tensors(full_reset=True)
242243

243244
def return_attention_mask(self) -> bool:
244-
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
245+
return self.config._attn_implementation in [
246+
"paged|eager",
247+
"paged|sdpa",
248+
] # we set `is_causal` to True in paged call
245249

246250
@traced
247251
@torch.no_grad()
@@ -604,6 +608,10 @@ def __init__(
604608
streaming: Whether to stream tokens as they are generated
605609
"""
606610
self.model = model.eval()
611+
if "paged|" not in model.config._attn_implementation:
612+
attn_implementation = "paged|" + self.model.config._attn_implementation
613+
load_and_register_kernel(attn_implementation)
614+
model.set_attn_implementation(attn_implementation)
607615
generation_config = model.generation_config if generation_config is None else generation_config
608616
self.generation_config = generation_config
609617
self.input_queue = queue.Queue(maxsize=max_queue_size)
@@ -758,14 +766,6 @@ def request_id_iter(self, request_id):
758766
if self.batch_processor is not None:
759767
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)
760768

761-
@staticmethod
762-
def supported_attention_implementations() -> set[str]:
763-
return {"eager_paged", "sdpa_paged", "flash_attention_2"}
764-
765-
@staticmethod
766-
def default_attention_implementation() -> str:
767-
return "sdpa_paged"
768-
769769
@traced
770770
def warmup(self, batch_processor):
771771
stream = torch.cuda.Stream(device=self.model.device)

src/transformers/modeling_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,7 +2509,7 @@ def _check_and_adjust_attn_implementation(
25092509
# If FA not installed, do not fail but use kernels instead
25102510
if (
25112511
attn_implementation is not None
2512-
and attn_implementation.startswith("flash_attention")
2512+
and "flash" in attn_implementation
25132513
and self._supports_flash_attn
25142514
and not (is_flash_attn_2_available() or is_flash_attn_3_available())
25152515
and is_kernels_available()
@@ -2617,8 +2617,6 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
26172617
else attn_implementation.get("", self.config._attn_implementation)
26182618
)
26192619

2620-
# At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
2621-
# warn the user that the requested value is not working
26222620
if requested_implementation != self.config._attn_implementation:
26232621
# In this case, raise
26242622
if not self._can_set_attn_implementation():
@@ -5834,10 +5832,10 @@ class AttentionInterface(GeneralInterface):
58345832
"flash_attention_3": flash_attention_forward,
58355833
"flash_attention_2": flash_attention_forward,
58365834
"flex_attention": flex_attention_forward,
5837-
"paged_attention": paged_attention_forward,
58385835
"sdpa": sdpa_attention_forward,
5839-
"sdpa_paged": sdpa_attention_paged_forward,
5840-
"eager_paged": eager_paged_attention_forward,
5836+
"paged|flash_attention_2": paged_attention_forward,
5837+
"paged|sdpa": sdpa_attention_paged_forward,
5838+
"paged|eager": eager_paged_attention_forward,
58415839
}
58425840

58435841

tests/generation/test_continuous_batching.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,15 @@ def test_continuous_batching_parity_gpt_oss_flash(self) -> None:
328328
"openai/gpt-oss-20b", "paged_attention|kernels-community/flash-attn", expected_outputs
329329
)
330330

331+
def test_attn_implementation(self) -> None:
332+
model = AutoModelForCausalLM.from_pretrained("gpt2")
333+
manager = model.init_continuous_batching()
334+
assert "paged|sdpa" == manager.model.config._attn_implementation
335+
336+
model = AutoModelForCausalLM.from_pretrained("gpt2", _attn_implementation="eager")
337+
manager = model.init_continuous_batching()
338+
assert "paged|eager" == manager.model.config._attn_implementation
339+
331340

332341
# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
333342
# inverted on CUDA. On AMD they do fine.

tests/generation/test_paged_attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def setUpClass(cls):
4444

4545
@parameterized.expand(
4646
[
47-
("eager_paged", 64, 128, 64),
48-
("sdpa_paged", 32, 256, 128),
49-
("paged_attention", 16, 512, 256),
50-
("flex_paged", 64, 128, 64),
47+
("paged|eager", 64, 128, 64),
48+
("paged|sdpa", 32, 256, 128),
49+
("paged|flash_attention_2", 16, 512, 256),
50+
("paged|flex_attention", 64, 128, 64),
5151
]
5252
)
5353
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
@@ -89,10 +89,10 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max
8989

9090
@parameterized.expand(
9191
[
92-
("eager_paged", 64, 128, 64),
93-
("sdpa_paged", 32, 256, 128),
94-
("paged_attention", 16, 512, 256),
95-
("flex_paged", 64, 128, 64),
92+
("paged|eager", 64, 128, 64),
93+
("paged|sdpa", 32, 256, 128),
94+
("paged|flash_attention_2", 16, 512, 256),
95+
("paged|flex_attention", 64, 128, 64),
9696
]
9797
)
9898
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):

0 commit comments

Comments
 (0)