|
2 | 2 |
|
3 | 3 | from typing import TYPE_CHECKING |
4 | 4 | from typing import Any |
| 5 | +from typing import cast |
5 | 6 |
|
6 | 7 | from narwhals.exceptions import UnsupportedDTypeError |
7 | 8 | from narwhals.utils import Implementation |
|
11 | 12 | if TYPE_CHECKING: |
12 | 13 | from types import ModuleType |
13 | 14 |
|
| 15 | + import sqlframe.base.functions as sqlframe_functions |
14 | 16 | import sqlframe.base.types as sqlframe_types |
15 | 17 | from sqlframe.base.column import Column |
16 | 18 | from typing_extensions import TypeAlias |
@@ -186,43 +188,61 @@ def _std( |
186 | 188 | _input: Column | str, |
187 | 189 | ddof: int, |
188 | 190 | np_version: tuple[int, ...], |
189 | | - functions: Any, |
| 191 | + functions: ModuleType, |
190 | 192 | implementation: Implementation, |
191 | 193 | ) -> Column: |
| 194 | + if TYPE_CHECKING: |
| 195 | + F = sqlframe_functions # noqa: N806 |
| 196 | + else: |
| 197 | + F = functions # noqa: N806 |
| 198 | + column = F.col(_input) if isinstance(_input, str) else _input |
192 | 199 | if implementation is Implementation.SQLFRAME or np_version > (2, 0): |
193 | 200 | if ddof == 0: |
194 | | - return functions.stddev_pop(_input) |
| 201 | + return F.stddev_pop(column) |
195 | 202 | if ddof == 1: |
196 | | - return functions.stddev_samp(_input) |
197 | | - |
198 | | - n_rows = functions.count(_input) |
199 | | - return functions.stddev_samp(_input) * functions.sqrt( |
200 | | - (n_rows - 1) / (n_rows - ddof) |
201 | | - ) |
202 | | - |
203 | | - from pyspark.pandas.spark.functions import stddev |
| 203 | + return F.stddev_samp(column) |
| 204 | + n_rows = F.count(column) |
| 205 | + return F.stddev_samp(column) * F.sqrt((n_rows - 1) / (n_rows - ddof)) |
| 206 | + if TYPE_CHECKING: |
| 207 | + return F.stddev(column) |
204 | 208 |
|
205 | | - input_col = functions.col(_input) if isinstance(_input, str) else _input |
206 | | - return stddev(input_col, ddof=ddof) |
| 209 | + return _stddev_pyspark(column, ddof) |
207 | 210 |
|
208 | 211 |
|
209 | 212 | def _var( |
210 | 213 | _input: Column | str, |
211 | 214 | ddof: int, |
212 | 215 | np_version: tuple[int, ...], |
213 | | - functions: Any, |
| 216 | + functions: ModuleType, |
214 | 217 | implementation: Implementation, |
215 | 218 | ) -> Column: |
| 219 | + if TYPE_CHECKING: |
| 220 | + F = sqlframe_functions # noqa: N806 |
| 221 | + else: |
| 222 | + F = functions # noqa: N806 |
| 223 | + column = F.col(_input) if isinstance(_input, str) else _input |
216 | 224 | if implementation is Implementation.SQLFRAME or np_version > (2, 0): |
217 | 225 | if ddof == 0: |
218 | | - return functions.var_pop(_input) |
| 226 | + return F.var_pop(column) |
219 | 227 | if ddof == 1: |
220 | | - return functions.var_samp(_input) |
| 228 | + return F.var_samp(column) |
| 229 | + |
| 230 | + n_rows = F.count(column) |
| 231 | + return F.var_samp(column) * (n_rows - 1) / (n_rows - ddof) |
| 232 | + |
| 233 | + if TYPE_CHECKING: |
| 234 | + return F.var_samp(column) |
| 235 | + |
| 236 | + return _var_pyspark(column, ddof) |
| 237 | + |
| 238 | + |
| 239 | +def _stddev_pyspark(col: Any, ddof: int, /) -> Column: |
| 240 | + from pyspark.pandas.spark.functions import stddev |
| 241 | + |
| 242 | + return cast("Column", stddev(col, ddof=ddof)) |
221 | 243 |
|
222 | | - n_rows = functions.count(_input) |
223 | | - return functions.var_samp(_input) * (n_rows - 1) / (n_rows - ddof) |
224 | 244 |
|
| 245 | +def _var_pyspark(col: Any, ddof: int, /) -> Column: |
225 | 246 | from pyspark.pandas.spark.functions import var |
226 | 247 |
|
227 | | - input_col = functions.col(_input) if isinstance(_input, str) else _input |
228 | | - return var(input_col, ddof=ddof) |
| 248 | + return cast("Column", var(col, ddof=ddof)) |
0 commit comments