forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllm.py
More file actions
101 lines (77 loc) · 3.63 KB
/
llm.py
File metadata and controls
101 lines (77 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import types
from typing import List, Optional
from ...executor.result import CompletionOutput
from ...inputs.registry import create_input_processor
from ...llmapi.llm import RequestOutput, _TorchLLM
from ...llmapi.tokenizer import TokenizerBase, tokenizer_factory
from .distributed import common as dist_ad
from .llm_args import LlmArgs
from .shim.demollm import DemoGenerationExecutor
class LLM(_TorchLLM):
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
args: LlmArgs
def __init__(self, *args, **kwargs):
kwargs["backend"] = "_autodeploy"
super().__init__(*args, **kwargs)
def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
if self.args.skip_tokenizer_init:
return None
factory = self.args.create_factory()
return tokenizer_factory(factory.init_tokenizer())
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
"""We don't need to validate args for AutoDeploy backend for now."""
pass
def _prefetch_model(self):
"""Prefetch the model for the LLM."""
self.args.create_factory().prefetch_checkpoint()
def _build_model(self):
"""Build the model for the LLM.
This is a wrapper around the regular build model method that prefetches the model with the
factory.
"""
# prefetch model with factory
self._prefetch_model()
# NOTE (lucaslie): do regular build model, we bypass the regular LLM CachedModelLoader in
# _autodeploy backend.
super()._build_model()
class DemoLLM(LLM):
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
This is a very simple implementation of an LLM class that can be hacked and used for debugging.
"""
def __init__(self, **kwargs):
self.args: LlmArgs = LlmArgs.from_kwargs(**kwargs)
self.mpi_session = None
self.runtime_context = None
# prefetch model and load tokenizer
self._prefetch_model()
self._tokenizer = self._try_load_tokenizer()
self.input_processor = create_input_processor(None, self.tokenizer)
# construct demo executor + engine
self._executor = DemoGenerationExecutor(
world_size=self.args.world_size,
tokenizer=self.tokenizer,
ad_config=self.args.get_pytorch_backend_config(),
)
def __del__(self):
"""Ensure proper cleanup of distributed resources."""
if hasattr(self, "_executor") and self._executor is not None:
self._executor.shutdown()
# Call cleanup to ensure process group is properly destroyed
dist_ad.cleanup()
@staticmethod
def _handle_response(request_output: RequestOutput, response: List[CompletionOutput]):
request_output._done = True
gen_request = request_output._generation_request
for i, out in enumerate(response):
out.text = request_output.tokenizer.decode(
out.token_ids,
skip_special_tokens=gen_request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=gen_request.sampling_params.spaces_between_special_tokens,
)
request_output._context_logits = out._postprocess_result["context_logits"]
request_output._outputs[i] = out
def generate_async(self, *args, **kwargs) -> RequestOutput:
request_output = super().generate_async(*args, **kwargs)
# patch the handle_output method for our use case
request_output._handle_response = types.MethodType(self._handle_response, request_output)
return request_output