Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def batch_generate(
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)

parser.add_argument(
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
parser.add_argument("--no-slice-inputs", action="store_true") # slicing is enabled by default because much faster
parser.add_argument("--use-cuda-graph", "-cg", action="store_true")
Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch/continuous_batching_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
parser.add_argument(
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
parser.add_argument("--samples", type=int, default=500)
args = parser.parse_args()

Expand Down
13 changes: 0 additions & 13 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,19 +489,6 @@ def __init__(self, args: ServeArguments):
# Store and process input arguments
self.args = args
self.use_continuous_batching = self.args.continuous_batching
if self.use_continuous_batching:
default_attn_impl = ContinuousBatchingManager.default_attention_implementation()
# checking if attn_implementation is supported by continuous batching
if self.args.attn_implementation is None:
self.args.attn_implementation = default_attn_impl # default to sdpa_paged
logger.info(f"No attn_implementation passed, defaulting to {default_attn_impl}")
supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations()
if self.args.attn_implementation not in supported_attn_impl:
raise ValueError(
f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got "
f"{self.args.attn_implementation}"
f"Try setting `--attn_implementation={default_attn_impl}`"
)
self.enable_cors = self.args.enable_cors

if self.args.default_seed is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
# Infer number of blocks and max batch tokens
page_size = self.head_dim * self.num_key_value_heads

if getattr(config, "attn_implementation", None) == "paged_attention":
if "flash" in self.config._attn_implementation:
num_attention_masks = 0
else:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...integrations.hub_kernels import load_and_register_kernel
from ...utils.logging import logging
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
from .cache import PagedAttentionCache
Expand Down Expand Up @@ -241,7 +242,10 @@ def setup_static_tensors(self, num_groups: int) -> None:
self.reset_static_tensors(full_reset=True)

def return_attention_mask(self) -> bool:
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
return self.config._attn_implementation in [
"paged|eager",
"paged|sdpa",
] # we set `is_causal` to True in paged call

@traced
@torch.no_grad()
Expand Down Expand Up @@ -604,6 +608,9 @@ def __init__(
streaming: Whether to stream tokens as they are generated
"""
self.model = model.eval()
attn_implementation = "paged|" + self.model.config._attn_implementation
load_and_register_kernel(attn_implementation)
model.set_attn_implementation(attn_implementation)
generation_config = model.generation_config if generation_config is None else generation_config
self.generation_config = generation_config
self.input_queue = queue.Queue(maxsize=max_queue_size)
Expand Down Expand Up @@ -758,14 +765,6 @@ def request_id_iter(self, request_id):
if self.batch_processor is not None:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)

@staticmethod
def supported_attention_implementations() -> set[str]:
return {"eager_paged", "sdpa_paged", "flash_attention_2"}

@staticmethod
def default_attention_implementation() -> str:
return "sdpa_paged"

@traced
def warmup(self, batch_processor):
stream = torch.cuda.Stream(device=self.model.device)
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,7 +2527,7 @@ def _check_and_adjust_attn_implementation(
# If FA not installed, do not fail but use kernels instead
if (
attn_implementation is not None
and attn_implementation.startswith("flash_attention")
and "flash" in attn_implementation
and self._supports_flash_attn
and not (is_flash_attn_2_available() or is_flash_attn_3_available())
and is_kernels_available()
Expand Down Expand Up @@ -2635,8 +2635,6 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
else attn_implementation.get("", self.config._attn_implementation)
)

# At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
# warn the user that the requested value is not working
if requested_implementation != self.config._attn_implementation:
# In this case, raise
if not self._can_set_attn_implementation():
Expand Down Expand Up @@ -5904,10 +5902,10 @@ class AttentionInterface(GeneralInterface):
"flash_attention_3": flash_attention_forward,
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward,
"sdpa": sdpa_attention_forward,
"sdpa_paged": sdpa_attention_paged_forward,
"eager_paged": eager_paged_attention_forward,
"paged|flash_attention2": paged_attention_forward,
"paged|sdpa": sdpa_attention_paged_forward,
"paged|eager": eager_paged_attention_forward,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should paged|flex_attention be an option as well? I see it listed below in the tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not supported yet AFAIK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good to know 👍 ty

}


Expand Down
9 changes: 9 additions & 0 deletions tests/generation/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ def test_continuous_batching_parity_gpt_oss_flash(self) -> None:
"openai/gpt-oss-20b", "paged_attention|kernels-community/flash-attn", expected_outputs
)

def test_attn_implementation(self) -> None:
model = AutoModelForCausalLM.from_pretrained("gpt2")
manager = model.init_continuous_batching()
assert "paged|sdpa" in manager.model.config._attn_implementation

model = AutoModelForCausalLM.from_pretrained("gpt2", _attn_implementation="eager")
manager = model.init_continuous_batching()
assert "paged|eager" in manager.model.config._attn_implementation


# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
# inverted on CUDA. On AMD they do fine.
16 changes: 8 additions & 8 deletions tests/generation/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def setUpClass(cls):

@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
("paged|eager", 64, 128, 64),
("paged|sdpa", 32, 256, 128),
("paged|flash_attention2", 16, 512, 256),
("paged|flex_attention", 64, 128, 64),
]
)
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
Expand Down Expand Up @@ -89,10 +89,10 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max

@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
("paged|eager", 64, 128, 64),
("paged|sdpa", 32, 256, 128),
("paged|flash_attention2", 16, 512, 256),
("paged|flex_attention", 64, 128, 64),
]
)
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
Expand Down