@@ -18,33 +18,37 @@ class RayLogger:
1818
1919 def __init__ (
2020 self ,
21- log_level : int = logging .INFO ,
21+ log_level : int = logging .DEBUG ,
2222 format : str = "%(asctime)s::%(levelname)-2s::%(name)s::%(message)s" , # pylint: disable=redefined-builtin
2323 datefmt : str = "%Y-%m-%d %H:%M:%S" ,
2424 ):
2525 logging .basicConfig (level = log_level , format = format , datefmt = datefmt )
2626
27- def get_logger (self , name : Union [str , Any ] = None ) -> Union [logging .Logger , Any ]:
27+ def get_logger (self , name : Union [str , Any ] = None ) -> Optional [logging .Logger ]:
2828 """Return logger object."""
29- return logging .getLogger (name ) if engine . get () == EngineEnum . RAY else None
29+ return logging .getLogger (name )
3030
3131
32- def ray_get ( futures : List [ Any ]) -> List [ Any ]:
32+ def ray_logger ( function : Callable [..., Any ]) -> Callable [..., Any ]:
3333 """
34- Run ray.get on futures if distributed .
34+ Decorate callable to add RayLogger .
3535
3636 Parameters
3737 ----------
38- futures : List[ Any]
39- List of Ray futures
38+ function : Callable[..., Any]
39+ Callable as input to decorator.
4040
4141 Returns
4242 -------
43- List[ Any]
43+ Callable[..., Any]
4444 """
45- if engine .get () == EngineEnum .RAY :
46- return ray .get (futures )
47- return futures
45+
46+ @wraps (function )
47+ def wrapper (* args : Any , ** kwargs : Any ) -> Any :
48+ RayLogger ().get_logger (name = function .__name__ )
49+ return function (* args , ** kwargs )
50+
51+ return wrapper
4852
4953
5054def ray_remote (function : Callable [..., Any ]) -> Callable [..., Any ]:
@@ -54,7 +58,8 @@ def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
5458 Parameters
5559 ----------
5660 function : Callable[..., Any]
57- Callable as input to ray.remote
61+ Callable as input to ray.remote.
62+
5863 Returns
5964 -------
6065 Callable[..., Any]
@@ -64,18 +69,36 @@ def ray_remote(function: Callable[..., Any]) -> Callable[..., Any]:
6469
6570 @wraps (function )
6671 def wrapper (* args : Any , ** kwargs : Any ) -> Any :
67- return ray .remote (function ).remote (* args , ** kwargs ) # type: ignore
72+ return ray .remote (ray_logger ( function ) ).remote (* args , ** kwargs ) # type: ignore
6873
6974 return wrapper
7075
7176
77+ def ray_get (futures : List [Any ]) -> List [Any ]:
78+ """
79+ Run ray.get on futures if distributed.
80+
81+ Parameters
82+ ----------
83+ futures : List[Any]
84+ List of Ray futures
85+
86+ Returns
87+ -------
88+ List[Any]
89+ """
90+ if engine .get () == EngineEnum .RAY :
91+ return ray .get (futures )
92+ return futures
93+
94+
7295@apply_configs
7396def initialize_ray (
7497 address : Optional [str ] = None ,
7598 redis_password : Optional [str ] = None ,
76- ignore_reinit_error : Optional [ bool ] = True ,
99+ ignore_reinit_error : bool = True ,
77100 include_dashboard : Optional [bool ] = False ,
78- log_to_driver : Optional [ bool ] = True ,
101+ log_to_driver : bool = False ,
79102 object_store_memory : Optional [int ] = None ,
80103 cpu_count : Optional [int ] = None ,
81104 gpu_count : Optional [int ] = None ,
@@ -89,12 +112,12 @@ def initialize_ray(
89112 Address of the Ray cluster to connect to, by default None
90113 redis_password : Optional[str]
91114 Password to the Redis cluster, by default None
92- ignore_reinit_error : Optional[ bool]
115+ ignore_reinit_error : bool
93116 If true, Ray suppress errors from calling ray.init() twice, by default True
94117 include_dashboard : Optional[bool]
95118 Boolean flag indicating whether or not to start the Ray dashboard, by default False
96- log_to_driver : Optional[ bool]
97- Boolean flag to enable routing of all worker logs to the driver, by default True
119+ log_to_driver : bool
120+ Boolean flag to enable routing of all worker logs to the driver, by default False
98121 object_store_memory : Optional[int]
99122 The amount of memory (in bytes) to start the object store with, by default None
100123 cpu_count : Optional[int]
0 commit comments