1414
1515import asyncio
1616import atexit
17- import logging
17+ import os
1818import sys
19+ import logging
1920from concurrent .futures import Future as SyncFuture
2021from typing import Dict , List , Union
2122
2728from ...resource import cpu_count , cuda_count , mem_total , Resource
2829from ...services import NodeRole
2930from ...typing import ClusterType , ClientType
30- from ..utils import get_third_party_modules_from_config
31+ from ..utils import get_third_party_modules_from_config , load_config
3132from .pool import create_supervisor_actor_pool , create_worker_actor_pool
3233from .service import (
3334 start_supervisor ,
3435 start_worker ,
3536 stop_supervisor ,
3637 stop_worker ,
37- load_config ,
3838)
3939from .session import AbstractSession , _new_session , ensure_isolation_created
4040
4646)
4747atexit .register (stop_isolation )
4848
49+ # The default config file.
50+ DEFAULT_CONFIG_FILE = os .path .join (
51+ os .path .dirname (os .path .abspath (__file__ )), "config.yml"
52+ )
53+
54+
55+ def _load_config (config : Union [str , Dict ] = None ):
56+ return load_config (config , default_config_file = DEFAULT_CONFIG_FILE )
57+
4958
5059async def new_cluster_in_isolation (
5160 address : str = "0.0.0.0" ,
@@ -67,6 +76,7 @@ async def new_cluster_in_isolation(
6776 mem_bytes ,
6877 cuda_devices ,
6978 subprocess_start_method ,
79+ backend ,
7080 config ,
7181 web ,
7282 n_supervisor_process ,
@@ -82,6 +92,7 @@ async def new_cluster(
8292 mem_bytes : Union [int , str ] = "auto" ,
8393 cuda_devices : Union [List [int ], str ] = "auto" ,
8494 subprocess_start_method : str = None ,
95+ backend : str = None ,
8596 config : Union [str , Dict ] = None ,
8697 web : bool = True ,
8798 loop : asyncio .AbstractEventLoop = None ,
@@ -95,6 +106,7 @@ async def new_cluster(
95106 mem_bytes = mem_bytes ,
96107 cuda_devices = cuda_devices ,
97108 subprocess_start_method = subprocess_start_method ,
109+ backend = backend ,
98110 config = config ,
99111 web = web ,
100112 n_supervisor_process = n_supervisor_process ,
@@ -121,6 +133,7 @@ def __init__(
121133 mem_bytes : Union [int , str ] = "auto" ,
122134 cuda_devices : Union [List [int ], List [List [int ]], str ] = "auto" ,
123135 subprocess_start_method : str = None ,
136+ backend : str = None ,
124137 config : Union [str , Dict ] = None ,
125138 web : Union [bool , str ] = "auto" ,
126139 n_supervisor_process : int = 0 ,
@@ -133,11 +146,11 @@ def __init__(
133146 "spawn" if sys .platform == "win32" else "forkserver"
134147 )
135148 # load config file to dict.
136- if not config or isinstance (config , str ):
137- config = load_config (config )
138149 self ._address = address
139150 self ._subprocess_start_method = subprocess_start_method
140- self ._config = config
151+ self ._config = load_config (config , default_config_file = DEFAULT_CONFIG_FILE )
152+ if backend is not None :
153+ self ._config ["task" ]["task_executor_config" ]["backend" ] = backend
141154 self ._n_cpu = cpu_count () if n_cpu == "auto" else n_cpu
142155 self ._mem_bytes = mem_total () if mem_bytes == "auto" else mem_bytes
143156 self ._n_supervisor_process = n_supervisor_process
0 commit comments