Skip to content

Commit e55511c

Browse files
committed
[SPARK-50388][PYTHON][TESTS][FOLLOW-UP] Move have_flameprof to pyspark.testing.utils
### What changes were proposed in this pull request? Move `have_flameprof` to `pyspark.testing.utils` ### Why are the changes needed? to centralize the import check ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48973 from zhengruifeng/fix_has_flameprof. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 5425d45 commit e55511c

File tree

3 files changed

+20
-24
lines changed

3 files changed

+20
-24
lines changed

python/pyspark/sql/tests/connect/test_parity_udf_profiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from pyspark.sql.tests.test_udf_profiler import (
2222
UDFProfiler2TestsMixin,
2323
_do_computation,
24-
has_flameprof,
2524
)
2625
from pyspark.testing.connectutils import ReusedConnectTestCase
26+
from pyspark.testing.utils import have_flameprof
2727

2828

2929
class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase):
@@ -65,7 +65,7 @@ def action(df):
6565
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
6666
)
6767

68-
if has_flameprof:
68+
if have_flameprof:
6969
self.assertIn("svg", self.spark.profile.render(id))
7070

7171

python/pyspark/sql/tests/test_udf_profiler.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,15 @@
3131
from pyspark.sql.functions import col, pandas_udf, udf
3232
from pyspark.sql.window import Window
3333
from pyspark.profiler import UDFBasicProfiler
34-
from pyspark.testing.sqlutils import (
35-
ReusedSQLTestCase,
34+
from pyspark.testing.sqlutils import ReusedSQLTestCase
35+
from pyspark.testing.utils import (
3636
have_pandas,
3737
have_pyarrow,
38+
have_flameprof,
3839
pandas_requirement_message,
3940
pyarrow_requirement_message,
4041
)
4142

42-
try:
43-
import flameprof # noqa: F401
44-
45-
has_flameprof = True
46-
except ImportError:
47-
has_flameprof = False
48-
4943

5044
def _do_computation(spark, *, action=lambda df: df.collect(), use_arrow=False):
5145
@udf("long", useArrow=use_arrow)
@@ -208,7 +202,7 @@ def test_perf_profiler_udf(self):
208202
)
209203
self.assertTrue(f"udf_{id}_perf.pstats" in os.listdir(d))
210204

211-
if has_flameprof:
205+
if have_flameprof:
212206
self.assertIn("svg", self.spark.profile.render(id))
213207

214208
@unittest.skipIf(
@@ -230,7 +224,7 @@ def test_perf_profiler_udf_with_arrow(self):
230224
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
231225
)
232226

233-
if has_flameprof:
227+
if have_flameprof:
234228
self.assertIn("svg", self.spark.profile.render(id))
235229

236230
def test_perf_profiler_udf_multiple_actions(self):
@@ -252,7 +246,7 @@ def action(df):
252246
io.getvalue(), f"20.*{os.path.basename(inspect.getfile(_do_computation))}"
253247
)
254248

255-
if has_flameprof:
249+
if have_flameprof:
256250
self.assertIn("svg", self.spark.profile.render(id))
257251

258252
def test_perf_profiler_udf_registered(self):
@@ -276,7 +270,7 @@ def add1(x):
276270
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}"
277271
)
278272

279-
if has_flameprof:
273+
if have_flameprof:
280274
self.assertIn("svg", self.spark.profile.render(id))
281275

282276
@unittest.skipIf(
@@ -309,7 +303,7 @@ def add2(x):
309303
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
310304
)
311305

312-
if has_flameprof:
306+
if have_flameprof:
313307
self.assertIn("svg", self.spark.profile.render(id))
314308

315309
@unittest.skipIf(
@@ -345,7 +339,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
345339
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
346340
)
347341

348-
if has_flameprof:
342+
if have_flameprof:
349343
self.assertIn("svg", self.spark.profile.render(id))
350344

351345
@unittest.skipIf(
@@ -395,7 +389,7 @@ def mean_udf(v: pd.Series) -> float:
395389
io.getvalue(), f"5.*{os.path.basename(inspect.getfile(_do_computation))}"
396390
)
397391

398-
if has_flameprof:
392+
if have_flameprof:
399393
self.assertIn("svg", self.spark.profile.render(id))
400394

401395
@unittest.skipIf(
@@ -427,7 +421,7 @@ def min_udf(v: pd.Series) -> float:
427421
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
428422
)
429423

430-
if has_flameprof:
424+
if have_flameprof:
431425
self.assertIn("svg", self.spark.profile.render(id))
432426

433427
@unittest.skipIf(
@@ -458,7 +452,7 @@ def normalize(pdf):
458452
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
459453
)
460454

461-
if has_flameprof:
455+
if have_flameprof:
462456
self.assertIn("svg", self.spark.profile.render(id))
463457

464458
@unittest.skipIf(
@@ -496,7 +490,7 @@ def asof_join(left, right):
496490
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
497491
)
498492

499-
if has_flameprof:
493+
if have_flameprof:
500494
self.assertIn("svg", self.spark.profile.render(id))
501495

502496
@unittest.skipIf(
@@ -530,7 +524,7 @@ def normalize(table):
530524
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
531525
)
532526

533-
if has_flameprof:
527+
if have_flameprof:
534528
self.assertIn("svg", self.spark.profile.render(id))
535529

536530
@unittest.skipIf(
@@ -562,7 +556,7 @@ def summarize(left, right):
562556
io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}"
563557
)
564558

565-
if has_flameprof:
559+
if have_flameprof:
566560
self.assertIn("svg", self.spark.profile.render(id))
567561

568562
def test_perf_profiler_render(self):
@@ -572,7 +566,7 @@ def test_perf_profiler_render(self):
572566

573567
id = list(self.profile_results.keys())[0]
574568

575-
if has_flameprof:
569+
if have_flameprof:
576570
self.assertIn("svg", self.spark.profile.render(id))
577571
self.assertIn("svg", self.spark.profile.render(id, type="perf"))
578572
self.assertIn("svg", self.spark.profile.render(id, renderer="flameprof"))

python/pyspark/testing/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def have_package(name: str) -> bool:
9191
have_graphviz = have_package("graphviz")
9292
graphviz_requirement_message = None if have_graphviz else "No module named 'graphviz'"
9393

94+
have_flameprof = have_package("flameprof")
95+
flameprof_requirement_message = None if have_flameprof else "No module named 'flameprof'"
9496

9597
pandas_requirement_message = None
9698
try:

0 commit comments

Comments
 (0)