Skip to content

Commit fbd4872

Browse files
committed
fix(native): Replace lambda body with optimized expression in NativeExpressionOptimizer
1 parent ef9ef78 commit fbd4872

File tree

3 files changed

+163
-2
lines changed

3 files changed

+163
-2
lines changed

presto-native-sidecar-plugin/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@
8484
<scope>provided</scope>
8585
</dependency>
8686

87+
<dependency>
88+
<groupId>com.facebook.presto</groupId>
89+
<artifactId>presto-analyzer</artifactId>
90+
<scope>test</scope>
91+
</dependency>
92+
8793
<dependency>
8894
<groupId>com.facebook.airlift</groupId>
8995
<artifactId>units</artifactId>

presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/NativeExpressionOptimizer.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,12 +345,15 @@ public RowExpression visitExpression(RowExpression originalExpression, Void cont
345345
@Override
346346
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
347347
{
348-
if (canBeReplaced(lambda.getBody())) {
348+
if (canBeReplaced(lambda)) {
349+
RowExpression replacement = resolver.apply(lambda);
350+
// Sidecar optimizes only the body of lambda expression.
351+
RowExpression optimizedBody = ((LambdaDefinitionExpression) replacement).getBody().accept(this, context);
349352
return new LambdaDefinitionExpression(
350353
lambda.getSourceLocation(),
351354
lambda.getArgumentTypes(),
352355
lambda.getArguments(),
353-
toRowExpression(lambda.getSourceLocation(), resolver.apply(lambda.getBody()), lambda.getBody().getType()));
356+
toRowExpression(lambda.getSourceLocation(), optimizedBody, optimizedBody.getType()));
354357
}
355358
return lambda;
356359
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.facebook.presto.sidecar.expressions;
15+
16+
import com.facebook.airlift.bootstrap.Bootstrap;
17+
import com.facebook.airlift.json.JsonModule;
18+
import com.facebook.drift.codec.guice.ThriftCodecModule;
19+
import com.facebook.presto.block.BlockJsonSerde;
20+
import com.facebook.presto.common.block.Block;
21+
import com.facebook.presto.common.block.BlockEncoding;
22+
import com.facebook.presto.common.block.BlockEncodingManager;
23+
import com.facebook.presto.common.block.BlockEncodingSerde;
24+
import com.facebook.presto.common.type.Type;
25+
import com.facebook.presto.common.type.TypeManager;
26+
import com.facebook.presto.connector.ConnectorManager;
27+
import com.facebook.presto.metadata.FunctionAndTypeManager;
28+
import com.facebook.presto.metadata.HandleJsonModule;
29+
import com.facebook.presto.metadata.MetadataManager;
30+
import com.facebook.presto.operator.scalar.FunctionAssertions;
31+
import com.facebook.presto.sidecar.ForSidecarInfo;
32+
import com.facebook.presto.sidecar.NativeSidecarPluginQueryRunner;
33+
import com.facebook.presto.spi.NodeManager;
34+
import com.facebook.presto.spi.relation.ExpressionOptimizer;
35+
import com.facebook.presto.spi.relation.RowExpression;
36+
import com.facebook.presto.sql.TestingRowExpressionTranslator;
37+
import com.facebook.presto.sql.analyzer.FeaturesConfig;
38+
import com.facebook.presto.sql.tree.Expression;
39+
import com.facebook.presto.tests.DistributedQueryRunner;
40+
import com.facebook.presto.type.TypeDeserializer;
41+
import com.google.common.collect.ImmutableList;
42+
import com.google.inject.Injector;
43+
import com.google.inject.Module;
44+
import com.google.inject.Scopes;
45+
import org.intellij.lang.annotations.Language;
46+
import org.testng.annotations.AfterClass;
47+
import org.testng.annotations.Test;
48+
49+
import java.util.function.Function;
50+
51+
import static com.facebook.airlift.configuration.ConfigBinder.configBinder;
52+
import static com.facebook.airlift.http.client.HttpClientBinder.httpClientBinder;
53+
import static com.facebook.airlift.json.JsonBinder.jsonBinder;
54+
import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
55+
import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException;
56+
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
57+
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
58+
import static com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter.SYMBOL_TYPES;
59+
import static com.facebook.presto.sql.expressions.AbstractTestExpressionInterpreter.assertRowExpressionEvaluationEquals;
60+
import static com.google.inject.multibindings.Multibinder.newSetBinder;
61+
62+
public class TestNativeExpressionOptimizer
63+
{
64+
private final DistributedQueryRunner queryRunner;
65+
private final MetadataManager metadata;
66+
private final TestingRowExpressionTranslator translator;
67+
private final NativeExpressionOptimizer expressionOptimizer;
68+
69+
public TestNativeExpressionOptimizer()
70+
throws Exception
71+
{
72+
this.queryRunner = NativeSidecarPluginQueryRunner.getQueryRunner();
73+
FunctionAndTypeManager functionAndTypeManager = queryRunner.getCoordinator().getFunctionAndTypeManager();
74+
NodeManager nodeManager = queryRunner.getCoordinator().getPluginNodeManager();
75+
this.metadata = createTestMetadataManager(functionAndTypeManager);
76+
this.translator = new TestingRowExpressionTranslator(metadata);
77+
this.expressionOptimizer = getNativeExpressionOptimizer(functionAndTypeManager, nodeManager);
78+
}
79+
80+
@AfterClass(alwaysRun = true)
81+
public void tearDown()
82+
{
83+
closeAllRuntimeException(queryRunner);
84+
}
85+
86+
@Test
87+
public void testLambdaBodyConstantFolding()
88+
{
89+
assertOptimizedEquals("transform(ARRAY[unbound_long, unbound_long2], x -> 1 + 1)",
90+
"transform(ARRAY[unbound_long, unbound_long2], x -> 2)");
91+
assertOptimizedEquals("transform(ARRAY[unbound_long, unbound_long2], x -> cast('123' AS integer))", "transform(ARRAY[unbound_long, unbound_long2], x -> 123)");
92+
assertOptimizedEquals("transform(ARRAY[unbound_long, unbound_long2], x -> cast(json_parse('[1, 2]') AS ARRAY<INTEGER>)[1] + 1)",
93+
"transform(ARRAY[unbound_long, unbound_long2], x -> 2)");
94+
}
95+
96+
private void assertOptimizedEquals(@Language("SQL") String actual, @Language("SQL") String expected)
97+
{
98+
RowExpression optimizedActual = optimize(actual, ExpressionOptimizer.Level.OPTIMIZED);
99+
RowExpression optimizedExpected = optimize(expected, ExpressionOptimizer.Level.OPTIMIZED);
100+
assertRowExpressionEvaluationEquals(optimizedActual, optimizedExpected);
101+
}
102+
103+
private RowExpression optimize(@Language("SQL") String expression, ExpressionOptimizer.Level level)
104+
{
105+
RowExpression parsedExpression = sqlToRowExpression(expression);
106+
Function<com.facebook.presto.spi.relation.VariableReferenceExpression, Object> variableResolver = variable -> null;
107+
return expressionOptimizer.optimize(parsedExpression, level, TEST_SESSION.toConnectorSession(), variableResolver);
108+
}
109+
110+
private RowExpression sqlToRowExpression(String expression)
111+
{
112+
Expression parsedExpression = FunctionAssertions.createExpression(expression, metadata, SYMBOL_TYPES);
113+
return translator.translate(parsedExpression, SYMBOL_TYPES);
114+
}
115+
116+
private NativeExpressionOptimizer getNativeExpressionOptimizer(FunctionAndTypeManager functionAndTypeManager, NodeManager nodeManager)
117+
{
118+
Module testModule = binder -> {
119+
binder.bind(NodeManager.class).toInstance(nodeManager);
120+
binder.bind(TypeManager.class).toInstance(functionAndTypeManager);
121+
binder.install(new JsonModule());
122+
binder.install(new HandleJsonModule(functionAndTypeManager.getHandleResolver()));
123+
binder.bind(ConnectorManager.class).toProvider(() -> null).in(Scopes.SINGLETON);
124+
binder.install(new ThriftCodecModule());
125+
configBinder(binder).bindConfig(FeaturesConfig.class);
126+
127+
jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class);
128+
newSetBinder(binder, Type.class);
129+
binder.bind(com.facebook.presto.spi.function.FunctionMetadataManager.class).toInstance(functionAndTypeManager);
130+
binder.bind(com.facebook.presto.spi.function.StandardFunctionResolution.class).toInstance(
131+
new com.facebook.presto.sql.relational.FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()));
132+
binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON);
133+
newSetBinder(binder, BlockEncoding.class);
134+
jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class);
135+
jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class);
136+
jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class);
137+
jsonCodecBinder(binder).bindListJsonCodec(RowExpressionOptimizationResult.class);
138+
139+
httpClientBinder(binder).bindHttpClient("sidecar", ForSidecarInfo.class);
140+
141+
binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON);
142+
binder.bind(NativeExpressionOptimizer.class).in(Scopes.SINGLETON);
143+
};
144+
Bootstrap app = new Bootstrap(ImmutableList.of(testModule));
145+
Injector injector = app
146+
.doNotInitializeLogging()
147+
.quiet()
148+
.initialize();
149+
150+
return injector.getInstance(NativeExpressionOptimizer.class);
151+
}
152+
}

0 commit comments

Comments
 (0)