Skip to content

Commit 46756cb

Browse files
committed
fix(typing): pyspark compat for std, var
1 parent e923bf9 commit 46756cb

File tree

1 file changed

+39
-19
lines changed

1 file changed

+39
-19
lines changed

narwhals/_spark_like/utils.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING
44
from typing import Any
5+
from typing import cast
56

67
from narwhals.exceptions import UnsupportedDTypeError
78
from narwhals.utils import Implementation
@@ -11,6 +12,7 @@
1112
if TYPE_CHECKING:
1213
from types import ModuleType
1314

15+
import sqlframe.base.functions as sqlframe_functions
1416
import sqlframe.base.types as sqlframe_types
1517
from sqlframe.base.column import Column
1618
from typing_extensions import TypeAlias
@@ -186,43 +188,61 @@ def _std(
186188
_input: Column | str,
187189
ddof: int,
188190
np_version: tuple[int, ...],
189-
functions: Any,
191+
functions: ModuleType,
190192
implementation: Implementation,
191193
) -> Column:
194+
if TYPE_CHECKING:
195+
F = sqlframe_functions # noqa: N806
196+
else:
197+
F = functions # noqa: N806
198+
column = F.col(_input) if isinstance(_input, str) else _input
192199
if implementation is Implementation.SQLFRAME or np_version > (2, 0):
193200
if ddof == 0:
194-
return functions.stddev_pop(_input)
201+
return F.stddev_pop(column)
195202
if ddof == 1:
196-
return functions.stddev_samp(_input)
197-
198-
n_rows = functions.count(_input)
199-
return functions.stddev_samp(_input) * functions.sqrt(
200-
(n_rows - 1) / (n_rows - ddof)
201-
)
202-
203-
from pyspark.pandas.spark.functions import stddev
203+
return F.stddev_samp(column)
204+
n_rows = F.count(column)
205+
return F.stddev_samp(column) * F.sqrt((n_rows - 1) / (n_rows - ddof))
206+
if TYPE_CHECKING:
207+
return F.stddev(column)
204208

205-
input_col = functions.col(_input) if isinstance(_input, str) else _input
206-
return stddev(input_col, ddof=ddof)
209+
return _stddev_pyspark(column, ddof)
207210

208211

209212
def _var(
210213
_input: Column | str,
211214
ddof: int,
212215
np_version: tuple[int, ...],
213-
functions: Any,
216+
functions: ModuleType,
214217
implementation: Implementation,
215218
) -> Column:
219+
if TYPE_CHECKING:
220+
F = sqlframe_functions # noqa: N806
221+
else:
222+
F = functions # noqa: N806
223+
column = F.col(_input) if isinstance(_input, str) else _input
216224
if implementation is Implementation.SQLFRAME or np_version > (2, 0):
217225
if ddof == 0:
218-
return functions.var_pop(_input)
226+
return F.var_pop(column)
219227
if ddof == 1:
220-
return functions.var_samp(_input)
228+
return F.var_samp(column)
229+
230+
n_rows = F.count(column)
231+
return F.var_samp(column) * (n_rows - 1) / (n_rows - ddof)
232+
233+
if TYPE_CHECKING:
234+
return F.var_samp(column)
235+
236+
return _var_pyspark(column, ddof)
237+
238+
239+
def _stddev_pyspark(col: Any, ddof: int, /) -> Column:
240+
from pyspark.pandas.spark.functions import stddev
241+
242+
return cast("Column", stddev(col, ddof=ddof))
221243

222-
n_rows = functions.count(_input)
223-
return functions.var_samp(_input) * (n_rows - 1) / (n_rows - ddof)
224244

245+
def _var_pyspark(col: Any, ddof: int, /) -> Column:
225246
from pyspark.pandas.spark.functions import var
226247

227-
input_col = functions.col(_input) if isinstance(_input, str) else _input
228-
return var(input_col, ddof=ddof)
248+
return cast("Column", var(col, ddof=ddof))

0 commit comments

Comments
 (0)