|
16 | 16 |
|
17 | 17 | import functools
|
18 | 18 | from functools import cache
|
19 |
| -from typing import Any, List, Optional, Tuple |
| 19 | +from typing import Any, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import torch
|
22 | 22 |
|
@@ -418,6 +418,7 @@ def plan(
|
418 | 418 | rope_scale: Optional[float] = None,
|
419 | 419 | rope_theta: Optional[float] = None,
|
420 | 420 | q_data_type: str = "float16",
|
| 421 | + kv_data_type: Optional[Union[str, torch.dtype]] = None, |
421 | 422 | ):
|
422 | 423 | r"""Create auxiliary data structures for multi-level cascade attention for multiple
|
423 | 424 | forward calls within the same decode step. Please check
|
@@ -476,6 +477,8 @@ def plan(
|
476 | 477 | The theta used in RoPE, if not provided, will be set to ``1e4``.
|
477 | 478 | q_data_type : Optional[Union[str, torch.dtype]]
|
478 | 479 | The data type of the query tensor. If None, will be set to torch.float16.
|
| 480 | + kv_data_type : Optional[Union[str, torch.dtype]] |
| 481 | + The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. |
479 | 482 | """
|
480 | 483 | for i, (
|
481 | 484 | wrapper,
|
@@ -510,6 +513,7 @@ def plan(
|
510 | 513 | rope_scale=rope_scale,
|
511 | 514 | rope_theta=rope_theta,
|
512 | 515 | q_data_type=q_data_type,
|
| 516 | + kv_data_type=kv_data_type, |
513 | 517 | )
|
514 | 518 |
|
515 | 519 | begin_forward = plan
|
|
0 commit comments