3737from ....storage import StorageLevel
3838from ....services .storage import StorageAPI
3939from ....tensor .arithmetic .add import TensorAdd
40+ from ....tests .core import check_dict_structure_same
4041from ..local import new_cluster
4142from ..service import load_config
4243from ..session import (
6768 os .path .dirname (__file__ ), "local_test_with_third_parity_modules_config.yml"
6869)
6970
71+ EXPECT_PROFILING_STRUCTURE = {
72+ "supervisor" : {
73+ "general" : {
74+ "optimize" : 0.0005879402160644531 ,
75+ "incref_fetch_tileables" : 0.0010840892791748047 ,
76+ "stage_*" : {
77+ "tile" : 0.008243083953857422 ,
78+ "gen_subtask_graph" : 0.012202978134155273 ,
79+ "run" : 0.27870702743530273 ,
80+ "total" : 0.30318617820739746 ,
81+ },
82+ "total" : 0.30951380729675293 ,
83+ },
84+ "serialization" : {},
85+ }
86+ }
7087
7188params = ["default" ]
7289if vineyard is not None :
@@ -147,8 +164,15 @@ async def test_vineyard_operators(create_cluster):
147164 pd .testing .assert_frame_equal (df , raw )
148165
149166
167+ @pytest .mark .parametrize (
168+ "config" ,
169+ [
170+ [{"enable_profiling" : True }, EXPECT_PROFILING_STRUCTURE ],
171+ [{}, {}],
172+ ],
173+ )
150174@pytest .mark .asyncio
151- async def test_execute (create_cluster ):
175+ async def test_execute (create_cluster , config ):
152176 session = get_default_async_session ()
153177 assert session .address is not None
154178 assert session .session_id is not None
@@ -157,8 +181,14 @@ async def test_execute(create_cluster):
157181 a = mt .tensor (raw , chunk_size = 5 )
158182 b = a + 1
159183
160- info = await session .execute (b )
184+ extra_config , expect_profiling_structure = config
185+
186+ info = await session .execute (b , extra_config = extra_config )
161187 await info
188+ if extra_config :
189+ check_dict_structure_same (info .profiling_result (), expect_profiling_structure )
190+ else :
191+ assert not info .profiling_result ()
162192 assert info .result () is None
163193 assert info .exception () is None
164194 assert info .progress () == 1
@@ -296,16 +326,23 @@ def _my_func():
296326 await session .destroy ()
297327
298328
329+ @pytest .mark .parametrize (
330+ "config" ,
331+ [
332+ [{"enable_profiling" : True }, EXPECT_PROFILING_STRUCTURE ],
333+ [{}, {}],
334+ ],
335+ )
299336@pytest .mark .asyncio
300- async def test_web_session (create_cluster ):
337+ async def test_web_session (create_cluster , config ):
301338 client = create_cluster [0 ]
302339 session_id = str (uuid .uuid4 ())
303340 web_address = client .web_address
304341 session = await AsyncSession .init (web_address , session_id )
305342 assert await session .get_web_endpoint () == web_address
306343 session .as_default ()
307344 assert isinstance (session ._isolated_session , _IsolatedWebSession )
308- await test_execute (client )
345+ await test_execute (client , config )
309346 await test_iterative_tiling (client )
310347 AsyncSession .reset_default ()
311348 await session .destroy ()
0 commit comments