forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm_args.py
More file actions
2386 lines (1977 loc) · 93 KB
/
llm_args.py
File metadata and controls
2386 lines (1977 loc) · 93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import copy
import functools
import json
import math
import os
import types
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, EnumMeta
from pathlib import Path
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional,
Type, TypeAlias, TypeVar, Union, get_args, get_origin)
import torch
import yaml
from pydantic import BaseModel
from pydantic import Field as PydanticField
from pydantic import PrivateAttr, field_validator, model_validator
from strenum import StrEnum
from transformers import PreTrainedTokenizerBase
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules)
from .._utils import mpi_rank
from ..auto_parallel import AutoParallelConfig, infer_cluster_config
if TYPE_CHECKING:
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
# yapf: disable
# isort: off
from ..bindings.executor import (BatchingType as _BatchingType,
CacheTransceiverBackendType as _CacheTransceiverBackendType,
CacheTransceiverConfig as _CacheTransceiverConfig,
CapacitySchedulerPolicy as _CapacitySchedulerPolicy,
ContextChunkingPolicy as _ContextChunkingPolicy,
DecodingConfig,
DecodingMode,
DynamicBatchConfig as _DynamicBatchConfig,
EagleConfig as _EagleConfig,
ExecutorConfig as _ExecutorConfig,
ExtendedRuntimePerfKnobConfig as _ExtendedRuntimePerfKnobConfig,
KvCacheConfig as _KvCacheConfig,
LookaheadDecodingConfig as _LookaheadDecodingConfig,
PeftCacheConfig as _PeftCacheConfig,
SchedulerConfig as _SchedulerConfig) # isort: skip
# isort: on
# yapf: enable
from ..builder import BuildConfig, EngineConfig
from ..logger import logger
from ..mapping import Mapping
from ..models.automodel import AutoConfig
from ..models.modeling_utils import (PretrainedConfig, QuantAlgo, QuantConfig,
SpeculativeDecodingMode)
from ..sampling_params import BatchedLogitsProcessor
from .build_cache import BuildCacheConfig
from .tokenizer import TokenizerBase, tokenizer_factory
from .utils import generate_api_docs_as_docstring, get_type_repr
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
TypeBaseModel = TypeVar("T", bound=BaseModel)
def Field(default: Any = ...,
*,
status: Optional[Literal["prototype", "beta", "deprecated"]] = None,
**kwargs: Any) -> Any:
"""Custom Field wrapper that adds status to json_schema_extra.
Args:
default: The default value for the field
status: Optional status indicator that gets added to json_schema_extra.
- None: Stable.
- "beta": Recommended for use per the latest documentation.
- "prototype": Not yet stable and subject to breaking changes; intended for experimentation only.
**kwargs: All other arguments passed to the original Pydantic Field
Returns:
A Pydantic FieldInfo object with the status added to json_schema_extra if provided
"""
if status is not None:
json_schema_extra = kwargs.get('json_schema_extra', {})
if isinstance(json_schema_extra, dict):
json_schema_extra['status'] = status
else:
# If json_schema_extra is not a dict, create a new dict with the status
json_schema_extra = {'status': status}
kwargs['json_schema_extra'] = json_schema_extra
return PydanticField(default, **kwargs)
class StrictBaseModel(BaseModel):
"""
A base model that forbids arbitrary fields.
"""
class Config:
extra = "forbid" # globally forbid arbitrary fields
class CudaGraphConfig(StrictBaseModel):
"""
Configuration for CUDA graphs.
"""
# List of batch sizes to create CUDA graphs for.
batch_sizes: Optional[List[int]] = Field(
default=None,
description="List of batch sizes to create CUDA graphs for.")
max_batch_size: int = Field(
default=0, description="Maximum batch size for CUDA graphs.")
enable_padding: bool = Field(
default=False,
description=
"If true, batches are rounded up to the nearest cuda_graph_batch_size. This is usually a net win for performance."
)
@field_validator('max_batch_size')
@classmethod
def validate_cuda_graph_max_batch_size(cls, v):
"""Validate cuda_graph_config.max_batch_size is non-negative."""
if v < 0:
raise ValueError(
"cuda_graph_config.max_batch_size must be non-negative")
return v
@staticmethod
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
enable_padding: bool) -> List[int]:
"""Generate a list of batch sizes for CUDA graphs.
Args:
max_batch_size: Maximum batch size to generate up to
enable_padding: Whether padding is enabled, which affects the batch size distribution
Returns:
List of batch sizes to create CUDA graphs for
"""
if enable_padding:
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]
# Add powers of 2 up to max_batch_size
batch_sizes += [
2**i for i in range(8, math.floor(math.log(max_batch_size, 2)))
]
# Filter and sort batch sizes
batch_sizes = sorted(
[size for size in batch_sizes if size <= max_batch_size])
# Add max_batch_size if not already included
if max_batch_size != batch_sizes[-1]:
batch_sizes.append(max_batch_size)
return batch_sizes
class MoeConfig(StrictBaseModel):
"""
Configuration for MoE.
"""
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM",
"VANILLA"] = Field(default='CUTLASS',
description="MoE backend to use.")
max_num_tokens: Optional[int] = Field(
default=None,
description=
"If set, at most max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds max_num_tokens, the input tensors will be split into chunks and a for loop will be used."
)
load_balancer: Optional[Union[object, str]] = Field(
default=None,
description="Configuration for MoE load balancing.",
json_schema_extra={"type": "Union[MoeLoadBalancerConfig, str]"})
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@dataclass
class _ParallelConfig:
''' The model distribution configs for LLM. '''
tp_size: int = 1
pp_size: int = 1
cp_size: int = 1
gpus_per_node: int = 8
moe_cluster_size: int = 1
moe_tp_size: int = 1
moe_ep_size: int = 1
cp_config: dict = field(default_factory=dict)
enable_attention_dp: bool = False
auto_parallel: bool = False
_world_size: int = field(default=1, init=False)
_devices: Optional[List[int]] = field(default=None, init=False)
@property
def devices(self) -> List[int]:
if self._devices is None:
return list(range(self.world_size))
return self._devices
@devices.setter
def devices(self, devices: List[int]):
if len(devices) != self.world_size:
raise ValueError(
f"devices {devices} should have the same length as world_size {self.world_size}"
)
self._devices = devices
@property
def world_size(self) -> bool:
if self.auto_parallel:
if self.tp_size > 1 or self.pp_size > 1 or self.cp_size > 1:
raise RuntimeError(
"manually TP and PP are not supported in auto parallel mode."
)
return self._world_size
if self._world_size > 1:
raise RuntimeError(
"world_size > 1 is only supported in auto parallel mode.")
return self.tp_size * self.pp_size * self.cp_size
@property
def world_size_per_node(self) -> int:
world_size = self.world_size
total_nodes = math.ceil(world_size / self.gpus_per_node)
return world_size // total_nodes #TODO is this right?
@world_size.setter
def world_size(self, world_size: int):
if self.auto_parallel:
self._world_size = world_size
elif (not self.auto_parallel
) and world_size != self.tp_size * self.pp_size * self.cp_size:
raise ValueError(
f"world_size {world_size} should be equal to tp_size * pp_size {self.tp_size * self.pp_size * self.cp_size} "
)
@property
def is_multi_gpu(self) -> bool:
return self.world_size > 1
def to_mapping(self) -> Mapping:
return Mapping(world_size=self.world_size,
rank=mpi_rank(),
gpus_per_node=self.gpus_per_node,
tp_size=self.tp_size,
pp_size=self.pp_size,
cp_size=self.cp_size,
cp_config=self.cp_config,
enable_attention_dp=self.enable_attention_dp,
moe_cluster_size=self.moe_cluster_size,
moe_tp_size=self.moe_tp_size,
moe_ep_size=self.moe_ep_size,
auto_parallel=self.auto_parallel)
class CalibConfig(StrictBaseModel):
"""
Calibration configuration.
"""
device: Literal['cuda',
'cpu'] = Field(default='cuda',
description="The device to run calibration.")
calib_dataset: str = Field(
default='cnn_dailymail',
description="The name or local path of calibration dataset.")
calib_batches: int = Field(
default=512,
description="The number of batches that the calibration runs.")
calib_batch_size: int = Field(
default=1, description="The batch size that the calibration runs.")
calib_max_seq_length: int = Field(
default=512,
description="The maximum sequence length that the calibration runs.")
random_seed: int = Field(
default=1234, description="The random seed used for calibration.")
tokenizer_max_seq_length: int = Field(
default=2048,
description=
"The maximum sequence length to initialize tokenizer for calibration.")
@classmethod
def from_dict(cls, config: dict) -> 'CalibConfig':
"""Create a CalibConfig instance from a dict.
Args:
config (dict): The dict used to create CalibConfig.
Returns:
tensorrt_llm.llmapi.CalibConfig: The CalibConfig created from dict.
"""
return cls(**config)
def to_dict(self) -> dict:
"""Dump a CalibConfig instance to a dict.
Returns:
dict: The dict dumped from CalibConfig.
"""
return self.model_dump()
class _ModelFormatKind(Enum):
HF = 0
TLLM_CKPT = 1
TLLM_ENGINE = 2
class DecodingBaseConfig(StrictBaseModel):
max_draft_len: Optional[int] = None
speculative_model_dir: Optional[Union[str, Path]] = None
@classmethod
def from_dict(cls, data: dict):
# dispatch to the correct decoding config
decoding_type = data.get("decoding_type")
config_classes = {
"MTP": MTPDecodingConfig,
"Medusa": MedusaDecodingConfig,
"Eagle": EagleDecodingConfig,
"Lookahead": LookaheadDecodingConfig,
"NGram": NGramDecodingConfig,
"DraftTarget": DraftTargetDecodingConfig,
"UserProvided": UserProvidedDecodingConfig,
"AUTO": AutoDecodingConfig,
}
config_class = config_classes.get(decoding_type)
if config_class is None:
raise ValueError(f"Invalid decoding type: {decoding_type}")
data.pop("decoding_type")
return config_class(**data)
def _check_fields(self):
pass
def supports_backend(self, backend: str) -> bool:
"""
Override if the speculation algorithm does not support
a subset of the possible backends.
"""
return True
def validate(self) -> None:
"""
Do any additional error checking here.
"""
@functools.cached_property
def spec_dec_mode(self):
# spec_dec_mode has more functionality than the raw decoding_mode string.
# Use an alias for the import here to avoid name collisions with the one for the
# TRT backend.
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
return TorchSpeculativeDecodingMode.from_string(
self.decoding_type.upper())
class MedusaDecodingConfig(DecodingBaseConfig):
medusa_choices: Optional[List[List[int]]] = None
num_medusa_heads: Optional[int] = None
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "Medusa"
def supports_backend(self, backend: str) -> bool:
return backend not in ("pytorch", "_autodeploy")
class EagleDecodingConfig(DecodingBaseConfig):
eagle_choices: Optional[List[List[int]]] = None
greedy_sampling: Optional[bool] = True
posterior_threshold: Optional[float] = None
use_dynamic_tree: Optional[bool] = False
dynamic_tree_max_topK: Optional[int] = None
num_eagle_layers: Optional[int] = None
max_non_leaves_per_layer: Optional[int] = None
eagle3_one_model: Optional[bool] = True
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "Eagle"
def validate(self) -> None:
if self.speculative_model_dir is None:
raise ValueError("Draft model must be provided for EAGLE")
@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
if self.eagle3_one_model:
return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL
return TorchSpeculativeDecodingMode.EAGLE3
class UserProvidedDecodingConfig(DecodingBaseConfig):
# Cannot use real type annotations due to circular imports
drafter: object # Type is Drafter
resource_manager: object = None # Type is Optional[ResourceManager]
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "User_Provided"
class NGramDecodingConfig(DecodingBaseConfig):
"""
Configuration for NGram drafter speculative decoding.
Arguments:
max_draft_len: int
The length maximum of draft tokens (can be understood as length maximum of output draft tokens).
max_matching_ngram_size: int
The length maximum of searching tokens (can be understood as length maximum of input tokens to search).
is_keep_all: bool = True
Whether to keep all candidate pattern-matches pairs, only one match is kept for each pattern if False.
is_use_oldest: bool = True
Whether to provide the oldest match when pattern is hit, the newest one is provided if False.
is_public_pool: bool = True
Whether to use a common pool for all requests, or the pool is private for each request if False.
"""
max_matching_ngram_size: int = 0
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
# User should not set this flag. Use AutoDecodingConfig instead.
is_auto_heuristic: bool = False
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "NGram"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
class DraftTargetDecodingConfig(DecodingBaseConfig):
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "Draft_Target"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
class MTPDecodingConfig(DecodingBaseConfig):
num_nextn_predict_layers: int = 1
use_relaxed_acceptance_for_thinking: bool = False
relaxed_topk: int = 1
relaxed_delta: float = 0.
use_mtp_vanilla: bool = False
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
num_nextn_predict_layers_from_model_config: int = 1
# TODO: Hard code for DeepSeek R1
# When encounter <think>, start thinking phase.
# When encounter </think>, end thinking phase.
# <think> [thinking phase] </think> [real output]
BEGIN_THINKING_PHASE_TOKEN: int = 128798
END_THINKING_PHASE_TOKEN: int = 128799
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "MTP"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
@functools.cached_property
def spec_dec_mode(self):
from tensorrt_llm._torch.speculative.interface import \
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
if self.num_nextn_predict_layers_from_model_config == 1 and not self.use_mtp_vanilla:
return TorchSpeculativeDecodingMode.MTP_EAGLE
return TorchSpeculativeDecodingMode.MTP
class AutoDecodingConfig(DecodingBaseConfig):
"""
Configuration for auto speculative decoding.
This config is used to automatically select the best speculative decoding algorithm.
According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
Default heuristic:
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
With concurrency > 32, speculative decoding is disabled.
"""
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
decoding_type: ClassVar[str] = "AUTO"
def supports_backend(self, backend: str) -> bool:
return backend == "pytorch"
class PybindMirror(ABC):
''' A class containing the utilities for mirroring Python classes to
pybinding classes.
'''
@abstractmethod
def _to_pybind(self):
pass
@staticmethod
def maybe_to_pybind(ins):
if isinstance(
ins,
PybindMirror) or type(ins).__class__ == PybindMirrorEnumMeta:
return ins._to_pybind()
return ins
@staticmethod
def mirror_pybind_fields(pybind_class):
"""
Class decorator that ensures Python class fields mirror those of a C++ class.
Args:
pybind_class: The C++ class whose fields should be mirrored
Returns:
A decorator function that validates field mirroring
"""
def decorator(cls):
assert issubclass(cls, StrictBaseModel)
# Get all non-private fields from the C++ class
cpp_fields = PybindMirror.get_pybind_variable_fields(pybind_class)
python_fields = set(cls.model_fields.keys())
# Check if all C++ fields exist in the Python class
for field in cpp_fields:
if field not in python_fields:
raise ValueError(
f"Field {field} is not mirrored in Python class {cls.__name__} from C++ class {pybind_class.__name__}. Please update the class."
)
# Return the original class
return cls
return decorator
@staticmethod
def get_pybind_enum_fields(pybind_class):
''' Get all the enum fields from the pybind class. '''
return [
f for f in pybind_class.__members__.keys()
if not f.startswith('_') and not callable(getattr(pybind_class, f))
]
@staticmethod
def mirror_pybind_enum(pybind_class):
''' Mirror the enum fields from the pybind class to the Python class. '''
def decorator(cls):
assert issubclass(cls, Enum)
cpp_fields = PybindMirror.get_pybind_enum_fields(pybind_class)
python_fields = set(cls.__members__.keys())
for field in cpp_fields:
if field not in python_fields:
raise ValueError(
f"Field {field} is not mirrored in Python class {cls.__name__} from C++ class {pybind_class.__name__}. Please update the class."
)
return cls
return decorator
@staticmethod
def get_pybind_variable_fields(config_cls):
''' Get all the variable fields from the pybind class. '''
return [
f for f in dir(config_cls)
if not f.startswith('_') and not callable(getattr(config_cls, f))
]
@staticmethod
def pybind_equals(obj0, obj1):
''' Check if two pybind objects are equal. '''
assert type(obj0) is type(obj1)
for field in PybindMirror.get_pybind_variable_fields(type(obj0)):
if getattr(obj0, field) != getattr(obj1, field):
return False
return True
@classmethod
def from_pybind(cls: Type[TypeBaseModel],
pybind_instance: "PybindMirror") -> TypeBaseModel:
"""Construct an instance of the given class from the fields in the given
pybind class instance.
Args:
cls: Type of the class to construct, must be a subclass of pydantic
BaseModel
pybind_instance: Instance of the pybind class to construct from its
fields
Notes:
When a field value is None in the pybind class, but it's not
optional and has a default value in the BaseModel class, it would
get the default value defined in the BaseModel class.
Returns:
Instance of the given class, populated with the fields of the given
pybind instance
""" # noqa: D205
assert issubclass(cls, BaseModel)
# Some of the fields are optional in the C++ class but in python they aren't
# optional and have a default value, so copy the value from C++ instance
# only if it has a value, so otherwise the default value defined in the
# python class would be set.
def _is_optional_type(annotation: Any) -> bool:
"""Returns True if a type annotation represents an Optional type
(Optional[X]) or a Union type that includes None (Union[X, Y, None]
or X | Y | None).
""" # noqa: D205
origin = get_origin(annotation)
args = get_args(annotation)
# Union is for Optional[x]
# UnionType is for the new | operation in Python 3.10+
return (origin is Union
or origin is types.UnionType) and type(None) in args
fields_non_optional_with_default_value_in_basemodel = {
field_name
for field_name, field_info in cls.model_fields.items()
if not (_is_optional_type(field_info.annotation)
and field_info.is_required())
}
kwargs = {}
cpp_fields = PybindMirror.get_pybind_variable_fields(
type(pybind_instance))
for field_name in cpp_fields:
field_value = getattr(pybind_instance, field_name)
if field_value is not None or field_name not in fields_non_optional_with_default_value_in_basemodel:
kwargs[field_name] = field_value
return cls(**kwargs)
class PybindMirrorMeta(type(PybindMirror)):
pass
class PybindMirrorEnumMeta(EnumMeta, PybindMirrorMeta):
"""
Combined metaclass for Enum and PybindMirror. This is crucial.
"""
@PybindMirror.mirror_pybind_enum(_BatchingType)
class BatchingType(StrEnum, metaclass=PybindMirrorEnumMeta):
STATIC = "STATIC"
INFLIGHT = "INFLIGHT"
def _to_pybind(self):
return getattr(_BatchingType, self.value)
@PybindMirror.mirror_pybind_enum(_CapacitySchedulerPolicy)
class CapacitySchedulerPolicy(StrEnum, metaclass=PybindMirrorEnumMeta):
MAX_UTILIZATION = "MAX_UTILIZATION"
GUARANTEED_NO_EVICT = "GUARANTEED_NO_EVICT"
STATIC_BATCH = "STATIC_BATCH"
def _to_pybind(self):
return getattr(_CapacitySchedulerPolicy, self.value)
@PybindMirror.mirror_pybind_enum(_ContextChunkingPolicy)
class ContextChunkingPolicy(StrEnum, metaclass=PybindMirrorEnumMeta):
''' Context chunking policy. '''
FIRST_COME_FIRST_SERVED = "FIRST_COME_FIRST_SERVED"
EQUAL_PROGRESS = "EQUAL_PROGRESS"
def _to_pybind(self):
return getattr(_ContextChunkingPolicy, self.value)
@PybindMirror.mirror_pybind_fields(_DynamicBatchConfig)
class DynamicBatchConfig(StrictBaseModel, PybindMirror):
"""Dynamic batch configuration.
Controls how batch size and token limits are dynamically adjusted at runtime.
"""
enable_batch_size_tuning: bool = Field(
description="Controls if the batch size should be tuned dynamically")
enable_max_num_tokens_tuning: bool = Field(
description="Controls if the max num tokens should be tuned dynamically"
)
dynamic_batch_moving_average_window: int = Field(
description=
"The window size for moving average of input and output length which is used to calculate dynamic batch size and max num tokens"
)
def _to_pybind(self):
return _DynamicBatchConfig(
enable_batch_size_tuning=self.enable_batch_size_tuning,
enable_max_num_tokens_tuning=self.enable_max_num_tokens_tuning,
dynamic_batch_moving_average_window=self.
dynamic_batch_moving_average_window)
@PybindMirror.mirror_pybind_fields(_SchedulerConfig)
class SchedulerConfig(StrictBaseModel, PybindMirror):
capacity_scheduler_policy: CapacitySchedulerPolicy = Field(
default=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT,
description="The capacity scheduler policy to use")
context_chunking_policy: Optional[ContextChunkingPolicy] = Field(
default=None, description="The context chunking policy to use")
dynamic_batch_config: Optional[DynamicBatchConfig] = Field(
default=None, description="The dynamic batch config to use")
def _to_pybind(self):
return _SchedulerConfig(
capacity_scheduler_policy=self.capacity_scheduler_policy._to_pybind(
),
context_chunking_policy=self.context_chunking_policy._to_pybind()
if self.context_chunking_policy else None,
dynamic_batch_config=self.dynamic_batch_config._to_pybind()
if self.dynamic_batch_config else None)
@PybindMirror.mirror_pybind_fields(_PeftCacheConfig)
class PeftCacheConfig(StrictBaseModel, PybindMirror):
"""
Configuration for the PEFT cache.
"""
num_host_module_layer: int = Field(
default=0,
description=
"number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache"
", affects host cache size and overrides value of host_cache_size")
num_device_module_layer: int = Field(
default=0,
description=
"number of max sized 1-layer 1-module sets of weights that can be stored in device cache"
", affects device cache size and overrides value of device_cache_percent"
)
optimal_adapter_size: int = Field(
default=
8, # There are tests to keep the default value consistent with the pybind default value
description="optimal adapter size used to set page width")
max_adapter_size: int = Field(
default=64,
description="max supported adapter size. Used to compute minimum")
num_put_workers: int = Field(
default=1,
description=
"number of worker threads used to put weights into host cache")
num_ensure_workers: int = Field(
default=1,
description=
"number of worker threads used to copy weights from host to device")
num_copy_streams: int = Field(
default=1,
description="number of streams used to copy weights from host to device"
)
max_pages_per_block_host: int = Field(
default=24,
description="Number of cache pages per allocation block (host)")
max_pages_per_block_device: int = Field(
default=8,
description="Number of cache pages per allocation block (device)")
device_cache_percent: float = Field(
default=0.02,
description=
"Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1"
)
host_cache_size: int = Field(
default=1024**3, description="size in bytes to use for host cache")
lora_prefetch_dir: Optional[str] = Field(
default=None,
description=
"folder to store the LoRA weights we hope to load during engine initialization, currently not supported"
)
def _to_pybind(self):
return _PeftCacheConfig(
num_host_module_layer=self.num_host_module_layer,
num_device_module_layer=self.num_device_module_layer,
optimal_adapter_size=self.optimal_adapter_size,
max_adapter_size=self.max_adapter_size,
num_put_workers=self.num_put_workers,
num_ensure_workers=self.num_ensure_workers,
num_copy_streams=self.num_copy_streams,
max_pages_per_block_host=self.max_pages_per_block_host,
max_pages_per_block_device=self.max_pages_per_block_device,
device_cache_percent=self.device_cache_percent,
host_cache_size=self.host_cache_size,
lora_prefetch_dir=self.lora_prefetch_dir)
@PybindMirror.mirror_pybind_fields(_LookaheadDecodingConfig)
class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror):
"""
Configuration for lookahead speculative decoding.
"""
max_window_size: int = Field(
default=_LookaheadDecodingConfig.get_default_lookahead_decoding_window(
),
description="Number of NGrams in lookahead branch per step.")
max_ngram_size: int = Field(
default=_LookaheadDecodingConfig.get_default_lookahead_decoding_ngram(),
description="Number of tokens per NGram.")
max_verification_set_size: int = Field(
default=_LookaheadDecodingConfig.
get_default_lookahead_decoding_verification_set(),
description="Number of NGrams in verification branch per step.")
@field_validator('max_window_size', 'max_ngram_size',
'max_verification_set_size')
@classmethod
def validate_positive_values(cls, v):
if v <= 0:
raise ValueError(f"Value must be positive, got {v}")
return v
def __init__(self, **data):
super().__init__(**data)
self._check_fields()
def calculate_speculative_resource(self):
return _LookaheadDecodingConfig.calculate_speculative_resource_tuple(
self.max_window_size, self.max_ngram_size,
self.max_verification_set_size)
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
def _to_pybind(self):
return _LookaheadDecodingConfig(self.max_window_size,
self.max_ngram_size,
self.max_verification_set_size)
def supports_backend(self, backend: str) -> bool:
return backend not in ("pytorch", "_autodeploy")
decoding_type: ClassVar[str] = "Lookahead"
SpeculativeConfig: TypeAlias = Optional[Union[
DraftTargetDecodingConfig,
EagleDecodingConfig,
LookaheadDecodingConfig,
MedusaDecodingConfig,
MTPDecodingConfig,
NGramDecodingConfig,
UserProvidedDecodingConfig,
AutoDecodingConfig,
]]
@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
class KvCacheConfig(StrictBaseModel, PybindMirror):
"""
Configuration for the KV cache.
"""
enable_block_reuse: bool = Field(
default=True,
description=
"Controls if KV cache blocks can be reused for different requests.")
max_tokens: Optional[int] = Field(
default=None,
description=
"The maximum number of tokens that should be stored in the KV cache. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used."
)
max_attention_window: Optional[List[int]] = Field(
default=None,
description=
"Size of the attention window for each sequence. Only the last tokens will be stored in the KV cache. If the number of elements in `max_attention_window` is less than the number of layers, `max_attention_window` will be repeated multiple times to the number of layers."
)
sink_token_length: Optional[int] = Field(
default=None,
description=
"Number of sink tokens (tokens to always keep in attention window).")
free_gpu_memory_fraction: Optional[float] = Field(
default=None,
description=
"The fraction of GPU memory fraction that should be allocated for the KV cache. Default is 90%. If both `max_tokens` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be used."
)
host_cache_size: Optional[int] = Field(
default=None,
description=
"Size of the host cache in bytes. If both `max_tokens` and `host_cache_size` are specified, memory corresponding to the minimum will be used."
)
onboard_blocks: bool = Field(
default=True, description="Controls if blocks are onboarded.")
cross_kv_cache_fraction: Optional[float] = Field(
default=None,
description=
"The fraction of the KV Cache memory should be reserved for cross attention. If set to p, self attention will use 1-p of KV Cache memory and cross attention will use p of KV Cache memory. Default is 50%. Should only be set when using encoder-decoder model."
)
secondary_offload_min_priority: Optional[int] = Field(
default=None,
description=
"Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory."
)
event_buffer_max_size: int = Field(
default=0,
description=
"Maximum size of the event buffer. If set to 0, the event buffer will not be used."
)
enable_partial_reuse: bool = Field(
default=True,
description=
"Whether blocks that are only partially matched can be reused.")
copy_on_partial_reuse: bool = Field(
default=True,
description=
"Whether partially matched blocks that are in use can be reused after copying them."
)
use_uvm: bool = Field(default=False,
description="Whether to use UVM for the KV cache.")
# This is a pure python field, not a pybind field. It is only for the Pytorch backend.
dtype: str = Field(default="auto",
description="The data type to use for the KV cache.")
def _to_pybind(self):
return _KvCacheConfig(
enable_block_reuse=self.enable_block_reuse,
max_tokens=self.max_tokens,
max_attention_window=self.max_attention_window,
sink_token_length=self.sink_token_length,
free_gpu_memory_fraction=self.free_gpu_memory_fraction,
host_cache_size=self.host_cache_size,
onboard_blocks=self.onboard_blocks,
cross_kv_cache_fraction=self.cross_kv_cache_fraction,
secondary_offload_min_priority=self.secondary_offload_min_priority,
event_buffer_max_size=self.event_buffer_max_size,
enable_partial_reuse=self.enable_partial_reuse,
copy_on_partial_reuse=self.copy_on_partial_reuse,
use_uvm=self.use_uvm)
@PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig)
class ExtendedRuntimePerfKnobConfig(StrictBaseModel, PybindMirror):
"""
Configuration for extended runtime performance knobs.
"""
multi_block_mode: bool = Field(
default=True, description="Whether to use multi-block mode.")
enable_context_fmha_fp32_acc: bool = Field(
default=False,
description="Whether to enable context FMHA FP32 accumulation.")