Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions benchmarks/db-benchmark/groupby-datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import timeit

import datafusion as df
import pyarrow
import pyarrow as pa
from datafusion import (
RuntimeEnvBuilder,
SessionConfig,
Expand All @@ -37,7 +37,7 @@
exec(open("./_helpers/helpers.py").read())


def ans_shape(batches):
def ans_shape(batches) -> tuple[int, int]:
rows, cols = 0, 0
for batch in batches:
rows += batch.num_rows
Expand All @@ -48,7 +48,7 @@ def ans_shape(batches):
return rows, cols


def execute(df):
def execute(df) -> list:
print(df.execution_plan().display_indent())
return df.collect()

Expand All @@ -68,14 +68,14 @@ def execute(df):
src_grp = os.path.join("data", data_name + ".csv")
print("loading dataset %s" % src_grp, flush=True)

schema = pyarrow.schema(
schema = pa.schema(
[
("id4", pyarrow.int32()),
("id5", pyarrow.int32()),
("id6", pyarrow.int32()),
("v1", pyarrow.int32()),
("v2", pyarrow.int32()),
("v3", pyarrow.float64()),
("id4", pa.int32()),
("id5", pa.int32()),
("id6", pa.int32()),
("v1", pa.int32()),
("v2", pa.int32()),
("v3", pa.float64()),
]
)

Expand All @@ -93,8 +93,8 @@ def execute(df):
)
config = (
SessionConfig()
.with_repartition_joins(False)
.with_repartition_aggregations(False)
.with_repartition_joins(enabled=False)
.with_repartition_aggregations(enabled=False)
.set("datafusion.execution.coalesce_batches", "false")
)
ctx = SessionContext(config, runtime)
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/db-benchmark/join-datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
exec(open("./_helpers/helpers.py").read())


def ans_shape(batches):
def ans_shape(batches) -> tuple[int, int]:
rows, cols = 0, 0
for batch in batches:
rows += batch.num_rows
Expand Down Expand Up @@ -57,7 +57,8 @@ def ans_shape(batches):
os.path.join("data", y_data_name[2] + ".csv"),
]
if len(src_jn_y) != 3:
raise Exception("Something went wrong in preparing files used for join")
error_msg = "Something went wrong in preparing files used for join"
raise Exception(error_msg)

print(
"loading datasets "
Expand Down
7 changes: 2 additions & 5 deletions benchmarks/tpch/tpch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datafusion import SessionContext


def bench(data_path, query_path):
def bench(data_path, query_path) -> None:
with open("results.csv", "w") as results:
# register tables
start = time.time()
Expand Down Expand Up @@ -68,10 +68,7 @@ def bench(data_path, query_path):
with open(f"{query_path}/q{query}.sql") as f:
text = f.read()
tmp = text.split(";")
queries = []
for str in tmp:
if len(str.strip()) > 0:
queries.append(str.strip())
queries = [s.strip() for s in tmp if len(s.strip()) > 0]

try:
start = time.time()
Expand Down
6 changes: 3 additions & 3 deletions dev/release/generate-changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from github import Github


def print_pulls(repo_name, title, pulls):
def print_pulls(repo_name, title, pulls) -> None:
if len(pulls) > 0:
print(f"**{title}:**")
print()
Expand All @@ -34,7 +34,7 @@ def print_pulls(repo_name, title, pulls):
print()


def generate_changelog(repo, repo_name, tag1, tag2, version):
def generate_changelog(repo, repo_name, tag1, tag2, version) -> None:
# get a list of commits between two tags
print(f"Fetching list of commits between {tag1} and {tag2}", file=sys.stderr)
comparison = repo.compare(tag1, tag2)
Expand Down Expand Up @@ -154,7 +154,7 @@ def generate_changelog(repo, repo_name, tag1, tag2, version):
)


def cli(args=None):
def cli(args=None) -> None:
"""Process command line arguments."""
if not args:
args = sys.argv[1:]
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
autoapi_python_class_content = "both"


def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa: ARG001
def autoapi_skip_member_fn(app, what, name, obj, skip, options) -> bool: # noqa: ARG001
skip_contents = [
# Re-exports
("class", "datafusion.DataFrame"),
Expand All @@ -93,7 +93,7 @@ def autoapi_skip_member_fn(app, what, name, obj, skip, options): # noqa: ARG001
return skip


def setup(sphinx):
def setup(sphinx) -> None:
sphinx.connect("autoapi-skip-member", autoapi_skip_member_fn)


Expand Down
12 changes: 6 additions & 6 deletions examples/create-context.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000)
config = (
SessionConfig()
.with_create_default_catalog_and_schema(True)
.with_create_default_catalog_and_schema(enabled=True)
.with_default_catalog_and_schema("foo", "bar")
.with_target_partitions(8)
.with_information_schema(True)
.with_repartition_joins(False)
.with_repartition_aggregations(False)
.with_repartition_windows(False)
.with_parquet_pruning(False)
.with_information_schema(enabled=True)
.with_repartition_joins(enabled=False)
.with_repartition_aggregations(enabled=False)
.with_repartition_windows(enabled=False)
.with_parquet_pruning(enabled=False)
.set("datafusion.execution.parquet.pushdown_filters", "true")
)
ctx = SessionContext(config, runtime)
Expand Down
36 changes: 16 additions & 20 deletions examples/python-udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import datafusion
import pyarrow
import pyarrow as pa
import pyarrow.compute
from datafusion import Accumulator, col, udaf

Expand All @@ -26,48 +26,44 @@ class MyAccumulator(Accumulator):
Interface of a user-defined accumulation.
"""

def __init__(self):
self._sum = pyarrow.scalar(0.0)
def __init__(self) -> None:
self._sum = pa.scalar(0.0)

def update(self, values: pyarrow.Array) -> None:
def update(self, values: pa.Array) -> None:
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
self._sum = pyarrow.scalar(
self._sum.as_py() + pyarrow.compute.sum(values).as_py()
)
self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(values).as_py())

def merge(self, states: pyarrow.Array) -> None:
def merge(self, states: pa.Array) -> None:
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
self._sum = pyarrow.scalar(
self._sum.as_py() + pyarrow.compute.sum(states).as_py()
)
self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states).as_py())

def state(self) -> pyarrow.Array:
return pyarrow.array([self._sum.as_py()])
def state(self) -> pa.Array:
return pa.array([self._sum.as_py()])

def evaluate(self) -> pyarrow.Scalar:
def evaluate(self) -> pa.Scalar:
return self._sum


# create a context
ctx = datafusion.SessionContext()

# create a RecordBatch and a new DataFrame from it
batch = pyarrow.RecordBatch.from_arrays(
[pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]])

my_udaf = udaf(
MyAccumulator,
pyarrow.float64(),
pyarrow.float64(),
[pyarrow.float64()],
pa.float64(),
pa.float64(),
[pa.float64()],
"stable",
)

df = df.aggregate([], [my_udaf(col("a"))])

result = df.collect()[0]

assert result.column(0) == pyarrow.array([6.0])
assert result.column(0) == pa.array([6.0])
6 changes: 3 additions & 3 deletions examples/python-udf-comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def udf_using_pyarrow_compute_impl(
resultant_arr = pc.and_(resultant_arr, filtered_returnflag_arr)

if results is None:
results = resultant_arr
else:
results = pc.or_(results, resultant_arr)
results = (
resultant_arr if results is None else pc.or_(results, resultant_arr)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double check this? It looks like it changed the logic slightly. If results is not None it doesn't look like this will work as before


return results

Expand Down
12 changes: 6 additions & 6 deletions examples/python-udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
# specific language governing permissions and limitations
# under the License.

import pyarrow
import pyarrow as pa
from datafusion import SessionContext, udf
from datafusion import functions as f


def is_null(array: pyarrow.Array) -> pyarrow.Array:
def is_null(array: pa.Array) -> pa.Array:
return array.is_null()


is_null_arr = udf(is_null, [pyarrow.int64()], pyarrow.bool_(), "stable")
is_null_arr = udf(is_null, [pa.int64()], pa.bool_(), "stable")

# create a context
ctx = SessionContext()

# create a RecordBatch and a new DataFrame from it
batch = pyarrow.RecordBatch.from_arrays(
[pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]])
Expand All @@ -40,4 +40,4 @@ def is_null(array: pyarrow.Array) -> pyarrow.Array:

result = df.collect()[0]

assert result.column(0) == pyarrow.array([False] * 3)
assert result.column(0) == pa.array([False] * 3)
10 changes: 5 additions & 5 deletions examples/query-pyarrow-data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
# under the License.

import datafusion
import pyarrow
import pyarrow as pa
from datafusion import col

# create a context
ctx = datafusion.SessionContext()

# create a RecordBatch and a new DataFrame from it
batch = pyarrow.RecordBatch.from_arrays(
[pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df = ctx.create_dataframe([[batch]])
Expand All @@ -38,5 +38,5 @@
# execute and collect the first (and only) batch
result = df.collect()[0]

assert result.column(0) == pyarrow.array([5, 7, 9])
assert result.column(1) == pyarrow.array([-3, -3, -3])
assert result.column(0) == pa.array([5, 7, 9])
assert result.column(1) == pa.array([-3, -3, -3])
2 changes: 1 addition & 1 deletion examples/sql-using-python-udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class MyAccumulator(Accumulator):
Interface of a user-defined accumulation.
"""

def __init__(self):
def __init__(self) -> None:
self._sum = pa.scalar(0.0)

def update(self, values: pa.Array) -> None:
Expand Down
Loading