Skip to content

Commit 913549c

Browse files
authored
feat: support custom service start command (#5382)
* feat: support custom service start command Signed-off-by: Frost Ming <me@frostming.com> * fix types Signed-off-by: Frost Ming <me@frostming.com> * fix: mirror Signed-off-by: Frost Ming <me@frostming.com>
1 parent 17a160c commit 913549c

File tree

10 files changed

+209
-94
lines changed

10 files changed

+209
-94
lines changed

pdm.lock

Lines changed: 7 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ dependencies = [
6565
"aiosqlite>=0.20.0",
6666
"questionary>=2.0.1",
6767
"a2wsgi>=1.10.7",
68-
"python-dotenv>=1.0.1",
6968
]
7069
dynamic = ["version"]
7170
[project.urls]

src/_bentoml_impl/server/serving.py

Lines changed: 108 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import platform
1010
import shutil
1111
import socket
12+
import sys
1213
import tempfile
1314
import typing as t
1415

@@ -17,6 +18,8 @@
1718

1819
from _bentoml_sdk import Service
1920
from bentoml._internal.container import BentoMLContainer
21+
from bentoml._internal.utils import expand_envs
22+
from bentoml._internal.utils import reserve_free_port
2023
from bentoml._internal.utils.circus import Server
2124
from bentoml.exceptions import BentoMLConfigException
2225

@@ -64,8 +67,6 @@ def _get_server_socket(
6467
) -> tuple[str, CircusSocket]:
6568
from circus.sockets import CircusSocket
6669

67-
from bentoml._internal.utils import reserve_free_port
68-
6970
runner_port = port_stack.enter_context(reserve_free_port())
7071
runner_host = "127.0.0.1"
7172

@@ -103,30 +104,49 @@ def create_dependency_watcher(
103104
working_dir: str | None = None,
104105
env: dict[str, str] | None = None,
105106
bento_args: dict[str, t.Any] = Provide[BentoMLContainer.bento_arguments],
106-
) -> tuple[Watcher, CircusSocket, str]:
107+
) -> tuple[Watcher, CircusSocket | None, str]:
107108
from bentoml.serving import create_watcher
108109

109-
num_workers, worker_envs = scheduler.get_worker_env(svc)
110-
uri, socket = _get_server_socket(svc, uds_path, port_stack, backlog)
111-
args = [
112-
"-m",
113-
_SERVICE_WORKER_SCRIPT,
114-
bento_identifier,
115-
"--service-name",
116-
svc.name,
117-
"--fd",
118-
f"$(circus.sockets.{svc.name})",
119-
"--worker-id",
120-
"$(CIRCUS.WID)",
121-
"--args",
122-
json.dumps(bento_args),
123-
]
124-
125-
if worker_envs:
126-
args.extend(["--worker-env", json.dumps(worker_envs)])
110+
if svc.cmd is not None:
111+
num_workers = 1 # Custom command only runs a single worker
112+
svc_port = port_stack.enter_context(reserve_free_port())
113+
if env is None:
114+
env = {**os.environ, "PORT": str(svc_port)}
115+
else:
116+
env = {**os.environ, **env, "PORT": str(svc_port)}
117+
uri = f"http://127.0.0.1:{svc_port}"
118+
socket = None
119+
working_dir = svc.working_dir
120+
logger.info(
121+
"Starting with custom command for dependency service(%s): %s",
122+
svc.name,
123+
svc.cmd,
124+
)
125+
cmd, *args = [expand_envs(p, env) for p in svc.cmd]
126+
else:
127+
num_workers, worker_envs = scheduler.get_worker_env(svc)
128+
uri, socket = _get_server_socket(svc, uds_path, port_stack, backlog)
129+
args = [
130+
"-m",
131+
_SERVICE_WORKER_SCRIPT,
132+
bento_identifier,
133+
"--service-name",
134+
svc.name,
135+
"--fd",
136+
f"$(circus.sockets.{svc.name})",
137+
"--worker-id",
138+
"$(CIRCUS.WID)",
139+
"--args",
140+
json.dumps(bento_args),
141+
]
142+
cmd = sys.executable
143+
144+
if worker_envs:
145+
args.extend(["--worker-env", json.dumps(worker_envs)])
127146

128147
watcher = create_watcher(
129148
name=f"service_{svc.name}",
149+
cmd=cmd,
130150
args=args,
131151
numprocesses=num_workers,
132152
working_dir=working_dir,
@@ -269,7 +289,8 @@ def serve_http(
269289
env={k: str(v) for k, v in dependency_env.items()},
270290
)
271291
watchers.append(new_watcher)
272-
sockets.append(new_socket)
292+
if new_socket:
293+
sockets.append(new_socket)
273294
dependency_map[name] = uri
274295
server_on_deployment(dep_svc)
275296
# reserve one more to avoid conflicts
@@ -287,62 +308,78 @@ def serve_http(
287308
)
288309
except ValueError as e:
289310
raise BentoMLConfigException(f"Invalid host IP address: {host}") from e
290-
291-
sockets.append(
292-
CircusSocket(
293-
name=API_SERVER_NAME,
294-
host=host,
295-
port=port,
296-
family=family,
297-
backlog=backlog,
311+
if svc.cmd is not None:
312+
logger.info("Starting with custom command for entry service: %s", svc.cmd)
313+
num_workers = 1 # Custom command only runs a single worker
314+
env.update(os.environ)
315+
env.update(
316+
{
317+
"PORT": str(port),
318+
"BENTOML_HOST": host,
319+
"BENTOML_PORT": str(port),
320+
}
298321
)
299-
)
300-
if BentoMLContainer.ssl.enabled.get() and not ssl_certfile:
301-
raise BentoMLConfigException("ssl_certfile is required when ssl is enabled")
302-
303-
ssl_args = construct_ssl_args(
304-
ssl_certfile=ssl_certfile,
305-
ssl_keyfile=ssl_keyfile,
306-
ssl_keyfile_password=ssl_keyfile_password,
307-
ssl_version=ssl_version,
308-
ssl_cert_reqs=ssl_cert_reqs,
309-
ssl_ca_certs=ssl_ca_certs,
310-
ssl_ciphers=ssl_ciphers,
311-
)
312-
timeouts_args = construct_timeouts_args(
313-
timeout_keep_alive=timeout_keep_alive,
314-
timeout_graceful_shutdown=timeout_graceful_shutdown,
315-
)
316-
timeout_args = ["--timeout", str(timeout)] if timeout else []
317-
bento_args = BentoMLContainer.bento_arguments.get()
322+
server_cmd, *server_args = [expand_envs(p, env) for p in svc.cmd]
323+
bento_path = pathlib.Path(svc.working_dir)
324+
else:
325+
sockets.append(
326+
CircusSocket(
327+
name=API_SERVER_NAME,
328+
host=host,
329+
port=port,
330+
family=family,
331+
backlog=backlog,
332+
)
333+
)
334+
if BentoMLContainer.ssl.enabled.get() and not ssl_certfile:
335+
raise BentoMLConfigException(
336+
"ssl_certfile is required when ssl is enabled"
337+
)
318338

319-
server_args = [
320-
"-m",
321-
_SERVICE_WORKER_SCRIPT,
322-
bento_identifier,
323-
"--fd",
324-
f"$(circus.sockets.{API_SERVER_NAME})",
325-
"--service-name",
326-
svc.name,
327-
"--backlog",
328-
str(backlog),
329-
"--worker-id",
330-
"$(CIRCUS.WID)",
331-
"--args",
332-
json.dumps(bento_args),
333-
*ssl_args,
334-
*timeouts_args,
335-
*timeout_args,
336-
]
337-
if worker_envs:
338-
server_args.extend(["--worker-env", json.dumps(worker_envs)])
339-
if development_mode:
340-
server_args.append("--development-mode")
339+
ssl_args = construct_ssl_args(
340+
ssl_certfile=ssl_certfile,
341+
ssl_keyfile=ssl_keyfile,
342+
ssl_keyfile_password=ssl_keyfile_password,
343+
ssl_version=ssl_version,
344+
ssl_cert_reqs=ssl_cert_reqs,
345+
ssl_ca_certs=ssl_ca_certs,
346+
ssl_ciphers=ssl_ciphers,
347+
)
348+
timeouts_args = construct_timeouts_args(
349+
timeout_keep_alive=timeout_keep_alive,
350+
timeout_graceful_shutdown=timeout_graceful_shutdown,
351+
)
352+
timeout_args = ["--timeout", str(timeout)] if timeout else []
353+
bento_args = BentoMLContainer.bento_arguments.get()
354+
server_cmd = sys.executable
355+
server_args = [
356+
"-m",
357+
_SERVICE_WORKER_SCRIPT,
358+
bento_identifier,
359+
"--fd",
360+
f"$(circus.sockets.{API_SERVER_NAME})",
361+
"--service-name",
362+
svc.name,
363+
"--backlog",
364+
str(backlog),
365+
"--worker-id",
366+
"$(CIRCUS.WID)",
367+
"--args",
368+
json.dumps(bento_args),
369+
*ssl_args,
370+
*timeouts_args,
371+
*timeout_args,
372+
]
373+
if worker_envs:
374+
server_args.extend(["--worker-env", json.dumps(worker_envs)])
375+
if development_mode:
376+
server_args.append("--development-mode")
341377

342378
scheme = "https" if BentoMLContainer.ssl.enabled.get() else "http"
343379
watchers.append(
344380
create_watcher(
345381
name="service",
382+
cmd=server_cmd,
346383
args=server_args,
347384
working_dir=str(bento_path.absolute()),
348385
numprocesses=num_workers,

src/_bentoml_sdk/service/config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,16 @@ class LoggingSchema(TypedDict, total=False):
236236
access: AccessLoggingSchema
237237

238238

239-
"""
240-
service level (per replica) config
241-
"""
239+
class EndpointsSchema(TypedDict, total=False):
240+
livez: str
241+
readyz: str
242242

243243

244244
class ServiceConfig(TypedDict, total=False):
245+
"""
246+
service level (per replica) config
247+
"""
248+
245249
name: str
246250
traffic: TrafficSchema
247251
backlog: Annotated[int, Ge(64)]
@@ -257,6 +261,7 @@ class ServiceConfig(TypedDict, total=False):
257261
runner_probe: RunnerProbeSchema
258262
tracing: TracingSchema
259263
monitoring: MonitoringSchema
264+
endpoints: EndpointsSchema
260265

261266

262267
schema_type = TypeAdapter(ServiceConfig)

src/_bentoml_sdk/service/factory.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838

3939
logger = logging.getLogger("bentoml.serve")
4040

41-
T = t.TypeVar("T", bound=object)
41+
T = t.TypeVar("T")
4242

4343
if t.TYPE_CHECKING:
4444
from bentoml._internal import external_typing as ext
@@ -72,18 +72,23 @@ def convert_envs(envs: t.List[t.Dict[str, t.Any]]) -> t.List[BentoEnvSchema]:
7272
return [BentoEnvSchema(**env) for env in envs]
7373

7474

75+
class _DummyService:
76+
pass
77+
78+
7579
@attrs.define
7680
class Service(t.Generic[T]):
7781
"""A Bentoml service that can be served by BentoML server."""
7882

79-
config: Config
80-
inner: type[T]
83+
config: Config = attrs.field(factory=Config)
84+
inner: type[T] = _DummyService
8185
image: t.Optional[Image] = None
8286
envs: t.List[BentoEnvSchema] = attrs.field(factory=list, converter=convert_envs)
8387
labels: t.Dict[str, str] = attrs.field(factory=dict)
84-
bento: t.Optional[Bento] = attrs.field(init=False, default=None)
8588
models: list[Model[t.Any]] = attrs.field(factory=list)
86-
apis: dict[str, APIMethod[..., t.Any]] = attrs.field(factory=dict)
89+
cmd: t.Optional[t.List[str]] = None
90+
bento: t.Optional[Bento] = attrs.field(init=False, default=None)
91+
apis: dict[str, APIMethod[..., t.Any]] = attrs.field(factory=dict, init=False)
8792
dependencies: dict[str, Dependency[t.Any]] = attrs.field(factory=dict, init=False)
8893
mount_apps: list[tuple[ext.ASGIApp, str, str]] = attrs.field(
8994
factory=list, init=False
@@ -491,6 +496,8 @@ def service(
491496
image: Image | None = None,
492497
envs: list[dict[str, str]] | None = None,
493498
labels: dict[str, str] | None = None,
499+
cmd: list[str] | None = None,
500+
service_class: type[Service[T]] = Service,
494501
**kwargs: Unpack[Config],
495502
) -> _ServiceDecorator: ...
496503

@@ -502,6 +509,8 @@ def service(
502509
image: Image | None = None,
503510
envs: list[dict[str, str]] | None = None,
504511
labels: dict[str, str] | None = None,
512+
cmd: list[str] | None = None,
513+
service_class: type[Service[T]] = Service,
505514
**kwargs: Unpack[Config],
506515
) -> t.Any:
507516
"""Mark a class as a BentoML service.
@@ -519,12 +528,13 @@ def predict(self, input: str) -> str:
519528
def decorator(inner: type[T]) -> Service[T]:
520529
if isinstance(inner, Service):
521530
raise TypeError("service() decorator can only be applied once")
522-
return Service(
531+
return service_class(
523532
config=config,
524533
inner=inner,
525534
image=image,
526535
envs=envs or [],
527536
labels=labels or {},
537+
cmd=cmd,
528538
)
529539

530540
return decorator(inner) if inner is not None else decorator

src/bentoml/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
# New SDK
4949
"service": "_bentoml_sdk:service",
5050
"runner_service": "_bentoml_sdk:runner_service",
51+
"Service": "_bentoml_sdk:Service",
5152
"api": "_bentoml_sdk:api",
5253
"task": "_bentoml_sdk:task",
5354
"depends": "_bentoml_sdk:depends",
@@ -144,6 +145,7 @@
144145
from _bentoml_impl.client import SyncHTTPClient
145146
from _bentoml_impl.loader import importing
146147
from _bentoml_sdk import IODescriptor
148+
from _bentoml_sdk import Service
147149
from _bentoml_sdk import api
148150
from _bentoml_sdk import asgi_app
149151
from _bentoml_sdk import depends
@@ -374,6 +376,7 @@ def __getattr__(name: str) -> Any:
374376
"set_serialization_strategy",
375377
# new SDK
376378
"service",
379+
"Service",
377380
"runner_service",
378381
"api",
379382
"task",

0 commit comments

Comments
 (0)