Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/pytorch/continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ def batch_generate(
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)

parser.add_argument(
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
parser.add_argument("--matmul-precision", "-mp", type=str, default="high") # set to "none" to disable
parser.add_argument("--no-slice-inputs", action="store_true") # slicing is enabled by default because much faster
parser.add_argument("--use-cuda-graph", "-cg", action="store_true")
Expand Down
4 changes: 1 addition & 3 deletions examples/pytorch/continuous_batching_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("--num-blocks", "-n", type=int, default=None)
parser.add_argument("--max-batch-tokens", "-b", type=int, default=None)
parser.add_argument(
"--attn", type=str, default="paged_attention|kernels-community/flash-attn", help="Attention implementation"
)
parser.add_argument("--attn", type=str, default="kernels-community/flash-attn", help="Attention implementation")
parser.add_argument("--samples", type=int, default=500)
args = parser.parse_args()

Expand Down
13 changes: 0 additions & 13 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,19 +489,6 @@ def __init__(self, args: ServeArguments):
# Store and process input arguments
self.args = args
self.use_continuous_batching = self.args.continuous_batching
if self.use_continuous_batching:
default_attn_impl = ContinuousBatchingManager.default_attention_implementation()
# checking if attn_implementation is supported by continuous batching
if self.args.attn_implementation is None:
self.args.attn_implementation = default_attn_impl # default to sdpa_paged
logger.info(f"No attn_implementation passed, defaulting to {default_attn_impl}")
supported_attn_impl = ContinuousBatchingManager.supported_attention_implementations()
if self.args.attn_implementation not in supported_attn_impl:
raise ValueError(
f"Continuous batching only supports {supported_attn_impl} as attn_implementation, got "
f"{self.args.attn_implementation}"
f"Try setting `--attn_implementation={default_attn_impl}`"
)
self.enable_cors = self.args.enable_cors

if self.args.default_seed is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/continuous_batching/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
# Infer number of blocks and max batch tokens
page_size = self.head_dim * self.num_key_value_heads

if getattr(config, "attn_implementation", None) == "paged_attention":
if "flash" in self.config._attn_implementation:
num_attention_masks = 0
else:
# TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from ...configuration_utils import PreTrainedConfig
from ...generation.configuration_utils import GenerationConfig
from ...integrations.hub_kernels import load_and_register_kernel
from ...utils.logging import logging
from ...utils.metrics import ContinuousBatchProcessorMetrics, attach_tracer, traced
from .cache import PagedAttentionCache
Expand Down Expand Up @@ -241,7 +242,10 @@ def setup_static_tensors(self, num_groups: int) -> None:
self.reset_static_tensors(full_reset=True)

def return_attention_mask(self) -> bool:
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
return self.config._attn_implementation in [
"paged|eager",
"paged|sdpa",
] # we set `is_causal` to True in paged call

@traced
@torch.no_grad()
Expand Down Expand Up @@ -604,6 +608,10 @@ def __init__(
streaming: Whether to stream tokens as they are generated
"""
self.model = model.eval()
if "paged|" not in model.config._attn_implementation:
attn_implementation = "paged|" + self.model.config._attn_implementation
load_and_register_kernel(attn_implementation)
model.set_attn_implementation(attn_implementation)
generation_config = model.generation_config if generation_config is None else generation_config
self.generation_config = generation_config
self.input_queue = queue.Queue(maxsize=max_queue_size)
Expand Down Expand Up @@ -758,14 +766,6 @@ def request_id_iter(self, request_id):
if self.batch_processor is not None:
request_cancelled = self.batch_processor.scheduler.request_is_cancelled(request_id)

@staticmethod
def supported_attention_implementations() -> set[str]:
return {"eager_paged", "sdpa_paged", "flash_attention_2"}

@staticmethod
def default_attention_implementation() -> str:
return "sdpa_paged"

@traced
def warmup(self, batch_processor):
stream = torch.cuda.Stream(device=self.model.device)
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,7 +2527,7 @@ def _check_and_adjust_attn_implementation(
# If FA not installed, do not fail but use kernels instead
if (
attn_implementation is not None
and attn_implementation.startswith("flash_attention")
and "flash" in attn_implementation
and self._supports_flash_attn
and not (is_flash_attn_2_available() or is_flash_attn_3_available())
and is_kernels_available()
Expand Down Expand Up @@ -2635,8 +2635,6 @@ def set_attn_implementation(self, attn_implementation: Union[str, dict]):
else attn_implementation.get("", self.config._attn_implementation)
)

# At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
# warn the user that the requested value is not working
if requested_implementation != self.config._attn_implementation:
# In this case, raise
if not self._can_set_attn_implementation():
Expand Down Expand Up @@ -5904,10 +5902,10 @@ class AttentionInterface(GeneralInterface):
"flash_attention_3": flash_attention_forward,
"flash_attention_2": flash_attention_forward,
"flex_attention": flex_attention_forward,
"paged_attention": paged_attention_forward,
"sdpa": sdpa_attention_forward,
"sdpa_paged": sdpa_attention_paged_forward,
"eager_paged": eager_paged_attention_forward,
"paged|flash_attention_2": paged_attention_forward,
"paged|sdpa": sdpa_attention_paged_forward,
"paged|eager": eager_paged_attention_forward,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should paged|flex_attention be an option as well? I see it listed below in the tests

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not supported yet AFAIK

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok good to know 👍 ty

}


Expand Down
9 changes: 9 additions & 0 deletions tests/generation/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,15 @@ def test_continuous_batching_parity_gpt_oss_flash(self) -> None:
"openai/gpt-oss-20b", "paged_attention|kernels-community/flash-attn", expected_outputs
)

def test_attn_implementation(self) -> None:
model = AutoModelForCausalLM.from_pretrained("gpt2")
manager = model.init_continuous_batching()
assert "paged|sdpa" == manager.model.config._attn_implementation

model = AutoModelForCausalLM.from_pretrained("gpt2", _attn_implementation="eager")
manager = model.init_continuous_batching()
assert "paged|eager" == manager.model.config._attn_implementation


# FIXME: the gemma test seem broken, there is a message about cuda graphs and the sdpa and flash expecteations are
# inverted on CUDA. On AMD they do fine.
16 changes: 8 additions & 8 deletions tests/generation/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def setUpClass(cls):

@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
("paged|eager", 64, 128, 64),
("paged|sdpa", 32, 256, 128),
("paged|flash_attention_2", 16, 512, 256),
("paged|flex_attention", 64, 128, 64),
]
)
def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max_batch_tokens):
Expand Down Expand Up @@ -89,10 +89,10 @@ def test_generate_batch_consistency(self, attn_impl, num_blocks, block_size, max

@parameterized.expand(
[
("eager_paged", 64, 128, 64),
("sdpa_paged", 32, 256, 128),
("paged_attention", 16, 512, 256),
("flex_paged", 64, 128, 64),
("paged|eager", 64, 128, 64),
("paged|sdpa", 32, 256, 128),
("paged|flash_attention_2", 16, 512, 256),
("paged|flex_attention", 64, 128, 64),
]
)
def test_generate_batch_with_sampling(self, attn_impl, num_blocks, block_size, max_batch_tokens):
Expand Down