| 
 | 1 | +# Licensed to the Apache Software Foundation (ASF) under one  | 
 | 2 | +# or more contributor license agreements.  See the NOTICE file  | 
 | 3 | +# distributed with this work for additional information  | 
 | 4 | +# regarding copyright ownership.  The ASF licenses this file  | 
 | 5 | +# to you under the Apache License, Version 2.0 (the  | 
 | 6 | +# "License"); you may not use this file except in compliance  | 
 | 7 | +# with the License.  You may obtain a copy of the License at  | 
 | 8 | +#  | 
 | 9 | +#   http://www.apache.org/licenses/LICENSE-2.0  | 
 | 10 | +#  | 
 | 11 | +# Unless required by applicable law or agreed to in writing,  | 
 | 12 | +# software distributed under the License is distributed on an  | 
 | 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY  | 
 | 14 | +# KIND, either express or implied.  See the License for the  | 
 | 15 | +# specific language governing permissions and limitations  | 
 | 16 | +# under the License.  | 
 | 17 | +"""Benchmark script showing how to maximize CPU usage."""  | 
 | 18 | + | 
 | 19 | +from __future__ import annotations  | 
 | 20 | + | 
 | 21 | +import argparse  | 
 | 22 | +import multiprocessing  | 
 | 23 | +import time  | 
 | 24 | + | 
 | 25 | +import pyarrow as pa  | 
 | 26 | +from datafusion import SessionConfig, SessionContext, col  | 
 | 27 | +from datafusion import functions as f  | 
 | 28 | + | 
 | 29 | + | 
 | 30 | +def main(num_rows: int, partitions: int) -> None:  | 
 | 31 | +    """Run a simple aggregation after repartitioning."""  | 
 | 32 | +    # Create some example data  | 
 | 33 | +    array = pa.array(range(num_rows))  | 
 | 34 | +    batch = pa.record_batch([array], names=["a"])  | 
 | 35 | + | 
 | 36 | +    # Configure the session to use a higher target partition count and  | 
 | 37 | +    # enable automatic repartitioning.  | 
 | 38 | +    config = (  | 
 | 39 | +        SessionConfig()  | 
 | 40 | +        .with_target_partitions(partitions)  | 
 | 41 | +        .with_repartition_joins(enabled=True)  | 
 | 42 | +        .with_repartition_aggregations(enabled=True)  | 
 | 43 | +        .with_repartition_windows(enabled=True)  | 
 | 44 | +    )  | 
 | 45 | +    ctx = SessionContext(config)  | 
 | 46 | + | 
 | 47 | +    # Register the input data and repartition manually to ensure that all  | 
 | 48 | +    # partitions are used.  | 
 | 49 | +    df = ctx.create_dataframe([[batch]]).repartition(partitions)  | 
 | 50 | + | 
 | 51 | +    start = time.time()  | 
 | 52 | +    df = df.aggregate([], [f.sum(col("a"))])  | 
 | 53 | +    df.collect()  | 
 | 54 | +    end = time.time()  | 
 | 55 | + | 
 | 56 | +    print(  | 
 | 57 | +        f"Processed {num_rows} rows using {partitions} partitions in {end - start:.3f}s"  | 
 | 58 | +    )  | 
 | 59 | + | 
 | 60 | + | 
 | 61 | +if __name__ == "__main__":  | 
 | 62 | +    parser = argparse.ArgumentParser(description=__doc__)  | 
 | 63 | +    parser.add_argument(  | 
 | 64 | +        "--rows",  | 
 | 65 | +        type=int,  | 
 | 66 | +        default=1_000_000,  | 
 | 67 | +        help="Number of rows in the generated dataset",  | 
 | 68 | +    )  | 
 | 69 | +    parser.add_argument(  | 
 | 70 | +        "--partitions",  | 
 | 71 | +        type=int,  | 
 | 72 | +        default=multiprocessing.cpu_count(),  | 
 | 73 | +        help="Target number of partitions to use",  | 
 | 74 | +    )  | 
 | 75 | +    args = parser.parse_args()  | 
 | 76 | +    main(args.rows, args.partitions)  | 
0 commit comments