|
20 | 20 | package org.apache.spark.sql.benchmark |
21 | 21 |
|
22 | 22 | /** |
23 | | - * Benchmark to measure Comet execution performance. To run this benchmark: |
24 | | - * `SPARK_GENERATE_BENCHMARK_FILES=1 make |
25 | | - * benchmark-org.apache.spark.sql.benchmark.CometConditionalExpressionBenchmark` Results will be |
26 | | - * written to "spark/benchmarks/CometConditionalExpressionBenchmark-**results.txt". |
| 23 | + * Benchmark to measure Comet execution performance for conditional expressions. To run this |
| 24 | + * benchmark: |
| 25 | + * {{{ |
| 26 | + * SPARK_GENERATE_BENCHMARK_FILES=1 make benchmark-org.apache.spark.sql.benchmark.CometConditionalExpressionBenchmark |
| 27 | + * }}} |
| 28 | + * Results will be written to |
| 29 | + * "spark/benchmarks/CometConditionalExpressionBenchmark-**results.txt". |
27 | 30 | */ |
28 | 31 | object CometConditionalExpressionBenchmark extends CometBenchmarkBase { |
29 | 32 |
|
30 | | - def caseWhenExprBenchmark(values: Int): Unit = { |
| 33 | + private def prepareTestTable(values: Int)(f: => Unit): Unit = { |
31 | 34 | withTempPath { dir => |
32 | 35 | withTempTable("parquetV1Table") { |
33 | | - prepareTable(dir, spark.sql(s"SELECT value AS c1 FROM $tbl")) |
| 36 | + // Create table with multiple columns for richer test scenarios: |
| 37 | + // - c1: random long values (full range) |
| 38 | + // - c2: values 0-99 for multi-branch testing |
| 39 | + // - c3: secondary column for non-literal result expressions |
| 40 | + // - c4: string column for string result expressions |
| 41 | + prepareTable( |
| 42 | + dir, |
| 43 | + spark.sql(s""" |
| 44 | + SELECT |
| 45 | + value AS c1, |
| 46 | + CAST(ABS(value % 100) AS INT) AS c2, |
| 47 | + CAST(value * 2 AS LONG) AS c3, |
| 48 | + CAST(value AS STRING) AS c4 |
| 49 | + FROM $tbl |
| 50 | + """)) |
| 51 | + f |
| 52 | + } |
| 53 | + } |
| 54 | + } |
34 | 55 |
|
35 | | - val query = |
36 | | - "select CASE WHEN c1 < 0 THEN '<0' WHEN c1 = 0 THEN '=0' ELSE '>0' END from parquetV1Table" |
| 56 | + def caseWhenLiteralBenchmark(values: Int): Unit = { |
| 57 | + prepareTestTable(values) { |
| 58 | + val query = |
| 59 | + "SELECT CASE WHEN c1 < 0 THEN '<0' WHEN c1 = 0 THEN '=0' ELSE '>0' END FROM parquetV1Table" |
| 60 | + runExpressionBenchmark("Case When Literal (3 branches)", values, query) |
| 61 | + } |
| 62 | + } |
37 | 63 |
|
38 | | - runExpressionBenchmark("Case When Expr", values, query) |
39 | | - } |
| 64 | + def caseWhenManyBranchesLiteralBenchmark(values: Int): Unit = { |
| 65 | + prepareTestTable(values) { |
| 66 | + // 10 branches using c2 (values 0-99) |
| 67 | + val query = """ |
| 68 | + SELECT CASE |
| 69 | + WHEN c2 < 10 THEN 'a' |
| 70 | + WHEN c2 < 20 THEN 'b' |
| 71 | + WHEN c2 < 30 THEN 'c' |
| 72 | + WHEN c2 < 40 THEN 'd' |
| 73 | + WHEN c2 < 50 THEN 'e' |
| 74 | + WHEN c2 < 60 THEN 'f' |
| 75 | + WHEN c2 < 70 THEN 'g' |
| 76 | + WHEN c2 < 80 THEN 'h' |
| 77 | + WHEN c2 < 90 THEN 'i' |
| 78 | + ELSE 'j' |
| 79 | + END FROM parquetV1Table |
| 80 | + """ |
| 81 | + runExpressionBenchmark("Case When Literal (10 branches)", values, query) |
40 | 82 | } |
41 | 83 | } |
42 | 84 |
|
43 | | - def ifExprBenchmark(values: Int): Unit = { |
44 | | - withTempPath { dir => |
45 | | - withTempTable("parquetV1Table") { |
46 | | - prepareTable(dir, spark.sql(s"SELECT value AS c1 FROM $tbl")) |
| 85 | + def caseWhenColumnResultBenchmark(values: Int): Unit = { |
| 86 | + prepareTestTable(values) { |
| 87 | + // Result expressions are column references, not literals |
| 88 | + val query = |
| 89 | + "SELECT CASE WHEN c1 < 0 THEN c3 WHEN c1 = 0 THEN c1 ELSE c3 + c1 END FROM parquetV1Table" |
| 90 | + runExpressionBenchmark("Case When Column Result (3 branches)", values, query) |
| 91 | + } |
| 92 | + } |
47 | 93 |
|
48 | | - val query = "select IF (c1 < 0, '<0', '>=0') from parquetV1Table" |
| 94 | + def caseWhenManyBranchesColumnResultBenchmark(values: Int): Unit = { |
| 95 | + prepareTestTable(values) { |
| 96 | + // 10 branches with column expressions as results |
| 97 | + val query = """ |
| 98 | + SELECT CASE |
| 99 | + WHEN c2 < 10 THEN c1 |
| 100 | + WHEN c2 < 20 THEN c3 |
| 101 | + WHEN c2 < 30 THEN c1 + c3 |
| 102 | + WHEN c2 < 40 THEN c1 - c2 |
| 103 | + WHEN c2 < 50 THEN c3 * 2 |
| 104 | + WHEN c2 < 60 THEN c1 / 2 |
| 105 | + WHEN c2 < 70 THEN c2 + c3 |
| 106 | + WHEN c2 < 80 THEN c1 * c2 |
| 107 | + WHEN c2 < 90 THEN c3 - c1 |
| 108 | + ELSE c1 + c2 + c3 |
| 109 | + END FROM parquetV1Table |
| 110 | + """ |
| 111 | + runExpressionBenchmark("Case When Column Result (10 branches)", values, query) |
| 112 | + } |
| 113 | + } |
49 | 114 |
|
50 | | - runExpressionBenchmark("If Expr", values, query) |
51 | | - } |
| 115 | + def ifLiteralBenchmark(values: Int): Unit = { |
| 116 | + prepareTestTable(values) { |
| 117 | + val query = "SELECT IF(c1 < 0, '<0', '>=0') FROM parquetV1Table" |
| 118 | + runExpressionBenchmark("If Literal", values, query) |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + def ifColumnResultBenchmark(values: Int): Unit = { |
| 123 | + prepareTestTable(values) { |
| 124 | + // Result expressions are column references |
| 125 | + val query = "SELECT IF(c1 < 0, c3, c1 + c3) FROM parquetV1Table" |
| 126 | + runExpressionBenchmark("If Column Result", values, query) |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + def nestedIfBenchmark(values: Int): Unit = { |
| 131 | + prepareTestTable(values) { |
| 132 | + // Nested IF expressions (equivalent to CASE WHEN with multiple branches) |
| 133 | + val query = """ |
| 134 | + SELECT IF(c2 < 25, 'a', |
| 135 | + IF(c2 < 50, 'b', |
| 136 | + IF(c2 < 75, 'c', 'd'))) |
| 137 | + FROM parquetV1Table |
| 138 | + """ |
| 139 | + runExpressionBenchmark("Nested If Literal (4 outcomes)", values, query) |
| 140 | + } |
| 141 | + } |
| 142 | + |
| 143 | + def nestedIfColumnResultBenchmark(values: Int): Unit = { |
| 144 | + prepareTestTable(values) { |
| 145 | + val query = """ |
| 146 | + SELECT IF(c2 < 25, c1, |
| 147 | + IF(c2 < 50, c3, |
| 148 | + IF(c2 < 75, c1 + c3, c3 * 2))) |
| 149 | + FROM parquetV1Table |
| 150 | + """ |
| 151 | + runExpressionBenchmark("Nested If Column Result (4 outcomes)", values, query) |
52 | 152 | } |
53 | 153 | } |
54 | 154 |
|
55 | 155 | override def runCometBenchmark(mainArgs: Array[String]): Unit = { |
56 | | - val values = 1024 * 1024; |
| 156 | + val values = 1024 * 1024 |
| 157 | + |
| 158 | + // CASE WHEN with literal results |
| 159 | + runBenchmarkWithTable("caseWhenLiteral", values) { v => |
| 160 | + caseWhenLiteralBenchmark(v) |
| 161 | + } |
| 162 | + |
| 163 | + runBenchmarkWithTable("caseWhenManyBranchesLiteral", values) { v => |
| 164 | + caseWhenManyBranchesLiteralBenchmark(v) |
| 165 | + } |
| 166 | + |
| 167 | + // CASE WHEN with column/expression results |
| 168 | + runBenchmarkWithTable("caseWhenColumnResult", values) { v => |
| 169 | + caseWhenColumnResultBenchmark(v) |
| 170 | + } |
| 171 | + |
| 172 | + runBenchmarkWithTable("caseWhenManyBranchesColumnResult", values) { v => |
| 173 | + caseWhenManyBranchesColumnResultBenchmark(v) |
| 174 | + } |
| 175 | + |
| 176 | + // IF with literal results |
| 177 | + runBenchmarkWithTable("ifLiteral", values) { v => |
| 178 | + ifLiteralBenchmark(v) |
| 179 | + } |
| 180 | + |
| 181 | + // IF with column/expression results |
| 182 | + runBenchmarkWithTable("ifColumnResult", values) { v => |
| 183 | + ifColumnResultBenchmark(v) |
| 184 | + } |
57 | 185 |
|
58 | | - runBenchmarkWithTable("caseWhenExpr", values) { v => |
59 | | - caseWhenExprBenchmark(v) |
| 186 | + // Nested IF expressions |
| 187 | + runBenchmarkWithTable("nestedIfLiteral", values) { v => |
| 188 | + nestedIfBenchmark(v) |
60 | 189 | } |
61 | 190 |
|
62 | | - runBenchmarkWithTable("ifExpr", values) { v => |
63 | | - ifExprBenchmark(v) |
| 191 | + runBenchmarkWithTable("nestedIfColumnResult", values) { v => |
| 192 | + nestedIfColumnResultBenchmark(v) |
64 | 193 | } |
65 | 194 | } |
66 | 195 | } |
0 commit comments