Skip to content

Commit d771487

Browse files
committed
Add test for nsorted method
1 parent 92d53da commit d771487

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

pandas/tests/frame/methods/test_nlargest.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def test_nsorted_n(self, nselect_method, n: int, columns):
7575
)
7676
if "b" in columns:
7777
error_msg = (
78-
f"Column 'b' has dtype (object|str), "
79-
f"cannot use method '{nselect_method}' with this dtype"
78+
"Column 'b' has dtype (object|str), "
79+
"cannot use n-sorting with this dtype"
8080
)
8181
with pytest.raises(TypeError, match=error_msg):
8282
getattr(df, nselect_method)(n, columns)
@@ -87,6 +87,29 @@ def test_nsorted_n(self, nselect_method, n: int, columns):
8787
expected = df.sort_values(columns, ascending=ascending).head(n)
8888
tm.assert_frame_equal(result, expected)
8989

90+
def test_nsorted(self):
91+
df = pd.DataFrame(
92+
{
93+
"x": [2, 2, 1],
94+
"y": [3, 2, 1],
95+
},
96+
index=["a", "b", "c"],
97+
)
98+
cols = ["x", "y"]
99+
ascending = [True, False]
100+
n = 2
101+
df_sort_values = df.sort_values(cols, ascending=ascending).head(n)
102+
result = df.nsorted(n, cols, ascending=ascending)
103+
tm.assert_frame_equal(result, df_sort_values)
104+
expected = pd.DataFrame(
105+
{
106+
"x": [1, 2],
107+
"y": [1, 3],
108+
},
109+
index=["c", "a"],
110+
)
111+
tm.assert_frame_equal(result, expected)
112+
90113
@pytest.mark.parametrize(
91114
"columns", [["group", "category_string"], ["group", "string"]]
92115
)
@@ -95,7 +118,7 @@ def test_nsorted_error(self, df_main_dtypes, nselect_method, columns):
95118
col = columns[1]
96119
error_msg = (
97120
f"Column '{col}' has dtype {df[col].dtype}, "
98-
f"cannot use method '{nselect_method}' with this dtype"
121+
f"cannot use n-sorting with this dtype"
99122
)
100123
# escape some characters that may be in the repr
101124
error_msg = (

pandas/tests/series/methods/test_nlargest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def assert_check_nselect_boundary(vals, dtype, method):
2020
tm.assert_series_equal(result, expected)
2121

2222

23-
class TestSeriesNLargestNSmallest:
23+
class TestSeriesNSorted:
2424
@pytest.mark.parametrize(
2525
"r",
2626
[
@@ -37,7 +37,7 @@ class TestSeriesNLargestNSmallest:
3737
@pytest.mark.parametrize("arg", [2, 5, 0, -1])
3838
def test_nlargest_error(self, r, method, arg):
3939
dt = r.dtype
40-
msg = f"Cannot use method 'n(largest|smallest)' with dtype {dt}"
40+
msg = f"Cannot use n-sorting with dtype {dt}"
4141
with pytest.raises(TypeError, match=msg):
4242
getattr(r, method)(arg)
4343

@@ -78,6 +78,9 @@ def test_nsmallest_nlargest(self, data):
7878
tm.assert_series_equal(ser.nlargest(len(ser)), ser.iloc[[4, 0, 1, 3, 2]])
7979
tm.assert_series_equal(ser.nlargest(len(ser) + 1), ser.iloc[[4, 0, 1, 3, 2]])
8080

81+
tm.assert_series_equal(ser.nsorted(2, True), ser.nsmallest(2))
82+
tm.assert_series_equal(ser.nsorted(2, False), ser.nlargest(2))
83+
8184
def test_nlargest_misc(self):
8285
ser = Series([3.0, np.nan, 1, 2, 5])
8386
result = ser.nlargest()

0 commit comments

Comments
 (0)