22
33import pytest
44
5- import narwhals as nw_main
6- import narwhals .stable .v1 as nw
5+ import narwhals as nw
6+ import narwhals .stable .v1 as nw_v1
77from tests .utils import ConstructorEager
88from tests .utils import assert_equal_data
99
1010
1111def test_expr_sample (constructor_eager : ConstructorEager ) -> None :
12- df = nw .from_native (
12+ df = nw_v1 .from_native (
1313 constructor_eager ({"a" : [1 , 2 , 3 ], "b" : [4 , 5 , 6 ]}), eager_only = True
1414 )
1515
16- result_expr = df .select (nw .col ("a" ).sample (n = 2 )).shape
16+ result_expr = df .select (nw_v1 .col ("a" ).sample (n = 2 )).shape
1717 expected_expr = (2 , 1 )
1818 assert result_expr == expected_expr
1919
@@ -24,15 +24,15 @@ def test_expr_sample(constructor_eager: ConstructorEager) -> None:
2424 with pytest .deprecated_call (
2525 match = "is deprecated and will be removed in a future version"
2626 ):
27- df .select (nw_main .col ("a" ).sample (n = 2 ))
27+ df .select (nw .col ("a" ).sample (n = 2 ))
2828
2929
3030def test_expr_sample_fraction (constructor_eager : ConstructorEager ) -> None :
31- df = nw .from_native (
31+ df = nw_v1 .from_native (
3232 constructor_eager ({"a" : [1 , 2 , 3 ] * 10 , "b" : [4 , 5 , 6 ] * 10 }), eager_only = True
3333 )
3434
35- result_expr = df .select (nw .col ("a" ).sample (fraction = 0.1 )).shape
35+ result_expr = df .select (nw_v1 .col ("a" ).sample (fraction = 0.1 )).shape
3636 expected_expr = (3 , 1 )
3737 assert result_expr == expected_expr
3838
@@ -43,15 +43,15 @@ def test_expr_sample_fraction(constructor_eager: ConstructorEager) -> None:
4343
4444def test_sample_with_seed (constructor_eager : ConstructorEager ) -> None :
4545 size , n = 100 , 10
46- df = nw .from_native (constructor_eager ({"a" : list (range (size ))}))
46+ df = nw_v1 .from_native (constructor_eager ({"a" : list (range (size ))}))
4747 expected = {"res1" : [True ], "res2" : [False ]}
4848 result = df .select (
49- seed1 = nw .col ("a" ).sample (n = n , seed = 123 ),
50- seed2 = nw .col ("a" ).sample (n = n , seed = 123 ),
51- seed3 = nw .col ("a" ).sample (n = n , seed = 42 ),
49+ seed1 = nw_v1 .col ("a" ).sample (n = n , seed = 123 ),
50+ seed2 = nw_v1 .col ("a" ).sample (n = n , seed = 123 ),
51+ seed3 = nw_v1 .col ("a" ).sample (n = n , seed = 42 ),
5252 ).select (
53- res1 = (nw .col ("seed1" ) == nw .col ("seed2" )).all (),
54- res2 = (nw .col ("seed1" ) == nw .col ("seed3" )).all (),
53+ res1 = (nw_v1 .col ("seed1" ) == nw_v1 .col ("seed2" )).all (),
54+ res2 = (nw_v1 .col ("seed1" ) == nw_v1 .col ("seed3" )).all (),
5555 )
5656
5757 assert_equal_data (result , expected )
0 commit comments