@@ -31,11 +31,13 @@ def __init__(
31
31
top_p : float = 1.0 ,
32
32
max_new_tokens : int = 50 ,
33
33
devices : int = 1 ,
34
+ api_path : Optional [str ] = None ,
34
35
) -> None :
35
36
if not _LITSERVE_AVAILABLE :
36
37
raise ImportError (str (_LITSERVE_AVAILABLE ))
37
38
38
- super ().__init__ ()
39
+ super ().__init__ (api_path = api_path )
40
+
39
41
self .checkpoint_dir = checkpoint_dir
40
42
self .quantize = quantize
41
43
self .precision = precision
@@ -61,12 +63,11 @@ def setup(self, device: str) -> None:
61
63
accelerator = accelerator ,
62
64
quantize = self .quantize ,
63
65
precision = self .precision ,
64
- generate_strategy = "sequential" if self .devices is not None and self .devices > 1 else None ,
66
+ generate_strategy = ( "sequential" if self .devices is not None and self .devices > 1 else None ) ,
65
67
)
66
68
print ("Model successfully initialized." , file = sys .stderr )
67
69
68
70
def decode_request (self , request : Dict [str , Any ]) -> Any :
69
- # Convert the request payload to your model input.
70
71
prompt = str (request ["prompt" ])
71
72
return prompt
72
73
@@ -82,15 +83,30 @@ def __init__(
82
83
top_p : float = 1.0 ,
83
84
max_new_tokens : int = 50 ,
84
85
devices : int = 1 ,
86
+ api_path : Optional [str ] = None ,
85
87
):
86
- super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens , devices )
88
+ super ().__init__ (
89
+ checkpoint_dir ,
90
+ quantize ,
91
+ precision ,
92
+ temperature ,
93
+ top_k ,
94
+ top_p ,
95
+ max_new_tokens ,
96
+ devices ,
97
+ api_path = api_path ,
98
+ )
87
99
88
100
def setup (self , device : str ):
89
101
super ().setup (device )
90
102
91
103
def predict (self , inputs : str ) -> Any :
92
104
output = self .llm .generate (
93
- inputs , temperature = self .temperature , top_k = self .top_k , top_p = self .top_p , max_new_tokens = self .max_new_tokens
105
+ inputs ,
106
+ temperature = self .temperature ,
107
+ top_k = self .top_k ,
108
+ top_p = self .top_p ,
109
+ max_new_tokens = self .max_new_tokens ,
94
110
)
95
111
return output
96
112
@@ -110,14 +126,24 @@ def __init__(
110
126
top_p : float = 1.0 ,
111
127
max_new_tokens : int = 50 ,
112
128
devices : int = 1 ,
129
+ api_path : Optional [str ] = None ,
113
130
):
114
- super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens , devices )
131
+ super ().__init__ (
132
+ checkpoint_dir ,
133
+ quantize ,
134
+ precision ,
135
+ temperature ,
136
+ top_k ,
137
+ top_p ,
138
+ max_new_tokens ,
139
+ devices ,
140
+ api_path = api_path ,
141
+ )
115
142
116
143
def setup (self , device : str ):
117
144
super ().setup (device )
118
145
119
146
def predict (self , inputs : torch .Tensor ) -> Any :
120
- # Run the model on the input and return the output.
121
147
yield from self .llm .generate (
122
148
inputs ,
123
149
temperature = self .temperature ,
@@ -143,8 +169,19 @@ def __init__(
143
169
top_p : float = 1.0 ,
144
170
max_new_tokens : int = 50 ,
145
171
devices : int = 1 ,
172
+ api_path : Optional [str ] = None ,
146
173
):
147
- super ().__init__ (checkpoint_dir , quantize , precision , temperature , top_k , top_p , max_new_tokens , devices )
174
+ super ().__init__ (
175
+ checkpoint_dir ,
176
+ quantize ,
177
+ precision ,
178
+ temperature ,
179
+ top_k ,
180
+ top_p ,
181
+ max_new_tokens ,
182
+ devices ,
183
+ api_path = api_path ,
184
+ )
148
185
149
186
def setup (self , device : str ):
150
187
super ().setup (device )
@@ -178,7 +215,12 @@ def predict(self, inputs: str, context: dict) -> Any:
178
215
179
216
# Run the model on the input and return the output.
180
217
yield from self .llm .generate (
181
- inputs , temperature = temperature , top_k = self .top_k , top_p = top_p , max_new_tokens = max_new_tokens , stream = True
218
+ inputs ,
219
+ temperature = temperature ,
220
+ top_k = self .top_k ,
221
+ top_p = top_p ,
222
+ max_new_tokens = max_new_tokens ,
223
+ stream = True ,
182
224
)
183
225
184
226
@@ -196,6 +238,7 @@ def run_server(
196
238
stream : bool = False ,
197
239
openai_spec : bool = False ,
198
240
access_token : Optional [str ] = None ,
241
+ api_path : Optional [str ] = "/predict" ,
199
242
) -> None :
200
243
"""Serve a LitGPT model using LitServe.
201
244
@@ -237,11 +280,13 @@ def run_server(
237
280
`/v1/chat/completions` endpoints that work with the OpenAI SDK and other OpenAI-compatible clients,
238
281
making it easy to integrate with existing applications that use the OpenAI API.
239
282
access_token: Optional API token to access models with restrictions.
283
+ api_path: The custom API path for the endpoint (e.g., "/my_api/classify").
240
284
"""
241
285
checkpoint_dir = auto_download_checkpoint (model_name = checkpoint_dir , access_token = access_token )
242
286
pprint (locals ())
243
287
244
288
api_class = OpenAISpecLitAPI if openai_spec else StreamLitAPI if stream else SimpleLitAPI
289
+
245
290
server = LitServer (
246
291
api_class (
247
292
checkpoint_dir = checkpoint_dir ,
@@ -252,6 +297,7 @@ def run_server(
252
297
top_p = top_p ,
253
298
max_new_tokens = max_new_tokens ,
254
299
devices = devices ,
300
+ api_path = api_path ,
255
301
),
256
302
spec = OpenAISpec () if openai_spec else None ,
257
303
accelerator = accelerator ,
0 commit comments