|
3 | 3 | import pytest |
4 | 4 |
|
5 | 5 | def test_register_filtered_dataframe(): |
6 | | - # Create a new session context |
7 | 6 | ctx = SessionContext() |
8 | 7 |
|
9 | | - # Create sample data as a dictionary |
10 | 8 | data = { |
11 | 9 | "a": [1, 2, 3, 4, 5], |
12 | 10 | "b": [10, 20, 30, 40, 50] |
13 | 11 | } |
14 | 12 |
|
15 | | - # Create a DataFrame from the dictionary |
16 | 13 | df = ctx.from_pydict(data, "my_table") |
17 | 14 |
|
18 | | - # Filter the DataFrame (for example, keep rows where a > 2) |
19 | 15 | df_filtered = df.filter(col("a") > literal(2)) |
20 | 16 | view = df_filtered.into_view() |
21 | 17 |
|
| 18 | + assert view.kind == "view" |
22 | 19 |
|
23 | | - # Register the filtered DataFrame as a table called "view1" |
24 | 20 | ctx.register_table("view1", view) |
25 | 21 |
|
26 | | - # Now run a SQL query against the registered table "view1" |
27 | 22 | df_view = ctx.sql("SELECT * FROM view1") |
28 | 23 |
|
29 | | - # Collect the results (as a list of Arrow RecordBatches) |
30 | 24 | results = df_view.collect() |
31 | 25 |
|
32 | | - # Convert results to a list of dictionaries for easier assertion |
33 | 26 | result_dicts = [batch.to_pydict() for batch in results] |
34 | 27 |
|
35 | | - # Expected results |
36 | 28 | expected_results = [ |
37 | 29 | {"a": [3, 4, 5], "b": [30, 40, 50]} |
38 | 30 | ] |
39 | 31 |
|
40 | | - # Assert the results match the expected results |
41 | 32 | assert result_dicts == expected_results |
42 | 33 |
|
43 | | - assert view.kind == "view" |
0 commit comments