|
7 | 7 |
|
8 | 8 | package org.elasticsearch.xpack.esql.expression.function.scalar.conditional; |
9 | 9 |
|
| 10 | +import org.elasticsearch.common.breaker.CircuitBreaker; |
| 11 | +import org.elasticsearch.common.unit.ByteSizeValue; |
| 12 | +import org.elasticsearch.common.util.BigArrays; |
| 13 | +import org.elasticsearch.common.util.MockBigArrays; |
| 14 | +import org.elasticsearch.common.util.PageCacheRecycler; |
| 15 | +import org.elasticsearch.compute.data.Block; |
| 16 | +import org.elasticsearch.compute.data.BlockFactory; |
| 17 | +import org.elasticsearch.compute.data.Page; |
| 18 | +import org.elasticsearch.compute.operator.DriverContext; |
| 19 | +import org.elasticsearch.compute.operator.EvalOperator; |
10 | 20 | import org.elasticsearch.test.ESTestCase; |
| 21 | +import org.elasticsearch.xpack.esql.core.expression.Expression; |
11 | 22 | import org.elasticsearch.xpack.esql.core.expression.Literal; |
12 | 23 | import org.elasticsearch.xpack.esql.core.tree.Source; |
13 | 24 | import org.elasticsearch.xpack.esql.core.type.DataType; |
| 25 | +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; |
| 26 | +import org.junit.After; |
14 | 27 |
|
| 28 | +import java.util.ArrayList; |
| 29 | +import java.util.Collections; |
15 | 30 | import java.util.List; |
| 31 | +import java.util.function.Function; |
| 32 | +import java.util.stream.Stream; |
16 | 33 |
|
| 34 | +import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; |
17 | 35 | import static org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase.field; |
18 | 36 | import static org.hamcrest.Matchers.equalTo; |
19 | 37 | import static org.hamcrest.Matchers.sameInstance; |
@@ -166,4 +184,129 @@ public void testPartialFoldLastAfterKeepingUnknown() { |
166 | 184 | ) |
167 | 185 | ); |
168 | 186 | } |
| 187 | + |
| 188 | + public void testEvalCase() { |
| 189 | + testCase(caseExpr -> { |
| 190 | + DriverContext driverContext = driverContext(); |
| 191 | + Page page = new Page(driverContext.blockFactory().newConstantIntBlockWith(0, 1)); |
| 192 | + try ( |
| 193 | + EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(AbstractFunctionTestCase::evaluator).get(driverContext); |
| 194 | + Block block = eval.eval(page) |
| 195 | + ) { |
| 196 | + return toJavaObject(block, 0); |
| 197 | + } finally { |
| 198 | + page.releaseBlocks(); |
| 199 | + } |
| 200 | + }); |
| 201 | + } |
| 202 | + |
| 203 | + public void testFoldCase() { |
| 204 | + testCase(caseExpr -> { |
| 205 | + assertTrue(caseExpr.foldable()); |
| 206 | + return caseExpr.fold(); |
| 207 | + }); |
| 208 | + } |
| 209 | + |
| 210 | + public void testCase(Function<Case, Object> toValue) { |
| 211 | + assertEquals(1, toValue.apply(caseExpr(true, 1))); |
| 212 | + assertNull(toValue.apply(caseExpr(false, 1))); |
| 213 | + assertEquals(2, toValue.apply(caseExpr(false, 1, 2))); |
| 214 | + assertEquals(1, toValue.apply(caseExpr(true, 1, true, 2))); |
| 215 | + assertEquals(2, toValue.apply(caseExpr(false, 1, true, 2))); |
| 216 | + assertNull(toValue.apply(caseExpr(false, 1, false, 2))); |
| 217 | + assertEquals(3, toValue.apply(caseExpr(false, 1, false, 2, 3))); |
| 218 | + assertNull(toValue.apply(caseExpr(true, null, 1))); |
| 219 | + assertEquals(1, toValue.apply(caseExpr(false, null, 1))); |
| 220 | + assertEquals(1, toValue.apply(caseExpr(false, field("ignored", DataType.INTEGER), 1))); |
| 221 | + assertEquals(1, toValue.apply(caseExpr(true, 1, field("ignored", DataType.INTEGER)))); |
| 222 | + } |
| 223 | + |
| 224 | + public void testIgnoreLeadingNulls() { |
| 225 | + assertEquals(DataType.INTEGER, resolveType(false, null, 1)); |
| 226 | + assertEquals(DataType.INTEGER, resolveType(false, null, false, null, false, 2, null)); |
| 227 | + assertEquals(DataType.NULL, resolveType(false, null, null)); |
| 228 | + assertEquals(DataType.BOOLEAN, resolveType(false, null, field("bool", DataType.BOOLEAN))); |
| 229 | + } |
| 230 | + |
| 231 | + public void testCaseWithInvalidCondition() { |
| 232 | + assertEquals("expected at least two arguments in [<case>] but got 1", resolveCase(1).message()); |
| 233 | + assertEquals("first argument of [<case>] must be [boolean], found value [1] type [integer]", resolveCase(1, 2).message()); |
| 234 | + assertEquals( |
| 235 | + "third argument of [<case>] must be [boolean], found value [3] type [integer]", |
| 236 | + resolveCase(true, 2, 3, 4, 5).message() |
| 237 | + ); |
| 238 | + } |
| 239 | + |
| 240 | + public void testCaseWithIncompatibleTypes() { |
| 241 | + assertEquals("third argument of [<case>] must be [integer], found value [hi] type [keyword]", resolveCase(true, 1, "hi").message()); |
| 242 | + assertEquals( |
| 243 | + "fourth argument of [<case>] must be [integer], found value [hi] type [keyword]", |
| 244 | + resolveCase(true, 1, false, "hi", 5).message() |
| 245 | + ); |
| 246 | + assertEquals( |
| 247 | + "argument of [<case>] must be [integer], found value [hi] type [keyword]", |
| 248 | + resolveCase(true, 1, false, 2, true, 5, "hi").message() |
| 249 | + ); |
| 250 | + } |
| 251 | + |
| 252 | + public void testCaseIsLazy() { |
| 253 | + Case caseExpr = caseExpr(true, 1, true, 2); |
| 254 | + DriverContext driveContext = driverContext(); |
| 255 | + EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(child -> { |
| 256 | + Object value = child.fold(); |
| 257 | + if (value != null && value.equals(2)) { |
| 258 | + return dvrCtx -> new EvalOperator.ExpressionEvaluator() { |
| 259 | + @Override |
| 260 | + public Block eval(Page page) { |
| 261 | + fail("Unexpected evaluation of 4th argument"); |
| 262 | + return null; |
| 263 | + } |
| 264 | + |
| 265 | + @Override |
| 266 | + public void close() {} |
| 267 | + }; |
| 268 | + } |
| 269 | + return AbstractFunctionTestCase.evaluator(child); |
| 270 | + }).get(driveContext); |
| 271 | + Page page = new Page(driveContext.blockFactory().newConstantIntBlockWith(0, 1)); |
| 272 | + try (Block block = evaluator.eval(page)) { |
| 273 | + assertEquals(1, toJavaObject(block, 0)); |
| 274 | + } finally { |
| 275 | + page.releaseBlocks(); |
| 276 | + } |
| 277 | + } |
| 278 | + |
| 279 | + private static Case caseExpr(Object... args) { |
| 280 | + List<Expression> exps = Stream.of(args).<Expression>map(arg -> { |
| 281 | + if (arg instanceof Expression e) { |
| 282 | + return e; |
| 283 | + } |
| 284 | + return new Literal(Source.synthetic(arg == null ? "null" : arg.toString()), arg, DataType.fromJava(arg)); |
| 285 | + }).toList(); |
| 286 | + return new Case(Source.synthetic("<case>"), exps.get(0), exps.subList(1, exps.size())); |
| 287 | + } |
| 288 | + |
| 289 | + private static Expression.TypeResolution resolveCase(Object... args) { |
| 290 | + return caseExpr(args).resolveType(); |
| 291 | + } |
| 292 | + |
| 293 | + private static DataType resolveType(Object... args) { |
| 294 | + return caseExpr(args).dataType(); |
| 295 | + } |
| 296 | + |
| 297 | + private final List<CircuitBreaker> breakers = Collections.synchronizedList(new ArrayList<>()); |
| 298 | + |
| 299 | + protected final DriverContext driverContext() { |
| 300 | + BigArrays bigArrays = new MockBigArrays(PageCacheRecycler.NON_RECYCLING_INSTANCE, ByteSizeValue.ofMb(256)).withCircuitBreaking(); |
| 301 | + CircuitBreaker breaker = bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST); |
| 302 | + breakers.add(breaker); |
| 303 | + return new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); |
| 304 | + } |
| 305 | + |
| 306 | + @After |
| 307 | + public void allMemoryReleased() { |
| 308 | + for (CircuitBreaker breaker : breakers) { |
| 309 | + assertThat(breaker.getUsed(), equalTo(0L)); |
| 310 | + } |
| 311 | + } |
169 | 312 | } |
0 commit comments