@@ -477,8 +477,8 @@ def test_tail(df):
477477 assert result .column (2 ) == pa .array ([8 ])
478478
479479
480- def test_with_column (df ):
481- df = df .with_column ("c" , column ( "a" ) + column ( "b" ) )
480+ def test_with_column_sql_expression (df ):
481+ df = df .with_column ("c" , "a + b" )
482482
483483 # execute and collect the first (and only) batch
484484 result = df .collect ()[0 ]
@@ -492,11 +492,19 @@ def test_with_column(df):
492492 assert result .column (2 ) == pa .array ([5 , 7 , 9 ])
493493
494494
495- def test_with_column_invalid_expr (df ):
496- with pytest .raises (
497- TypeError , match = r"Use col\(\)/column\(\) or lit\(\)/literal\(\)"
498- ):
499- df .with_column ("c" , "a" )
495+ def test_with_column (df ):
496+ df = df .with_column ("c" , column ("a" ) + column ("b" ))
497+
498+ # execute and collect the first (and only) batch
499+ result = df .collect ()[0 ]
500+
501+ assert result .schema .field (0 ).name == "a"
502+ assert result .schema .field (1 ).name == "b"
503+ assert result .schema .field (2 ).name == "c"
504+
505+ assert result .column (0 ) == pa .array ([1 , 2 , 3 ])
506+ assert result .column (1 ) == pa .array ([4 , 5 , 6 ])
507+ assert result .column (2 ) == pa .array ([5 , 7 , 9 ])
500508
501509
502510def test_with_columns (df ):
0 commit comments