3131from pyspark .sql .functions import col , pandas_udf , udf
3232from pyspark .sql .window import Window
3333from pyspark .profiler import UDFBasicProfiler
34- from pyspark .testing .sqlutils import (
35- ReusedSQLTestCase ,
34+ from pyspark .testing .sqlutils import ReusedSQLTestCase
35+ from pyspark . testing . utils import (
3636 have_pandas ,
3737 have_pyarrow ,
38+ have_flameprof ,
3839 pandas_requirement_message ,
3940 pyarrow_requirement_message ,
4041)
4142
42- try :
43- import flameprof # noqa: F401
44-
45- has_flameprof = True
46- except ImportError :
47- has_flameprof = False
48-
4943
5044def _do_computation (spark , * , action = lambda df : df .collect (), use_arrow = False ):
5145 @udf ("long" , useArrow = use_arrow )
@@ -208,7 +202,7 @@ def test_perf_profiler_udf(self):
208202 )
209203 self .assertTrue (f"udf_{ id } _perf.pstats" in os .listdir (d ))
210204
211- if has_flameprof :
205+ if have_flameprof :
212206 self .assertIn ("svg" , self .spark .profile .render (id ))
213207
214208 @unittest .skipIf (
@@ -230,7 +224,7 @@ def test_perf_profiler_udf_with_arrow(self):
230224 io .getvalue (), f"10.*{ os .path .basename (inspect .getfile (_do_computation ))} "
231225 )
232226
233- if has_flameprof :
227+ if have_flameprof :
234228 self .assertIn ("svg" , self .spark .profile .render (id ))
235229
236230 def test_perf_profiler_udf_multiple_actions (self ):
@@ -252,7 +246,7 @@ def action(df):
252246 io .getvalue (), f"20.*{ os .path .basename (inspect .getfile (_do_computation ))} "
253247 )
254248
255- if has_flameprof :
249+ if have_flameprof :
256250 self .assertIn ("svg" , self .spark .profile .render (id ))
257251
258252 def test_perf_profiler_udf_registered (self ):
@@ -276,7 +270,7 @@ def add1(x):
276270 io .getvalue (), f"10.*{ os .path .basename (inspect .getfile (_do_computation ))} "
277271 )
278272
279- if has_flameprof :
273+ if have_flameprof :
280274 self .assertIn ("svg" , self .spark .profile .render (id ))
281275
282276 @unittest .skipIf (
@@ -309,7 +303,7 @@ def add2(x):
309303 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
310304 )
311305
312- if has_flameprof :
306+ if have_flameprof :
313307 self .assertIn ("svg" , self .spark .profile .render (id ))
314308
315309 @unittest .skipIf (
@@ -345,7 +339,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
345339 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
346340 )
347341
348- if has_flameprof :
342+ if have_flameprof :
349343 self .assertIn ("svg" , self .spark .profile .render (id ))
350344
351345 @unittest .skipIf (
@@ -395,7 +389,7 @@ def mean_udf(v: pd.Series) -> float:
395389 io .getvalue (), f"5.*{ os .path .basename (inspect .getfile (_do_computation ))} "
396390 )
397391
398- if has_flameprof :
392+ if have_flameprof :
399393 self .assertIn ("svg" , self .spark .profile .render (id ))
400394
401395 @unittest .skipIf (
@@ -427,7 +421,7 @@ def min_udf(v: pd.Series) -> float:
427421 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
428422 )
429423
430- if has_flameprof :
424+ if have_flameprof :
431425 self .assertIn ("svg" , self .spark .profile .render (id ))
432426
433427 @unittest .skipIf (
@@ -458,7 +452,7 @@ def normalize(pdf):
458452 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
459453 )
460454
461- if has_flameprof :
455+ if have_flameprof :
462456 self .assertIn ("svg" , self .spark .profile .render (id ))
463457
464458 @unittest .skipIf (
@@ -496,7 +490,7 @@ def asof_join(left, right):
496490 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
497491 )
498492
499- if has_flameprof :
493+ if have_flameprof :
500494 self .assertIn ("svg" , self .spark .profile .render (id ))
501495
502496 @unittest .skipIf (
@@ -530,7 +524,7 @@ def normalize(table):
530524 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
531525 )
532526
533- if has_flameprof :
527+ if have_flameprof :
534528 self .assertIn ("svg" , self .spark .profile .render (id ))
535529
536530 @unittest .skipIf (
@@ -562,7 +556,7 @@ def summarize(left, right):
562556 io .getvalue (), f"2.*{ os .path .basename (inspect .getfile (_do_computation ))} "
563557 )
564558
565- if has_flameprof :
559+ if have_flameprof :
566560 self .assertIn ("svg" , self .spark .profile .render (id ))
567561
568562 def test_perf_profiler_render (self ):
@@ -572,7 +566,7 @@ def test_perf_profiler_render(self):
572566
573567 id = list (self .profile_results .keys ())[0 ]
574568
575- if has_flameprof :
569+ if have_flameprof :
576570 self .assertIn ("svg" , self .spark .profile .render (id ))
577571 self .assertIn ("svg" , self .spark .profile .render (id , type = "perf" ))
578572 self .assertIn ("svg" , self .spark .profile .render (id , renderer = "flameprof" ))
0 commit comments