diff --git a/.gitignore b/.gitignore index 1a05e5d8..ce69e802 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /.idea -/target \ No newline at end of file +/target +/benchmarks/data/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f10473bb..84e8d020 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,65 @@ dependencies = [ "libc", ] +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + +[[package]] +name = "anstream" +version = "0.6.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301af1932e46185686725e0fad2f8f2aa7da69dd70bf6ecc44d6b703844a3933" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "862ed96ca487e809f1c8e5a8447f6ee2cf102f846893800b20cebdf541fc6bbd" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8bdeb6047d8983be085bab0ba1472e6dc604e7041dbf6fcd5e71523014fae9" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "403f75924867bb1033c59fbf0797484329750cfbe3c4325cd33127941fabc882" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.59.0", +] + [[package]] name = "anyhow" version = "1.0.98" @@ -364,7 +423,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -375,7 +434,7 @@ checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -393,6 +452,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -480,6 +550,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.9.1" @@ -623,6 +699,27 @@ dependencies = [ "phf", ] +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "ansi_term", + "atty", + "bitflags 1.3.2", + "strsim", + "textwrap", + "unicode-width 0.1.14", + "vec_map", +] + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + [[package]] name = "comfy-table" version = "7.1.4" @@ -630,7 +727,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a65ebfec4fb190b6f90e944a817d60499ee0744e582530e2c9900a22e591d9a" dependencies = [ "unicode-segmentation", - "unicode-width", + "unicode-width 0.2.1", ] [[package]] @@ -1036,6 +1133,23 @@ dependencies = [ "uuid", ] +[[package]] +name = "datafusion-distributed-benchmarks" +version = "0.1.0" +dependencies = [ + "async-trait", + "datafusion", + "datafusion-distributed", + "datafusion-proto", + "env_logger", + "log", + "parquet", + "serde", + "serde_json", + "structopt", + "tokio", +] + [[package]] name = "datafusion-doc" version = "49.0.0" @@ -1233,7 +1347,7 @@ checksum = "4cabe1f32daa2fa54e6b20d14a13a9e85bef97c4161fe8a90d76b6d9693a5ac4" dependencies = [ "datafusion-expr", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -1436,7 +1550,7 @@ checksum = "6178a82cf56c836a3ba61a7935cdb1c49bfaa6fa4327cd5bf554a503087de26b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -1458,7 +1572,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -1473,6 +1587,29 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "env_filter" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "jiff", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1507,7 +1644,7 @@ version = "25.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1045398c1bfd89168b5fd3f1fc11f6e70b34f6f66300c87d44d3de849463abf1" dependencies = [ - "bitflags", + "bitflags 2.9.1", "rustc_version", ] @@ -1599,7 +1736,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -1736,6 +1873,24 @@ dependencies = [ "foldhash", ] +[[package]] +name = "heck" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hex" version = "0.4.3" @@ -2024,11 +2179,17 @@ version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ - "bitflags", + "bitflags 2.9.1", "cfg-if", "libc", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + [[package]] name = "itertools" version = "0.14.0" @@ -2044,6 +2205,30 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "jiff" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be1f93b8b1eb69c77f24bbb0afdf66f54b632ee39af40ca21c4365a1d7347e49" +dependencies = [ + "jiff-static", + "log", + "portable-atomic", + "portable-atomic-util", + "serde", +] + +[[package]] +name = "jiff-static" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03343451ff899767262ec32146f6d559dd759fdadf42ff0e227c7c48f72594b4" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "jobserver" version = "0.1.33" @@ -2064,6 +2249,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "lexical-core" version = "1.0.5" @@ -2364,6 +2555,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" + [[package]] name = "ordered-float" version = "2.10.1" @@ -2492,7 +2689,7 @@ checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -2513,6 +2710,21 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.2" @@ -2531,6 +2743,30 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -2560,7 +2796,7 @@ dependencies = [ "itertools", "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -2672,7 +2908,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -2681,7 +2917,7 @@ version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e8af0dde094006011e6a740d4879319439489813bd0bcdc7d821beaeeff48ec" dependencies = [ - "bitflags", + "bitflags 2.9.1", ] [[package]] @@ -2748,7 +2984,7 @@ version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ - "bitflags", + "bitflags 2.9.1", "errno", "libc", "linux-raw-sys", @@ -2811,7 +3047,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -2927,7 +3163,7 @@ checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -2955,12 +3191,53 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + +[[package]] +name = "structopt" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c6b5c64445ba8094a6ab0c3cd2ad323e07171012d9c98b0b15651daf1787a10" +dependencies = [ + "clap", + "lazy_static", + "structopt-derive", +] + +[[package]] +name = "structopt-derive" +version = "0.4.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcb5ae327f9cc13b68763b5749770cb9e048a99bd9dfdfa58d0cf05d5f64afe0" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "subtle" version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.104" @@ -2986,7 +3263,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3002,6 +3279,15 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width 0.1.14", +] + [[package]] name = "thiserror" version = "2.0.12" @@ -3019,7 +3305,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3080,7 +3366,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3202,7 +3488,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3244,6 +3530,12 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + [[package]] name = "unicode-width" version = "0.2.1" @@ -3273,6 +3565,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.17.0" @@ -3284,6 +3582,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "version_check" version = "0.9.5" @@ -3346,7 +3650,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn", + "syn 2.0.104", "wasm-bindgen-shared", ] @@ -3381,7 +3685,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3415,6 +3719,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.9" @@ -3424,6 +3744,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.61.2" @@ -3445,7 +3771,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3456,7 +3782,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3644,7 +3970,7 @@ version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags", + "bitflags 2.9.1", ] [[package]] @@ -3682,7 +4008,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", "synstructure", ] @@ -3703,7 +4029,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] @@ -3723,7 +4049,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", "synstructure", ] @@ -3757,7 +4083,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.104", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 5bae9452..0235c81f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,18 @@ +[workspace] +members = [ + "benchmarks" +] + +[workspace.dependencies] +datafusion = { version = "49.0.0" } + [package] name = "datafusion-distributed" version = "0.1.0" edition = "2021" [dependencies] -datafusion = { version = "49.0.0" } +datafusion = { workspace = true } datafusion-proto = { version = "49.0.0" } arrow-flight = "55.2.0" async-trait = "0.1.88" diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml new file mode 100644 index 00000000..cd659503 --- /dev/null +++ b/benchmarks/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "datafusion-distributed-benchmarks" +version = "0.1.0" +edition = "2021" +default-run = "dfbench" + +[dependencies] +datafusion = { workspace = true } +datafusion-distributed = { path = "..", features = ["integration"] } +tokio = { version = "1.46.1", features = ["full"] } +parquet = { version = "55.2.0" } +structopt = { version = "0.3.26" } +log = "0.4.27" +serde = "1.0.219" +serde_json = "1.0.141" +env_logger = "0.11.8" +async-trait = "0.1.88" +datafusion-proto = { version = "49.0.0", optional = true } + +[[bin]] +name = "dfbench" +path = "src/bin/dfbench.rs" + +[features] +ci = [ + "datafusion-proto" +] \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..6723821f --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,17 @@ +# Distributed DataFusion Benchmarks + +### Generating tpch data + +Generate TPCH data into the `data/` dir + +```shell +./gen-tpch.sh +``` + +### Running tpch benchmarks + +After generating the data with the command above: + +```shell +cargo run -p datafusion-distributed-benchmarks --release -- tpch --path data/tpch_sf1 +``` \ No newline at end of file diff --git a/benchmarks/gen-tpch.sh b/benchmarks/gen-tpch.sh new file mode 100755 index 00000000..98ec9c88 --- /dev/null +++ b/benchmarks/gen-tpch.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash + +set -e + +SCALE_FACTOR=1 + +# https://stackoverflow.com/questions/59895/how-do-i-get-the-directory-where-a-bash-script-is-located-from-within-the-script +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} +CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} + +if [ -z "$SCALE_FACTOR" ] ; then + echo "Internal error: Scale factor not specified" + exit 1 +fi + +TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" +echo "Creating tpch dataset at Scale Factor ${SCALE_FACTOR} in ${TPCH_DIR}..." + +# Ensure the target data directory exists +mkdir -p "${TPCH_DIR}" + +# Create 'tbl' (CSV format) data into $DATA_DIR if it does not already exist +FILE="${TPCH_DIR}/supplier.tbl" +if test -f "${FILE}"; then + echo " tbl files exist ($FILE exists)." +else + echo " creating tbl files with tpch_dbgen..." + docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s "${SCALE_FACTOR}" +fi + +# Copy expected answers into the ./data/answers directory if it does not already exist +FILE="${TPCH_DIR}/answers/q1.out" +if test -f "${FILE}"; then + echo " Expected answers exist (${FILE} exists)." +else + echo " Copying answers to ${TPCH_DIR}/answers" + mkdir -p "${TPCH_DIR}/answers" + docker run -v "${TPCH_DIR}":/data -it --entrypoint /bin/bash --rm ghcr.io/scalytics/tpch-docker:main -c "cp -f /opt/tpch/2.18.0_rc2/dbgen/answers/* /data/answers/" +fi + +# Create 'parquet' files from tbl +FILE="${TPCH_DIR}/supplier" +if test -d "${FILE}"; then + echo " parquet files exist ($FILE exists)." +else + echo " creating parquet files using benchmark binary ..." + pushd "${SCRIPT_DIR}" > /dev/null + $CARGO_COMMAND -- tpch-convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet + popd > /dev/null +fi + +# Create 'csv' files from tbl +FILE="${TPCH_DIR}/csv/supplier" +if test -d "${FILE}"; then + echo " csv files exist ($FILE exists)." +else + echo " creating csv files using benchmark binary ..." + pushd "${SCRIPT_DIR}" > /dev/null + $CARGO_COMMAND -- tpch-convert --input "${TPCH_DIR}" --output "${TPCH_DIR}/csv" --format csv + popd > /dev/null +fi + diff --git a/benchmarks/queries/q1.sql b/benchmarks/queries/q1.sql new file mode 100644 index 00000000..a0fcf159 --- /dev/null +++ b/benchmarks/queries/q1.sql @@ -0,0 +1,21 @@ +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + l_shipdate <= date '1998-09-02' +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus; \ No newline at end of file diff --git a/benchmarks/queries/q10.sql b/benchmarks/queries/q10.sql new file mode 100644 index 00000000..8613fd49 --- /dev/null +++ b/benchmarks/queries/q10.sql @@ -0,0 +1,32 @@ +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1994-01-01' + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit 20; diff --git a/benchmarks/queries/q11.sql b/benchmarks/queries/q11.sql new file mode 100644 index 00000000..c23ed1c7 --- /dev/null +++ b/benchmarks/queries/q11.sql @@ -0,0 +1,27 @@ +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc; \ No newline at end of file diff --git a/benchmarks/queries/q12.sql b/benchmarks/queries/q12.sql new file mode 100644 index 00000000..f8e6d960 --- /dev/null +++ b/benchmarks/queries/q12.sql @@ -0,0 +1,30 @@ +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + lineitem + join + orders + on + l_orderkey = o_orderkey +where + l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1995-01-01' +group by + l_shipmode +order by + l_shipmode; \ No newline at end of file diff --git a/benchmarks/queries/q13.sql b/benchmarks/queries/q13.sql new file mode 100644 index 00000000..4bfe8c35 --- /dev/null +++ b/benchmarks/queries/q13.sql @@ -0,0 +1,20 @@ +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc; \ No newline at end of file diff --git a/benchmarks/queries/q14.sql b/benchmarks/queries/q14.sql new file mode 100644 index 00000000..d8ef6afa --- /dev/null +++ b/benchmarks/queries/q14.sql @@ -0,0 +1,13 @@ +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-10-01'; \ No newline at end of file diff --git a/benchmarks/queries/q15.sql b/benchmarks/queries/q15.sql new file mode 100644 index 00000000..b5cb49e5 --- /dev/null +++ b/benchmarks/queries/q15.sql @@ -0,0 +1,34 @@ +create view revenue0 (supplier_no, total_revenue) as + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey; + + +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue0 +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue0 + ) +order by + s_suppkey; + +drop view revenue0; \ No newline at end of file diff --git a/benchmarks/queries/q16.sql b/benchmarks/queries/q16.sql new file mode 100644 index 00000000..36b7c07c --- /dev/null +++ b/benchmarks/queries/q16.sql @@ -0,0 +1,30 @@ +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' +) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; \ No newline at end of file diff --git a/benchmarks/queries/q17.sql b/benchmarks/queries/q17.sql new file mode 100644 index 00000000..1e655506 --- /dev/null +++ b/benchmarks/queries/q17.sql @@ -0,0 +1,17 @@ +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey +); \ No newline at end of file diff --git a/benchmarks/queries/q18.sql b/benchmarks/queries/q18.sql new file mode 100644 index 00000000..ba7ee7f7 --- /dev/null +++ b/benchmarks/queries/q18.sql @@ -0,0 +1,33 @@ +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit 100; diff --git a/benchmarks/queries/q19.sql b/benchmarks/queries/q19.sql new file mode 100644 index 00000000..56668e73 --- /dev/null +++ b/benchmarks/queries/q19.sql @@ -0,0 +1,35 @@ +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 1 + 10 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 10 + 10 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 20 + 10 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); \ No newline at end of file diff --git a/benchmarks/queries/q2.sql b/benchmarks/queries/q2.sql new file mode 100644 index 00000000..68e478f6 --- /dev/null +++ b/benchmarks/queries/q2.sql @@ -0,0 +1,44 @@ +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' +) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit 100; diff --git a/benchmarks/queries/q20.sql b/benchmarks/queries/q20.sql new file mode 100644 index 00000000..dd61a7d8 --- /dev/null +++ b/benchmarks/queries/q20.sql @@ -0,0 +1,37 @@ +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name; \ No newline at end of file diff --git a/benchmarks/queries/q21.sql b/benchmarks/queries/q21.sql new file mode 100644 index 00000000..b95e7b0d --- /dev/null +++ b/benchmarks/queries/q21.sql @@ -0,0 +1,40 @@ +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100; diff --git a/benchmarks/queries/q22.sql b/benchmarks/queries/q22.sql new file mode 100644 index 00000000..90aea6fd --- /dev/null +++ b/benchmarks/queries/q22.sql @@ -0,0 +1,37 @@ +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone from 1 for 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone from 1 for 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; \ No newline at end of file diff --git a/benchmarks/queries/q3.sql b/benchmarks/queries/q3.sql new file mode 100644 index 00000000..e5fa9e38 --- /dev/null +++ b/benchmarks/queries/q3.sql @@ -0,0 +1,23 @@ +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < date '1995-03-15' + and l_shipdate > date '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit 10; diff --git a/benchmarks/queries/q4.sql b/benchmarks/queries/q4.sql new file mode 100644 index 00000000..74a620db --- /dev/null +++ b/benchmarks/queries/q4.sql @@ -0,0 +1,21 @@ +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority; \ No newline at end of file diff --git a/benchmarks/queries/q5.sql b/benchmarks/queries/q5.sql new file mode 100644 index 00000000..5a336b23 --- /dev/null +++ b/benchmarks/queries/q5.sql @@ -0,0 +1,24 @@ +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1995-01-01' +group by + n_name +order by + revenue desc; \ No newline at end of file diff --git a/benchmarks/queries/q6.sql b/benchmarks/queries/q6.sql new file mode 100644 index 00000000..5806f980 --- /dev/null +++ b/benchmarks/queries/q6.sql @@ -0,0 +1,9 @@ +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1995-01-01' + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; \ No newline at end of file diff --git a/benchmarks/queries/q7.sql b/benchmarks/queries/q7.sql new file mode 100644 index 00000000..512e5be5 --- /dev/null +++ b/benchmarks/queries/q7.sql @@ -0,0 +1,39 @@ +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year; diff --git a/benchmarks/queries/q8.sql b/benchmarks/queries/q8.sql new file mode 100644 index 00000000..6ddb2a67 --- /dev/null +++ b/benchmarks/queries/q8.sql @@ -0,0 +1,37 @@ +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year; \ No newline at end of file diff --git a/benchmarks/queries/q9.sql b/benchmarks/queries/q9.sql new file mode 100644 index 00000000..587bbc8a --- /dev/null +++ b/benchmarks/queries/q9.sql @@ -0,0 +1,32 @@ +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc; \ No newline at end of file diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs new file mode 100644 index 00000000..be63e844 --- /dev/null +++ b/benchmarks/src/bin/dfbench.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! DataFusion Distributed benchmark runner +use datafusion::error::Result; + +use structopt::StructOpt; + +use datafusion_distributed_benchmarks::tpch; + +#[derive(Debug, StructOpt)] +#[structopt(about = "benchmark command")] +enum Options { + Tpch(tpch::RunOpt), + TpchConvert(tpch::ConvertOpt), +} + +// Main benchmark runner entrypoint +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + + match Options::from_args() { + Options::Tpch(opt) => Box::pin(opt.run()).await, + Options::TpchConvert(opt) => opt.run().await, + } +} diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs new file mode 100644 index 00000000..5ede4af0 --- /dev/null +++ b/benchmarks/src/lib.rs @@ -0,0 +1,2 @@ +pub mod tpch; +mod util; diff --git a/benchmarks/src/tpch/convert.rs b/benchmarks/src/tpch/convert.rs new file mode 100644 index 00000000..4d036e4c --- /dev/null +++ b/benchmarks/src/tpch/convert.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::instant::Instant; +use datafusion::logical_expr::select_expr::SelectExpr; +use std::fs; +use std::path::{Path, PathBuf}; + +use datafusion::common::not_impl_err; + +use super::get_tbl_tpch_table_schema; +use super::TPCH_TABLES; +use datafusion::error::Result; +use datafusion::prelude::*; +use parquet::basic::Compression; +use parquet::file::properties::WriterProperties; +use structopt::StructOpt; + +/// Convert tpch .slt files to .parquet or .csv files +#[derive(Debug, StructOpt)] +pub struct ConvertOpt { + /// Path to csv files + #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] + input_path: PathBuf, + + /// Output path + #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] + output_path: PathBuf, + + /// Output file format: `csv` or `parquet` + #[structopt(short = "f", long = "format")] + file_format: String, + + /// Compression to use when writing Parquet files + #[structopt(short = "c", long = "compression", default_value = "zstd")] + compression: String, + + /// Number of partitions to produce + #[structopt(short = "n", long = "partitions", default_value = "1")] + partitions: usize, + + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, + + /// Sort each table by its first column in ascending order. + #[structopt(short = "t", long = "sort")] + sort: bool, +} + +impl ConvertOpt { + pub async fn run(self) -> Result<()> { + let compression = self.compression()?; + + let input_path = self.input_path.to_str().unwrap(); + let output_path = self.output_path.to_str().unwrap(); + + let output_root_path = Path::new(output_path); + for table in TPCH_TABLES { + let start = Instant::now(); + let schema = get_tbl_tpch_table_schema(table); + let key_column_name = schema.fields()[0].name(); + + let input_path = format!("{input_path}/{table}.tbl"); + let options = CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .delimiter(b'|') + .file_extension(".tbl"); + let options = if self.sort { + // indicated that the file is already sorted by its first column to speed up the conversion + options.file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = SessionConfig::new().with_batch_size(self.batch_size); + let ctx = SessionContext::new_with_config(config); + + // build plan to read the TBL file + let mut csv = ctx.read_csv(&input_path, options).await?; + + // Select all apart from the padding column + let selection = csv + .schema() + .iter() + .take(schema.fields.len() - 1) + .map(Expr::from) + .map(SelectExpr::from) + .collect::>(); + + csv = csv.select(selection)?; + // optionally, repartition the file + let partitions = self.partitions; + if partitions > 1 { + csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))? + } + let csv = if self.sort { + csv.sort_by(vec![col(key_column_name)])? + } else { + csv + }; + + // create the physical plan + let csv = csv.create_physical_plan().await?; + + let output_path = output_root_path.join(table); + let output_path = output_path.to_str().unwrap().to_owned(); + fs::create_dir_all(&output_path)?; + println!( + "Converting '{}' to {} files in directory '{}'", + &input_path, self.file_format, &output_path + ); + match self.file_format.as_str() { + "csv" => ctx.write_csv(csv, output_path).await?, + "parquet" => { + let props = WriterProperties::builder() + .set_compression(compression) + .build(); + ctx.write_parquet(csv, output_path, Some(props)).await? + } + other => { + return not_impl_err!("Invalid output format: {other}"); + } + } + println!("Conversion completed in {} ms", start.elapsed().as_millis()); + } + + Ok(()) + } + + /// return the compression method to use when writing parquet + fn compression(&self) -> Result { + Ok(match self.compression.as_str() { + "none" => Compression::UNCOMPRESSED, + "snappy" => Compression::SNAPPY, + "brotli" => Compression::BROTLI(Default::default()), + "gzip" => Compression::GZIP(Default::default()), + "lz4" => Compression::LZ4, + "lz0" => Compression::LZO, + "zstd" => Compression::ZSTD(Default::default()), + other => { + return not_impl_err!("Invalid compression format: {other}"); + } + }) + } +} diff --git a/benchmarks/src/tpch/mod.rs b/benchmarks/src/tpch/mod.rs new file mode 100644 index 00000000..d4aea119 --- /dev/null +++ b/benchmarks/src/tpch/mod.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark derived from TPC-H. This is not an official TPC-H benchmark. + +use datafusion::arrow::datatypes::SchemaBuilder; +use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + common::plan_err, + error::Result, +}; +use std::fs; +mod run; +pub use run::RunOpt; + +mod convert; +pub use convert::ConvertOpt; + +pub const TPCH_TABLES: &[&str] = &[ + "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", +]; + +pub const TPCH_QUERY_START_ID: usize = 1; +pub const TPCH_QUERY_END_ID: usize = 22; + +/// The `.tbl` file contains a trailing column +pub fn get_tbl_tpch_table_schema(table: &str) -> Schema { + let mut schema = SchemaBuilder::from(get_tpch_table_schema(table).fields); + schema.push(Field::new("__placeholder", DataType::Utf8, false)); + schema.finish() +} + +/// Get the schema for the benchmarks derived from TPC-H +pub fn get_tpch_table_schema(table: &str) -> Schema { + // note that the schema intentionally uses signed integers so that any generated Parquet + // files can also be used to benchmark tools that only support signed integers, such as + // Apache Spark + + match table { + "part" => Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Decimal128(15, 2), false), + Field::new("p_comment", DataType::Utf8, false), + ]), + + "supplier" => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Decimal128(15, 2), false), + Field::new("s_comment", DataType::Utf8, false), + ]), + + "partsupp" => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Decimal128(15, 2), false), + Field::new("ps_comment", DataType::Utf8, false), + ]), + + "customer" => Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]), + + "orders" => Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]), + + "lineitem" => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]), + + "nation" => Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]), + + "region" => Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]), + + _ => unimplemented!(), + } +} + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(query: usize) -> Result> { + if query > 0 && query < 23 { + let possibilities = vec![ + format!("queries/q{query}.sql"), + format!("benchmarks/queries/q{query}.sql"), + ]; + let mut errors = vec![]; + for filename in possibilities { + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{filename}: {e}")), + }; + } + plan_err!("invalid query. Could not find query: {:?}", errors) + } else { + plan_err!("invalid query. Expected value between 1 and 22") + } +} + +pub const QUERY_LIMIT: [Option; 22] = [ + None, + Some(100), + Some(10), + None, + None, + None, + None, + None, + None, + Some(20), + None, + None, + None, + None, + None, + None, + None, + Some(100), + None, + None, + Some(100), + None, +]; diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs new file mode 100644 index 00000000..b0ac2560 --- /dev/null +++ b/benchmarks/src/tpch/run.rs @@ -0,0 +1,522 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_QUERY_END_ID, + TPCH_QUERY_START_ID, TPCH_TABLES, +}; +use async_trait::async_trait; +use std::path::PathBuf; +use std::sync::Arc; + +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::common::instant::Instant; +use datafusion::common::utils::get_available_parallelism; +use datafusion::common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; +use datafusion::datasource::file_format::csv::CsvFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::{DataFusionError, Result}; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; + +use crate::util::{print_memory_stats, BenchmarkRun, CommonOpt, QueryResult}; +use datafusion_distributed::test_utils::localhost::start_localhost_context; +use datafusion_distributed::{DistributedPhysicalOptimizerRule, SessionBuilder}; +use log::info; +use structopt::StructOpt; + +// hack to avoid `default_value is meaningless for bool` errors +type BoolDefaultTrue = bool; + +/// Run the tpch benchmark. +/// +/// This benchmarks is derived from the [TPC-H][1] version +/// [2.17.1]. The data and answers are generated using `tpch-gen` from +/// [2]. +/// +/// [1]: http://www.tpc.org/tpch/ +/// [2]: https://github.com/databricks/tpch-dbgen.git +/// [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + pub query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// File format: `csv` or `parquet` + #[structopt(short = "f", long = "format", default_value = "csv")] + file_format: String, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[structopt(short = "S", long = "disable-statistics")] + disable_statistics: bool, + + /// If true then hash join used, if false then sort merge join + /// True by default. + #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + prefer_hash_join: BoolDefaultTrue, + + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[structopt(short = "t", long = "sorted")] + sorted: bool, + + /// Mark the first column of each table as sorted in ascending order. + /// The tables should have been created with the `--sort` option for this to have any effect. + #[structopt(long = "ppt")] + partitions_per_task: Option, +} + +#[async_trait] +impl SessionBuilder for RunOpt { + fn session_state_builder( + &self, + builder: SessionStateBuilder, + ) -> Result { + let mut config = self + .common + .config()? + .with_collect_statistics(!self.disable_statistics); + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + let rt_builder = self.common.runtime_env_builder()?; + + let mut rule = DistributedPhysicalOptimizerRule::new(); + if let Some(ppt) = self.partitions_per_task { + rule = rule.with_maximum_partitions_per_task(ppt); + } + Ok(builder + .with_config(config) + .with_physical_optimizer_rule(Arc::new(rule)) + .with_runtime_env(rt_builder.build_arc()?)) + } + + async fn session_context( + &self, + ctx: SessionContext, + ) -> std::result::Result { + self.register_tables(&ctx).await?; + Ok(ctx) + } +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + let (ctx, _guard) = start_localhost_context([50051], self.clone()).await; + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => TPCH_QUERY_START_ID..=TPCH_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id, &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + benchmark_run.mark_failed(); + eprintln!("Query {query_id} failed: {e}"); + } + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + benchmark_run.maybe_print_failures(); + Ok(()) + } + + async fn benchmark_query( + &self, + query_id: usize, + ctx: &SessionContext, + ) -> Result> { + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + + let sql = &get_query_sql(query_id)?; + + for i in 0..self.iterations() { + let start = Instant::now(); + + // query 15 is special, with 3 statements. the second statement is the one from which we + // want to capture the results + let mut result = vec![]; + if query_id == 15 { + for (n, query) in sql.iter().enumerate() { + if n == 1 { + result = self.execute_query(ctx, query).await?; + } else { + self.execute_query(ctx, query).await?; + } + } + } else { + for query in sql { + result = self.execute_query(ctx, query).await?; + } + } + + let elapsed = start.elapsed(); + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended) + print_memory_stats(); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in TPCH_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()).await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query(&self, ctx: &SessionContext, sql: &str) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()).indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table(&self, ctx: &SessionContext, table: &str) -> Result> { + let path = self.path.to_str().unwrap(); + let table_format = self.file_format.as_str(); + let target_partitions = self.partitions(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let (format, path, extension): (Arc, String, &'static str) = + match table_format { + // dbgen creates .tbl ('|' delimited) files without header + "tbl" => { + let path = format!("{path}/{table}.tbl"); + + let format = CsvFormat::default() + .with_delimiter(b'|') + .with_has_header(false); + + (Arc::new(format), path, ".tbl") + } + "csv" => { + let path = format!("{path}/csv/{table}"); + let format = CsvFormat::default() + .with_delimiter(b',') + .with_has_header(true); + + (Arc::new(format), path, DEFAULT_CSV_EXTENSION) + } + "parquet" => { + let path = format!("{path}/{table}"); + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); + + (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) + } + other => { + unimplemented!("Invalid file format '{}'", other); + } + }; + + let table_path = ListingTableUrl::parse(path)?; + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_target_partitions(target_partitions) + .with_collect_stat(state.config().collect_statistics()); + let schema = match table_format { + "parquet" => options.infer_schema(&state, &table_path).await?, + "tbl" => Arc::new(get_tbl_tpch_table_schema(table)), + "csv" => Arc::new(get_tpch_table_schema(table)), + _ => unreachable!(), + }; + let options = if self.sorted { + let key_column_name = schema.fields()[0].name(); + options.with_file_sort_order(vec![vec![col(key_column_name).sort(true, false)]]) + } else { + options + }; + + let config = ListingTableConfig::new(table_path) + .with_listing_options(options) + .with_schema(schema); + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common + .partitions + .unwrap_or_else(get_available_parallelism) + } +} + +#[cfg(test)] +// Only run with "ci" mode when we have the data +#[cfg(feature = "ci")] +mod tests { + use std::path::Path; + + use super::*; + + use datafusion::common::exec_err; + use datafusion::error::Result; + use datafusion_proto::bytes::{ + logical_plan_from_bytes, logical_plan_to_bytes, physical_plan_from_bytes, + physical_plan_to_bytes, + }; + + fn get_tpch_data_path() -> Result { + let path = std::env::var("TPCH_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return exec_err!( + "Benchmark data not found (set TPCH_DATA env var to override): {}", + path + ); + } + Ok(path) + } + + async fn round_trip_logical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_tpch_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: Some(8192), + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, + debug: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + prefer_hash_join: true, + sorted: false, + partitions_per_task: None, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(query)?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", plan.display_indent()); + let plan2_formatted = format!("{}", plan2.display_indent()); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + async fn round_trip_physical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_tpch_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: Some(8192), + mem_pool_type: "fair".to_string(), + memory_limit: None, + sort_spill_reservation_bytes: None, + debug: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + prefer_hash_join: true, + sorted: false, + partitions_per_task: None, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(query)?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.create_physical_plan().await?; + let bytes = physical_plan_to_bytes(plan.clone())?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); + let plan2_formatted = format!("{}", displayable(plan2.as_ref()).indent(false)); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + macro_rules! test_round_trip_logical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_logical_plan($query).await + } + }; + } + + macro_rules! test_round_trip_physical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_physical_plan($query).await + } + }; + } + + // logical plan tests + test_round_trip_logical!(round_trip_logical_plan_q1, 1); + test_round_trip_logical!(round_trip_logical_plan_q2, 2); + test_round_trip_logical!(round_trip_logical_plan_q3, 3); + test_round_trip_logical!(round_trip_logical_plan_q4, 4); + test_round_trip_logical!(round_trip_logical_plan_q5, 5); + test_round_trip_logical!(round_trip_logical_plan_q6, 6); + test_round_trip_logical!(round_trip_logical_plan_q7, 7); + test_round_trip_logical!(round_trip_logical_plan_q8, 8); + test_round_trip_logical!(round_trip_logical_plan_q9, 9); + test_round_trip_logical!(round_trip_logical_plan_q10, 10); + test_round_trip_logical!(round_trip_logical_plan_q11, 11); + test_round_trip_logical!(round_trip_logical_plan_q12, 12); + test_round_trip_logical!(round_trip_logical_plan_q13, 13); + test_round_trip_logical!(round_trip_logical_plan_q14, 14); + test_round_trip_logical!(round_trip_logical_plan_q15, 15); + test_round_trip_logical!(round_trip_logical_plan_q16, 16); + test_round_trip_logical!(round_trip_logical_plan_q17, 17); + test_round_trip_logical!(round_trip_logical_plan_q18, 18); + test_round_trip_logical!(round_trip_logical_plan_q19, 19); + test_round_trip_logical!(round_trip_logical_plan_q20, 20); + test_round_trip_logical!(round_trip_logical_plan_q21, 21); + test_round_trip_logical!(round_trip_logical_plan_q22, 22); + + // physical plan tests + test_round_trip_physical!(round_trip_physical_plan_q1, 1); + test_round_trip_physical!(round_trip_physical_plan_q2, 2); + test_round_trip_physical!(round_trip_physical_plan_q3, 3); + test_round_trip_physical!(round_trip_physical_plan_q4, 4); + test_round_trip_physical!(round_trip_physical_plan_q5, 5); + test_round_trip_physical!(round_trip_physical_plan_q6, 6); + test_round_trip_physical!(round_trip_physical_plan_q7, 7); + test_round_trip_physical!(round_trip_physical_plan_q8, 8); + test_round_trip_physical!(round_trip_physical_plan_q9, 9); + test_round_trip_physical!(round_trip_physical_plan_q10, 10); + test_round_trip_physical!(round_trip_physical_plan_q11, 11); + test_round_trip_physical!(round_trip_physical_plan_q12, 12); + test_round_trip_physical!(round_trip_physical_plan_q13, 13); + test_round_trip_physical!(round_trip_physical_plan_q14, 14); + test_round_trip_physical!(round_trip_physical_plan_q15, 15); + test_round_trip_physical!(round_trip_physical_plan_q16, 16); + test_round_trip_physical!(round_trip_physical_plan_q17, 17); + test_round_trip_physical!(round_trip_physical_plan_q18, 18); + test_round_trip_physical!(round_trip_physical_plan_q19, 19); + test_round_trip_physical!(round_trip_physical_plan_q20, 20); + test_round_trip_physical!(round_trip_physical_plan_q21, 21); + test_round_trip_physical!(round_trip_physical_plan_q22, 22); +} diff --git a/benchmarks/src/util/memory.rs b/benchmarks/src/util/memory.rs new file mode 100644 index 00000000..2eb7ea5f --- /dev/null +++ b/benchmarks/src/util/memory.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Print Peak RSS, Peak Commit, Page Faults based on mimalloc api +pub fn print_memory_stats() { + // removed as not used in this project. +} diff --git a/benchmarks/src/util/mod.rs b/benchmarks/src/util/mod.rs new file mode 100644 index 00000000..a38d37de --- /dev/null +++ b/benchmarks/src/util/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared benchmark utilities +mod memory; +mod options; +mod run; + +pub use memory::print_memory_stats; +pub use options::CommonOpt; +pub use run::{BenchmarkRun, QueryResult}; diff --git a/benchmarks/src/util/options.rs b/benchmarks/src/util/options.rs new file mode 100644 index 00000000..f36c175a --- /dev/null +++ b/benchmarks/src/util/options.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{num::NonZeroUsize, sync::Arc}; + +use datafusion::common::{DataFusionError, Result}; +use datafusion::{ + execution::{ + disk_manager::DiskManagerBuilder, + memory_pool::{FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool}, + runtime_env::RuntimeEnvBuilder, + }, + prelude::SessionConfig, +}; +use structopt::StructOpt; + +// Common benchmark options (don't use doc comments otherwise this doc +// shows up in help files) +#[derive(Debug, StructOpt, Clone)] +pub struct CommonOpt { + /// Number of iterations of each test run + #[structopt(short = "i", long = "iterations", default_value = "3")] + pub iterations: usize, + + /// Number of partitions to process in parallel. Defaults to number of available cores. + #[structopt(short = "n", long = "partitions")] + pub partitions: Option, + + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size")] + pub batch_size: Option, + + /// The memory pool type to use, should be one of "fair" or "greedy" + #[structopt(long = "mem-pool-type", default_value = "fair")] + pub mem_pool_type: String, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query + /// if there's any, otherwise run with no memory limit. + #[structopt(long = "memory-limit", parse(try_from_str = parse_memory_limit))] + pub memory_limit: Option, + + /// The amount of memory to reserve for sort spill operations. DataFusion's default value will be used + /// if not specified. + #[structopt(long = "sort-spill-reservation-bytes", parse(try_from_str = parse_memory_limit))] + pub sort_spill_reservation_bytes: Option, + + /// Activate debug mode to see more details + #[structopt(short, long)] + pub debug: bool, +} + +impl CommonOpt { + /// Return an appropriately configured `SessionConfig` + pub fn config(&self) -> Result { + SessionConfig::from_env().map(|config| self.update_config(config)) + } + + /// Modify the existing config appropriately + pub fn update_config(&self, mut config: SessionConfig) -> SessionConfig { + if let Some(batch_size) = self.batch_size { + config = config.with_batch_size(batch_size); + } + + if let Some(partitions) = self.partitions { + config = config.with_target_partitions(partitions); + } + + if let Some(sort_spill_reservation_bytes) = self.sort_spill_reservation_bytes { + config = config.with_sort_spill_reservation_bytes(sort_spill_reservation_bytes); + } + + config + } + + /// Return an appropriately configured `RuntimeEnvBuilder` + pub fn runtime_env_builder(&self) -> Result { + let mut rt_builder = RuntimeEnvBuilder::new(); + const NUM_TRACKED_CONSUMERS: usize = 5; + if let Some(memory_limit) = self.memory_limit { + let pool: Arc = match self.mem_pool_type.as_str() { + "fair" => Arc::new(TrackConsumersPool::new( + FairSpillPool::new(memory_limit), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + "greedy" => Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(memory_limit), + NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(), + )), + _ => { + return Err(DataFusionError::Configuration(format!( + "Invalid memory pool type: {}", + self.mem_pool_type + ))) + } + }; + rt_builder = rt_builder + .with_memory_pool(pool) + .with_disk_manager_builder(DiskManagerBuilder::default()); + } + Ok(rt_builder) + } +} + +/// Parse memory limit from string to number of bytes +/// e.g. '1.5G', '100M' -> 1572864 +fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number + .parse() + .map_err(|_| format!("Failed to parse number from memory limit '{limit}'"))?; + + match unit { + "K" => Ok((number * 1024.0) as usize), + "M" => Ok((number * 1024.0 * 1024.0) as usize), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as usize), + _ => Err(format!( + "Unsupported unit '{unit}' in memory limit '{limit}'" + )), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!(parse_memory_limit("100K").unwrap(), 102400); + assert_eq!(parse_memory_limit("1.5M").unwrap(), 1572864); + assert_eq!(parse_memory_limit("2G").unwrap(), 2147483648); + + // Test invalid unit + assert!(parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(parse_memory_limit("abcM").is_err()); + } +} diff --git a/benchmarks/src/util/run.rs b/benchmarks/src/util/run.rs new file mode 100644 index 00000000..fdbe82ab --- /dev/null +++ b/benchmarks/src/util/run.rs @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::utils::get_available_parallelism; +use datafusion::{error::Result, DATAFUSION_VERSION}; +use serde::{Serialize, Serializer}; +use serde_json::Value; +use std::{ + collections::HashMap, + path::Path, + time::{Duration, SystemTime}, +}; + +fn serialize_start_time(start_time: &SystemTime, ser: S) -> Result +where + S: Serializer, +{ + ser.serialize_u64( + start_time + .duration_since(SystemTime::UNIX_EPOCH) + .expect("current time is later than UNIX_EPOCH") + .as_secs(), + ) +} +fn serialize_elapsed(elapsed: &Duration, ser: S) -> Result +where + S: Serializer, +{ + let ms = elapsed.as_secs_f64() * 1000.0; + ser.serialize_f64(ms) +} +#[derive(Debug, Serialize)] +pub struct RunContext { + /// Benchmark crate version + pub benchmark_version: String, + /// DataFusion crate version + pub datafusion_version: String, + /// Number of CPU cores + pub num_cpus: usize, + /// Start time + #[serde(serialize_with = "serialize_start_time")] + pub start_time: SystemTime, + /// CLI arguments + pub arguments: Vec, +} + +impl Default for RunContext { + fn default() -> Self { + Self::new() + } +} + +impl RunContext { + pub fn new() -> Self { + Self { + benchmark_version: env!("CARGO_PKG_VERSION").to_owned(), + datafusion_version: DATAFUSION_VERSION.to_owned(), + num_cpus: get_available_parallelism(), + start_time: SystemTime::now(), + arguments: std::env::args().skip(1).collect::>(), + } + } +} + +/// A single iteration of a benchmark query +#[derive(Debug, Serialize)] +struct QueryIter { + #[serde(serialize_with = "serialize_elapsed")] + elapsed: Duration, + row_count: usize, +} +/// A single benchmark case +#[derive(Debug, Serialize)] +pub struct BenchQuery { + query: String, + iterations: Vec, + #[serde(serialize_with = "serialize_start_time")] + start_time: SystemTime, + success: bool, +} +/// Internal representation of a single benchmark query iteration result. +pub struct QueryResult { + pub elapsed: Duration, + pub row_count: usize, +} +/// collects benchmark run data and then serializes it at the end +pub struct BenchmarkRun { + context: RunContext, + queries: Vec, + current_case: Option, +} + +impl Default for BenchmarkRun { + fn default() -> Self { + Self::new() + } +} + +impl BenchmarkRun { + // create new + pub fn new() -> Self { + Self { + context: RunContext::new(), + queries: vec![], + current_case: None, + } + } + /// begin a new case. iterations added after this will be included in the new case + pub fn start_new_case(&mut self, id: &str) { + self.queries.push(BenchQuery { + query: id.to_owned(), + iterations: vec![], + start_time: SystemTime::now(), + success: true, + }); + if let Some(c) = self.current_case.as_mut() { + *c += 1; + } else { + self.current_case = Some(0); + } + } + /// Write a new iteration to the current case + pub fn write_iter(&mut self, elapsed: Duration, row_count: usize) { + if let Some(idx) = self.current_case { + self.queries[idx] + .iterations + .push(QueryIter { elapsed, row_count }) + } else { + panic!("no cases existed yet"); + } + } + + /// Print the names of failed queries, if any + pub fn maybe_print_failures(&self) { + let failed_queries: Vec<&str> = self + .queries + .iter() + .filter_map(|q| (!q.success).then_some(q.query.as_str())) + .collect(); + + if !failed_queries.is_empty() { + println!("Failed Queries: {}", failed_queries.join(", ")); + } + } + + /// Mark current query + pub fn mark_failed(&mut self) { + if let Some(idx) = self.current_case { + self.queries[idx].success = false; + } else { + unreachable!("Cannot mark failure: no current case"); + } + } + + /// Stringify data into formatted json + pub fn to_json(&self) -> String { + let mut output = HashMap::<&str, Value>::new(); + output.insert("context", serde_json::to_value(&self.context).unwrap()); + output.insert("queries", serde_json::to_value(&self.queries).unwrap()); + serde_json::to_string_pretty(&output).unwrap() + } + + /// Write data as json into output path if it exists. + pub fn maybe_write_json(&self, maybe_path: Option>) -> Result<()> { + if let Some(path) = maybe_path { + std::fs::write(path, self.to_json())?; + }; + Ok(()) + } +} diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 210934b1..a2b52ab5 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -10,6 +10,7 @@ use arrow_flight::flight_service_server::FlightService; use arrow_flight::Ticket; use datafusion::execution::SessionStateBuilder; use datafusion::optimizer::OptimizerConfig; +use datafusion::prelude::SessionContext; use futures::TryStreamExt; use prost::Message; use std::sync::Arc; @@ -42,8 +43,17 @@ impl ArrowFlightEndpoint { let state_builder = SessionStateBuilder::new() .with_runtime_env(Arc::clone(&self.runtime)) .with_default_features(); + let state_builder = self + .session_builder + .session_state_builder(state_builder) + .map_err(|err| datafusion_error_to_tonic_status(&err))?; - let mut state = self.session_builder.on_new_session(state_builder).build(); + let state = state_builder.build(); + let mut state = self + .session_builder + .session_state(state) + .await + .map_err(|err| datafusion_error_to_tonic_status(&err))?; let function_registry = state.function_registry().ok_or(Status::invalid_argument( "FunctionRegistry not present in newly built SessionState", @@ -55,7 +65,7 @@ impl ArrowFlightEndpoint { combined_codec.push_arc(Arc::clone(&user_codec)); } - let mut stage = stage_from_proto( + let stage = stage_from_proto( stage_msg, function_registry, &self.runtime.as_ref(), @@ -69,8 +79,16 @@ impl ArrowFlightEndpoint { config.set_extension(Arc::clone(&self.channel_manager)); config.set_extension(Arc::new(stage)); + let ctx = SessionContext::new_with_state(state); + + let ctx = self + .session_builder + .session_context(ctx) + .await + .map_err(|err| datafusion_error_to_tonic_status(&err))?; + let stream = inner_plan - .execute(doget.partition as usize, state.task_ctx()) + .execute(doget.partition as usize, ctx.task_ctx()) .map_err(|err| Status::internal(format!("Error executing stage plan: {err:#?}")))?; let flight_data_stream = FlightDataEncoderBuilder::new() diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index b773bad9..76373a90 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -6,4 +6,4 @@ mod stream_partitioner_registry; pub(crate) use do_get::DoGet; pub use service::ArrowFlightEndpoint; -pub use session_builder::SessionBuilder; +pub use session_builder::{NoopSessionBuilder, SessionBuilder}; diff --git a/src/flight_service/session_builder.rs b/src/flight_service/session_builder.rs index eb6d06f4..64be9ddd 100644 --- a/src/flight_service/session_builder.rs +++ b/src/flight_service/session_builder.rs @@ -1,7 +1,11 @@ -use datafusion::execution::SessionStateBuilder; +use async_trait::async_trait; +use datafusion::error::DataFusionError; +use datafusion::execution::{SessionState, SessionStateBuilder}; +use datafusion::prelude::SessionContext; /// Trait called by the Arrow Flight endpoint that handles distributed parts of a DataFusion /// plan for building a DataFusion's [datafusion::prelude::SessionContext]. +#[async_trait] pub trait SessionBuilder { /// Takes a [SessionStateBuilder] and adds whatever is necessary for it to work, like /// custom extension codecs, custom physical optimization rules, UDFs, UDAFs, config @@ -10,8 +14,9 @@ pub trait SessionBuilder { /// Example: adding some custom extension plan codecs /// /// ```rust - /// /// # use std::sync::Arc; + /// # use async_trait::async_trait; + /// # use datafusion::error::DataFusionError; /// # use datafusion::execution::runtime_env::RuntimeEnv; /// # use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; /// # use datafusion::physical_plan::ExecutionPlan; @@ -33,22 +38,81 @@ pub trait SessionBuilder { /// /// #[derive(Clone)] /// struct CustomSessionBuilder; + /// + /// #[async_trait] /// impl SessionBuilder for CustomSessionBuilder { - /// fn on_new_session(&self, mut builder: SessionStateBuilder) -> SessionStateBuilder { + /// fn session_state_builder(&self, mut builder: SessionStateBuilder) -> Result { /// // Add your UDFs, optimization rules, etc... - /// with_user_codec(builder, CustomExecCodec) + /// Ok(with_user_codec(builder, CustomExecCodec)) + /// } + /// } + /// ``` + fn session_state_builder( + &self, + builder: SessionStateBuilder, + ) -> Result { + Ok(builder) + } + + /// Modifies the [SessionState] and returns it. Same as [SessionBuilder::session_state_builder] + /// but operating on an already built [SessionState]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::DataFusionError; + /// # use datafusion::execution::SessionState; + /// # use datafusion_distributed::SessionBuilder; + /// + /// #[derive(Clone)] + /// struct CustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for CustomSessionBuilder { + /// async fn session_state(&self, state: SessionState) -> Result { + /// // mutate the state adding any custom logic + /// Ok(state) /// } /// } /// ``` - fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder; + async fn session_state(&self, state: SessionState) -> Result { + Ok(state) + } + + /// Modifies the [SessionContext] and returns it. Same as [SessionBuilder::session_state_builder] + /// or [SessionBuilder::session_state] but operation on an already built [SessionContext]. + /// + /// Example: + /// + /// ```rust + /// # use async_trait::async_trait; + /// # use datafusion::common::DataFusionError; + /// # use datafusion::prelude::SessionContext; + /// # use datafusion_distributed::SessionBuilder; + /// + /// #[derive(Clone)] + /// struct CustomSessionBuilder; + /// + /// #[async_trait] + /// impl SessionBuilder for CustomSessionBuilder { + /// async fn session_context(&self, ctx: SessionContext) -> Result { + /// // mutate the context adding any custom logic + /// Ok(ctx) + /// } + /// } + /// ``` + async fn session_context( + &self, + ctx: SessionContext, + ) -> Result { + Ok(ctx) + } } /// Noop implementation of the [SessionBuilder]. Used by default if no [SessionBuilder] is provided /// while building the Arrow Flight endpoint. +#[derive(Debug, Clone)] pub struct NoopSessionBuilder; -impl SessionBuilder for NoopSessionBuilder { - fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder { - builder - } -} +impl SessionBuilder for NoopSessionBuilder {} diff --git a/src/lib.rs b/src/lib.rs index 5e03e1dd..ee36dcbb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod user_provided_codec; pub mod test_utils; pub use channel_manager::{BoxCloneSyncChannel, ChannelManager, ChannelResolver}; -pub use flight_service::{ArrowFlightEndpoint, SessionBuilder}; +pub use flight_service::{ArrowFlightEndpoint, NoopSessionBuilder, SessionBuilder}; pub use physical_optimizer::DistributedPhysicalOptimizerRule; pub use plan::ArrowFlightReadExec; pub use stage::{display_stage_graphviz, ExecutionStage}; diff --git a/src/test_utils/insta.rs b/src/test_utils/insta.rs index ad7b6c40..40fdb609 100644 --- a/src/test_utils/insta.rs +++ b/src/test_utils/insta.rs @@ -1,3 +1,4 @@ +use datafusion::common::utils::get_available_parallelism; use std::env; pub use insta; diff --git a/src/test_utils/localhost.rs b/src/test_utils/localhost.rs index b33572bc..fb5ff214 100644 --- a/src/test_utils/localhost.rs +++ b/src/test_utils/localhost.rs @@ -14,14 +14,6 @@ use std::time::Duration; use tonic::transport::{Channel, Server}; use url::Url; -#[derive(Debug, Clone)] -pub struct NoopSessionBuilder; -impl SessionBuilder for NoopSessionBuilder { - fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder { - builder - } -} - pub async fn start_localhost_context( ports: I, session_builder: B, @@ -49,12 +41,16 @@ where let config = SessionConfig::new().with_target_partitions(3); - let state = SessionStateBuilder::new() + let builder = SessionStateBuilder::new() .with_default_features() - .with_config(config) - .build(); + .with_config(config); + let builder = session_builder.session_state_builder(builder).unwrap(); + + let state = builder.build(); + let state = session_builder.session_state(state).await.unwrap(); let ctx = SessionContext::new_with_state(state); + let ctx = session_builder.session_context(ctx).await.unwrap(); ctx.state_ref() .write() diff --git a/tests/custom_extension_codec.rs b/tests/custom_extension_codec.rs index 98b7d81c..c0673f87 100644 --- a/tests/custom_extension_codec.rs +++ b/tests/custom_extension_codec.rs @@ -25,8 +25,7 @@ mod tests { use datafusion_distributed::assert_snapshot; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - add_user_codec, with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, - SessionBuilder, + with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -42,14 +41,16 @@ mod tests { #[derive(Clone)] struct CustomSessionBuilder; impl SessionBuilder for CustomSessionBuilder { - fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder { - with_user_codec(builder, Int64ListExecCodec) + fn session_state_builder( + &self, + builder: SessionStateBuilder, + ) -> Result { + Ok(with_user_codec(builder, Int64ListExecCodec)) } } - let (mut ctx, _guard) = + let (ctx, _guard) = start_localhost_context([50050, 50051, 50052], CustomSessionBuilder).await; - add_user_codec(&mut ctx, Int64ListExecCodec); let single_node_plan = build_plan(false)?; assert_snapshot!(displayable(single_node_plan.as_ref()).indent(true).to_string(), @r" diff --git a/tests/distributed_aggregation.rs b/tests/distributed_aggregation.rs index 09d5e659..0d70ae43 100644 --- a/tests/distributed_aggregation.rs +++ b/tests/distributed_aggregation.rs @@ -2,12 +2,10 @@ mod tests { use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::physical_plan::{displayable, execute_stream}; - use datafusion_distributed::assert_snapshot; - use datafusion_distributed::test_utils::localhost::{ - start_localhost_context, NoopSessionBuilder, - }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; use datafusion_distributed::test_utils::plan::distribute_aggregate; + use datafusion_distributed::{assert_snapshot, NoopSessionBuilder}; use futures::TryStreamExt; use std::error::Error; diff --git a/tests/error_propagation.rs b/tests/error_propagation.rs index cd541f6b..6afb1963 100644 --- a/tests/error_propagation.rs +++ b/tests/error_propagation.rs @@ -13,8 +13,7 @@ mod tests { }; use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::{ - add_user_codec, with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, - SessionBuilder, + with_user_codec, ArrowFlightReadExec, DistributedPhysicalOptimizerRule, SessionBuilder, }; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_proto::protobuf::proto_error; @@ -30,13 +29,15 @@ mod tests { #[derive(Clone)] struct CustomSessionBuilder; impl SessionBuilder for CustomSessionBuilder { - fn on_new_session(&self, builder: SessionStateBuilder) -> SessionStateBuilder { - with_user_codec(builder, ErrorExecCodec) + fn session_state_builder( + &self, + builder: SessionStateBuilder, + ) -> Result { + Ok(with_user_codec(builder, ErrorExecCodec)) } } - let (mut ctx, _guard) = + let (ctx, _guard) = start_localhost_context([50050, 50051, 50053], CustomSessionBuilder).await; - add_user_codec(&mut ctx, ErrorExecCodec); let mut plan: Arc = Arc::new(ErrorExec::new("something failed")); diff --git a/tests/highly_distributed_query.rs b/tests/highly_distributed_query.rs index 9b6f97cb..3994ee1d 100644 --- a/tests/highly_distributed_query.rs +++ b/tests/highly_distributed_query.rs @@ -2,11 +2,9 @@ mod tests { use datafusion::physical_expr::Partitioning; use datafusion::physical_plan::{displayable, execute_stream}; - use datafusion_distributed::test_utils::localhost::{ - start_localhost_context, NoopSessionBuilder, - }; + use datafusion_distributed::test_utils::localhost::start_localhost_context; use datafusion_distributed::test_utils::parquet::register_parquet_tables; - use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec}; + use datafusion_distributed::{assert_snapshot, ArrowFlightReadExec, NoopSessionBuilder}; use futures::TryStreamExt; use std::error::Error; use std::sync::Arc;