diff --git a/.github/workflows/release-notes-check.yml b/.github/workflows/release-notes-check.yml index 2d2ed1cf3b709..d846f75dbdb26 100644 --- a/.github/workflows/release-notes-check.yml +++ b/.github/workflows/release-notes-check.yml @@ -10,6 +10,9 @@ env: jobs: check_release_note: runs-on: ubuntu-latest + concurrency: + group: ${{ github.workflow }}-check-release-note-${{ github.event.pull_request.number }} + cancel-in-progress: true steps: - name: Checkout uses: actions/checkout@v4 diff --git a/README.md b/README.md index 383325b241f6e..2da36a8a31aea 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,16 @@ After opening the project in IntelliJ, double check that the Java SDK is properl * Open the File menu and select Project Structure * In the SDKs section, ensure that a distribution of JDK 17 is selected (create one if none exist) * In the Project section, ensure the Project language level is set to at least 8.0. +* When using JDK 17, an [IntelliJ bug](https://youtrack.jetbrains.com/issue/IDEA-201168) requires you + to disable the `Use '--release' option for cross-compilation (Java 9 and later)` setting in + `Settings > Build, Execution, Deployment > Compiler > Java Compiler`. If this option remains enabled, + you may encounter errors such as: `package sun.misc does not exist` because IntelliJ fails to resolve + certain internal JDK classes. Presto comes with sample configuration that should work out-of-the-box for development. Use the following options to create a run configuration: * Main Class: `com.facebook.presto.server.PrestoServer` -* VM Options: `-ea -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit -XX:+ExplicitGCInvokesConcurrent -Xmx2G -Dconfig=etc/config.properties -Dlog.levels-file=etc/log.properties` +* VM Options: `-ea -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit -XX:+ExplicitGCInvokesConcurrent -Xmx2G -Dconfig=etc/config.properties -Dlog.levels-file=etc/log.properties -Djdk.attach.allowAttachSelf=true` * Working directory: `$MODULE_WORKING_DIR$` or `$MODULE_DIR$`(Depends your version of IntelliJ) * Use classpath of module: `presto-main` @@ -54,6 +59,32 @@ Additionally, the Hive plugin must be configured with location of your Hive meta -Dhive.metastore.uri=thrift://localhost:9083 +### Additional configuration for Java 17 + +When running with Java 17, additional `--add-opens` flags are required to allow reflective access used by certain catalogs based on which catalogs are configured. +For the default set of catalogs loaded when starting the Presto server in IntelliJ without changes, add the following flags to the **VM Options**: + + --add-opens=java.base/java.io=ALL-UNNAMED + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.lang.ref=ALL-UNNAMED + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED + --add-opens=java.base/java.net=ALL-UNNAMED + --add-opens=java.base/java.nio=ALL-UNNAMED + --add-opens=java.base/java.security=ALL-UNNAMED + --add-opens=java.base/javax.security.auth=ALL-UNNAMED + --add-opens=java.base/javax.security.auth.login=ALL-UNNAMED + --add-opens=java.base/java.text=ALL-UNNAMED + --add-opens=java.base/java.util=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED + --add-opens=java.base/java.util.regex=ALL-UNNAMED + --add-opens=java.base/jdk.internal.loader=ALL-UNNAMED + --add-opens=java.base/sun.security.action=ALL-UNNAMED + --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED + +These flags ensure that internal JDK modules are accessible at runtime for components used by Presto’s default configuration. +It is not a comprehensive list. Additional flags may need to be added, depending on the catalogs configured on the server. + ### Using SOCKS for Hive or HDFS If your Hive metastore or HDFS cluster is not directly accessible to your local machine, you can use SSH port forwarding to access it. Setup a dynamic SOCKS proxy with SSH listening on local port 1080: diff --git a/pom.xml b/pom.xml index 35983917d94a4..6e4a38ae5fdc4 100644 --- a/pom.xml +++ b/pom.xml @@ -83,7 +83,8 @@ 1.26.2 4.29.0 12.0.18 - 4.1.119.Final + 4.1.122.Final + 1.2.8 2.0 2.12.1 3.18.0 @@ -1555,6 +1556,18 @@ 0.11.5 + + io.projectreactor.netty + reactor-netty-core + ${dep.reactor-netty.version} + + + + io.projectreactor.netty + reactor-netty-http + ${dep.reactor-netty.version} + + org.apache.thrift libthrift diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 866f895863666..aec2950358687 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -150,7 +150,7 @@ To enable the ``OFFSET`` clause in SQL query expressions, set this property to ` The corresponding session property is :ref:`admin/properties-session:\`\`offset_clause_enabled\`\``. ``max-serializable-object-size`` -^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * **Type:** ``long`` * **Default value:** ``1000`` diff --git a/presto-docs/src/main/sphinx/connector/hive.rst b/presto-docs/src/main/sphinx/connector/hive.rst index 461c28e815816..7e427b9834b3b 100644 --- a/presto-docs/src/main/sphinx/connector/hive.rst +++ b/presto-docs/src/main/sphinx/connector/hive.rst @@ -1005,12 +1005,16 @@ Hive catalog is called ``web``:: CALL web.system.example_procedure() -The following procedures are available: +Create Empty Partition +^^^^^^^^^^^^^^^^^^^^^^ * ``system.create_empty_partition(schema_name, table_name, partition_columns, partition_values)`` Create an empty partition in the specified table. +Sync Partition Metadata +^^^^^^^^^^^^^^^^^^^^^^^ + * ``system.sync_partition_metadata(schema_name, table_name, mode, case_sensitive)`` Check and update partitions list in metastore. There are three modes available: @@ -1024,6 +1028,9 @@ The following procedures are available: file system paths to use lowercase (e.g. ``col_x=SomeValue``). Partitions on the file system not conforming to this convention are ignored, unless the argument is set to ``false``. +Invalidate Directory List Cache +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + * ``system.invalidate_directory_list_cache()`` Flush full directory list cache. @@ -1032,6 +1039,9 @@ The following procedures are available: Invalidate directory list cache for specified directory_path. +Invalidate Metastore Cache +^^^^^^^^^^^^^^^^^^^^^^^^^^ + * ``system.invalidate_metastore_cache()`` Invalidate all metastore caches. @@ -1048,8 +1058,10 @@ The following procedures are available: Invalidate all metastore cache entries linked to a specific partition. -Note: To enable ``system.invalidate_metastore_cache`` procedure, please refer to the properties that -apply to Hive Metastore and are listed in the `Metastore Configuration Properties`_ table. + .. note:: + + To enable ``system.invalidate_metastore_cache`` procedure, ``hive.invalidate-metastore-cache-procedure-enabled`` must be set to ``true``. + See the properties in `Metastore Configuration Properties`_ table for more information. Extra Hidden Columns -------------------- @@ -1064,22 +1076,22 @@ columns as a part of the query like any other columns of the table. How to invalidate metastore cache? ---------------------------------- -The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.metastore.CachingHiveMetastore#flushCache``) to invalidate the metastore cache. -You can call this procedure to invalidate the metastore cache by connecting via jconsole or jmxterm. +Invalidating metastore cache is useful when the Hive metastore is updated outside of Presto and you want to make the changes visible to Presto immediately. +There are a couple of ways for invalidating this cache and are listed below - -This is useful when the Hive metastore is updated outside of Presto and you want to make the changes visible to Presto immediately. +* The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.metastore.InMemoryCachingHiveMetastore#invalidateAll``) to invalidate the metastore cache. You can call this procedure to invalidate the metastore cache by connecting via jconsole or jmxterm. However, this procedure flushes the cache for all the tables in all the schemas. -Currently, this procedure flushes the cache for all the tables in all the schemas. This is a known limitation and will be enhanced in the future. +* The Hive connector exposes ``system.invalidate_metastore_cache`` procedure which enables users to invalidate the metastore cache completely or partially as per the requirement and can be invoked with various arguments. See `Invalidate Metastore Cache`_ for more information. How to invalidate directory list cache? --------------------------------------- -The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.HiveDirectoryLister#flushCache``) to invalidate the directory list cache. -You can call this procedure to invalidate the directory list cache by connecting via jconsole or jmxterm. +Invalidating directory list cache is useful when the files are added or deleted in the cache directory path and you want to make the changes visible to Presto immediately. +There are a couple of ways for invalidating this cache and are listed below - -This is useful when the files are added or deleted in the cache directory path and you want to make the changes visible to Presto immediately. +* The Hive connector exposes a procedure over JMX (``com.facebook.presto.hive.CachingDirectoryLister#flushCache``) to invalidate the directory list cache. You can call this procedure to invalidate the directory list cache by connecting via jconsole or jmxterm. This procedure flushes all the cache entries. -Currently, this procedure flushes all the cache entries. This is a known limitation and will be enhanced in the future. +* The Hive connector exposes ``system.invalidate_directory_list_cache`` procedure which gives the flexibility to invalidate the list cache completely or partially as per the requirement and can be invoked in various ways. See `Invalidate Directory List Cache`_ for more information. Examples -------- diff --git a/presto-docs/src/main/sphinx/connector/iceberg.rst b/presto-docs/src/main/sphinx/connector/iceberg.rst index 990abcb4c367d..40a0f899c2634 100644 --- a/presto-docs/src/main/sphinx/connector/iceberg.rst +++ b/presto-docs/src/main/sphinx/connector/iceberg.rst @@ -2208,7 +2208,7 @@ Sorting can be combined with partitioning on the same column. For example:: sorted_by = ARRAY['join_date'] ) -The Iceberg connector does not support sort order transforms. The following sort order transformations are not supported: +Sort order does not support transforms. The following transforms are not supported: .. code-block:: text diff --git a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java index bee080b4e64ba..aa122e1c28f1f 100644 --- a/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/HiveErrorCode.java @@ -76,6 +76,7 @@ public enum HiveErrorCode HIVE_RANGER_SERVER_ERROR(48, EXTERNAL), HIVE_FUNCTION_INITIALIZATION_ERROR(49, EXTERNAL), HIVE_METASTORE_INITIALIZE_SSL_ERROR(50, EXTERNAL), + UNKNOWN_TABLE_TYPE(51, EXTERNAL), /**/; private final ErrorCode errorCode; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/UnknownTableTypeException.java b/presto-hive-common/src/main/java/com/facebook/presto/hive/UnknownTableTypeException.java similarity index 68% rename from presto-iceberg/src/main/java/com/facebook/presto/iceberg/UnknownTableTypeException.java rename to presto-hive-common/src/main/java/com/facebook/presto/hive/UnknownTableTypeException.java index b7d6323ec6ff2..cb740e975b535 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/UnknownTableTypeException.java +++ b/presto-hive-common/src/main/java/com/facebook/presto/hive/UnknownTableTypeException.java @@ -11,18 +11,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.facebook.presto.iceberg; +package com.facebook.presto.hive; import com.facebook.presto.spi.PrestoException; -import com.facebook.presto.spi.SchemaTableName; -import static com.facebook.presto.iceberg.IcebergErrorCode.ICEBERG_UNKNOWN_TABLE_TYPE; +import static com.facebook.presto.hive.HiveErrorCode.UNKNOWN_TABLE_TYPE; public class UnknownTableTypeException extends PrestoException { - public UnknownTableTypeException(SchemaTableName tableName) + public UnknownTableTypeException(String message) { - super(ICEBERG_UNKNOWN_TABLE_TYPE, "Not an Iceberg table: " + tableName); + super(UNKNOWN_TABLE_TYPE, message); } } diff --git a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java index 7f776dd043ee4..e0d07e7ce4a58 100644 --- a/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java +++ b/presto-hive/src/main/java/com/facebook/presto/hive/HiveMetadata.java @@ -15,6 +15,7 @@ import com.facebook.airlift.json.JsonCodec; import com.facebook.airlift.json.smile.SmileCodec; +import com.facebook.airlift.log.Logger; import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.Page; import com.facebook.presto.common.Subfield; @@ -389,6 +390,7 @@ public class HiveMetadata implements TransactionalMetadata { + private static final Logger log = Logger.get(HiveMetadata.class); public static final Set RESERVED_ROLES = ImmutableSet.of("all", "default", "none"); public static final String REFERENCED_MATERIALIZED_VIEWS = "referenced_materialized_views"; @@ -674,7 +676,7 @@ private ConnectorTableMetadata getTableMetadata(Optional table, SchemaTab } if (isIcebergTable(table.get()) || isDeltaLakeTable(table.get())) { - throw new PrestoException(HIVE_UNSUPPORTED_FORMAT, format("Not a Hive table '%s'", tableName)); + throw new UnknownTableTypeException("Not a Hive table: " + tableName); } List> tableConstraints = metastore.getTableConstraints(metastoreContext, tableName.getSchemaName(), tableName.getTableName()); @@ -862,6 +864,9 @@ public Map> listTableColumns(ConnectorSess catch (TableNotFoundException e) { // table disappeared during listing operation } + catch (UnknownTableTypeException e) { + log.warn(String.format("%s: Unknown table type of table %s", e.getMessage(), tableName)); + } } return columns.build(); } diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java index 2b63f666fd38d..b608c5e93d332 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/HiveTableOperations.java @@ -17,6 +17,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.presto.hive.HdfsContext; import com.facebook.presto.hive.HdfsEnvironment; +import com.facebook.presto.hive.UnknownTableTypeException; import com.facebook.presto.hive.metastore.ExtendedHiveMetastore; import com.facebook.presto.hive.metastore.HivePrivilegeInfo; import com.facebook.presto.hive.metastore.MetastoreContext; @@ -215,7 +216,7 @@ public TableMetadata refresh() Table table = getTable(); if (!isIcebergTable(table)) { - throw new UnknownTableTypeException(getSchemaTableName()); + throw new UnknownTableTypeException("Not an Iceberg table: " + getSchemaTableName()); } if (isPrestoView(table)) { diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java index 6f6cc902f101c..920013e55289d 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergAbstractMetadata.java @@ -28,6 +28,7 @@ import com.facebook.presto.hive.HiveOutputMetadata; import com.facebook.presto.hive.HivePartition; import com.facebook.presto.hive.NodeVersion; +import com.facebook.presto.hive.UnknownTableTypeException; import com.facebook.presto.iceberg.changelog.ChangelogOperation; import com.facebook.presto.iceberg.changelog.ChangelogUtil; import com.facebook.presto.iceberg.statistics.StatisticsFileCache; diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java index c833a65eaf62f..2d78252763f7f 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergErrorCode.java @@ -24,7 +24,6 @@ public enum IcebergErrorCode implements ErrorCodeSupplier { - ICEBERG_UNKNOWN_TABLE_TYPE(0, EXTERNAL), ICEBERG_INVALID_METADATA(1, EXTERNAL), ICEBERG_TOO_MANY_OPEN_PARTITIONS(2, USER_ERROR), ICEBERG_INVALID_PARTITION_VALUE(3, EXTERNAL), diff --git a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java index 1a6ac458ea456..4e94fdd1fe59e 100644 --- a/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java +++ b/presto-iceberg/src/main/java/com/facebook/presto/iceberg/IcebergHiveMetadata.java @@ -24,6 +24,7 @@ import com.facebook.presto.hive.HiveTypeTranslator; import com.facebook.presto.hive.NodeVersion; import com.facebook.presto.hive.TableAlreadyExistsException; +import com.facebook.presto.hive.UnknownTableTypeException; import com.facebook.presto.hive.ViewAlreadyExistsException; import com.facebook.presto.hive.metastore.Column; import com.facebook.presto.hive.metastore.Database; @@ -230,7 +231,7 @@ protected boolean tableExists(ConnectorSession session, SchemaTableName schemaTa return false; } if (!isIcebergTable(hiveTable.get())) { - throw new UnknownTableTypeException(schemaTableName); + throw new UnknownTableTypeException("Not an Iceberg table: " + schemaTableName); } return true; } diff --git a/presto-jdbc/pom.xml b/presto-jdbc/pom.xml index 7099d5ea4bb80..db69553a52b2e 100644 --- a/presto-jdbc/pom.xml +++ b/presto-jdbc/pom.xml @@ -123,6 +123,12 @@ test + + com.facebook.presto + presto-memory + test + + com.facebook.presto presto-main-base @@ -218,11 +224,13 @@ jjwt-api test + io.jsonwebtoken jjwt-impl test + io.jsonwebtoken jjwt-jackson diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java index a49db1ea6e40a..73f00a19c1833 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoDatabaseMetaData.java @@ -1144,8 +1144,7 @@ public boolean insertsAreDetected(int type) public boolean supportsBatchUpdates() throws SQLException { - // TODO: support batch updates - return false; + return true; } @Override diff --git a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java index 9720e6b167e0f..bd1af7ea2233b 100644 --- a/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java +++ b/presto-jdbc/src/main/java/com/facebook/presto/jdbc/PrestoPreparedStatement.java @@ -23,6 +23,7 @@ import java.math.BigDecimal; import java.net.URL; import java.sql.Array; +import java.sql.BatchUpdateException; import java.sql.Blob; import java.sql.Clob; import java.sql.Date; @@ -42,6 +43,7 @@ import java.sql.Timestamp; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.HashMap; import java.util.List; @@ -72,9 +74,11 @@ public class PrestoPreparedStatement implements PreparedStatement { private final Map parameters = new HashMap<>(); + private final List> batchValues = new ArrayList<>(); private final String statementName; private final String originalSql; private boolean isClosed; + private boolean isBatch; PrestoPreparedStatement(PrestoConnection connection, String statementName, String sql) throws SQLException @@ -101,7 +105,8 @@ public void close() public ResultSet executeQuery() throws SQLException { - if (!super.execute(getExecuteSql())) { + requireNonBatchStatement(); + if (!super.execute(getExecuteSql(statementName, toValues(parameters)))) { throw new SQLException("Prepared SQL statement is not a query: " + originalSql); } return getResultSet(); @@ -111,6 +116,7 @@ public ResultSet executeQuery() public int executeUpdate() throws SQLException { + requireNonBatchStatement(); return Ints.saturatedCast(executeLargeUpdate()); } @@ -118,7 +124,8 @@ public int executeUpdate() public long executeLargeUpdate() throws SQLException { - if (super.execute(getExecuteSql())) { + requireNonBatchStatement(); + if (super.execute(getExecuteSql(statementName, toValues(parameters)))) { throw new SQLException("Prepared SQL is not an update statement: " + originalSql); } return getLargeUpdateCount(); @@ -128,7 +135,8 @@ public long executeLargeUpdate() public boolean execute() throws SQLException { - return super.execute(getExecuteSql()); + requireNonBatchStatement(); + return super.execute(getExecuteSql(statementName, toValues(parameters))); } @Override @@ -430,7 +438,41 @@ else if (x instanceof Timestamp) { public void addBatch() throws SQLException { - throw new NotImplementedException("PreparedStatement", "addBatch"); + checkOpen(); + batchValues.add(toValues(parameters)); + isBatch = true; + } + + @Override + public void clearBatch() + throws SQLException + { + checkOpen(); + batchValues.clear(); + isBatch = false; + } + + @Override + public int[] executeBatch() + throws SQLException + { + try { + int[] batchUpdateCounts = new int[batchValues.size()]; + for (int i = 0; i < batchValues.size(); i++) { + try { + super.execute(getExecuteSql(statementName, batchValues.get(i))); + batchUpdateCounts[i] = getUpdateCount(); + } + catch (SQLException e) { + long[] updateCounts = Arrays.stream(batchUpdateCounts).mapToLong(j -> j).toArray(); + throw new BatchUpdateException(e.getMessage(), e.getSQLState(), e.getErrorCode(), updateCounts, e.getCause()); + } + } + return batchUpdateCounts; + } + finally { + clearBatch(); + } } @Override @@ -759,27 +801,34 @@ private void setParameter(int parameterIndex, String value) parameters.put(parameterIndex - 1, value); } - private void formatParametersTo(StringBuilder builder) + private static List toValues(Map parameters) throws SQLException { - List values = new ArrayList<>(); + ImmutableList.Builder values = ImmutableList.builder(); for (int index = 0; index < parameters.size(); index++) { if (!parameters.containsKey(index)) { throw new SQLException("No value specified for parameter " + (index + 1)); } values.add(parameters.get(index)); } - Joiner.on(", ").appendTo(builder, values); + return values.build(); } - private String getExecuteSql() + private void requireNonBatchStatement() throws SQLException + { + if (isBatch) { + throw new SQLException("Batch prepared statement must be executed using executeBatch method"); + } + } + + private static String getExecuteSql(String statementName, List values) { StringBuilder sql = new StringBuilder(); sql.append("EXECUTE ").append(statementName); - if (!parameters.isEmpty()) { + if (!values.isEmpty()) { sql.append(" USING "); - formatParametersTo(sql); + Joiner.on(", ").appendTo(sql, values); } return sql.toString(); } diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java index 0cfa4668d81bf..6d46f36475e61 100644 --- a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestJdbcPreparedStatement.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logging; import com.facebook.presto.plugin.blackhole.BlackHolePlugin; +import com.facebook.presto.plugin.memory.MemoryPlugin; import com.facebook.presto.server.testing.TestingPrestoServer; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -40,10 +41,13 @@ import static com.facebook.presto.jdbc.TestPrestoDriver.closeQuietly; import static com.facebook.presto.jdbc.TestPrestoDriver.waitForNodeRefresh; +import static com.facebook.presto.jdbc.TestingJdbcUtils.list; +import static com.facebook.presto.jdbc.TestingJdbcUtils.readRows; import static com.google.common.base.Strings.repeat; import static com.google.common.primitives.Ints.asList; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -61,7 +65,9 @@ public void setup() Logging.initialize(); server = new TestingPrestoServer(); server.installPlugin(new BlackHolePlugin()); + server.installPlugin(new MemoryPlugin()); server.createCatalog("blackhole", "blackhole"); + server.createCatalog("memory", "memory"); waitForNodeRefresh(server); try (Connection connection = createConnection(); @@ -636,6 +642,88 @@ public void testInvalidConversions() assertInvalidConversion((ps, i) -> ps.setObject(i, "abc", Types.SMALLINT), "Cannot convert instance of java.lang.String to SQL type " + Types.SMALLINT); } + @Test + public void testExecuteBatch() + throws Exception + { + try (Connection connection = createConnection("memory", "default")) { + try (Statement statement = connection.createStatement()) { + statement.execute("CREATE TABLE test_execute_batch(c_int integer)"); + } + + try (PreparedStatement preparedStatement = connection.prepareStatement( + "INSERT INTO test_execute_batch VALUES (?)")) { + // Run executeBatch before addBatch + assertEquals(preparedStatement.executeBatch(), new int[] {}); + + for (int i = 0; i < 3; i++) { + preparedStatement.setInt(1, i); + preparedStatement.addBatch(); + } + assertEquals(preparedStatement.executeBatch(), new int[] {1, 1, 1}); + + try (Statement statement = connection.createStatement()) { + ResultSet resultSet = statement.executeQuery("SELECT c_int FROM test_execute_batch"); + assertThat(readRows(resultSet)) + .containsExactlyInAnyOrder( + list(0), + list(1), + list(2)); + } + + // Make sure the above executeBatch cleared existing batch + assertEquals(preparedStatement.executeBatch(), new int[] {}); + + // clearBatch removes added batch and cancel batch mode + preparedStatement.setBoolean(1, true); + preparedStatement.clearBatch(); + assertEquals(preparedStatement.executeBatch(), new int[] {}); + + preparedStatement.setInt(1, 1); + assertEquals(preparedStatement.executeUpdate(), 1); + } + + try (Statement statement = connection.createStatement()) { + statement.execute("DROP TABLE test_execute_batch"); + } + } + } + + @Test + public void testInvalidExecuteBatch() + throws Exception + { + try (Connection connection = createConnection("blackhole", "blackhole")) { + try (Statement statement = connection.createStatement()) { + statement.execute("CREATE TABLE test_invalid_execute_batch(c_int integer)"); + } + + try (PreparedStatement statement = connection.prepareStatement( + "INSERT INTO test_invalid_execute_batch VALUES (?)")) { + statement.setInt(1, 1); + statement.addBatch(); + + String message = "Batch prepared statement must be executed using executeBatch method"; + assertThatThrownBy(statement::executeQuery) + .isInstanceOf(SQLException.class) + .hasMessage(message); + assertThatThrownBy(statement::executeUpdate) + .isInstanceOf(SQLException.class) + .hasMessage(message); + assertThatThrownBy(statement::executeLargeUpdate) + .isInstanceOf(SQLException.class) + .hasMessage(message); + assertThatThrownBy(statement::execute) + .isInstanceOf(SQLException.class) + .hasMessage(message); + } + + try (Statement statement = connection.createStatement()) { + statement.execute("DROP TABLE test_invalid_execute_batch"); + } + } + } + private void assertInvalidConversion(Binder binder, String message) { assertThatThrownBy(() -> assertParameter(null, Types.NULL, binder)) diff --git a/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestingJdbcUtils.java b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestingJdbcUtils.java new file mode 100644 index 0000000000000..26ad29b209e50 --- /dev/null +++ b/presto-jdbc/src/test/java/com/facebook/presto/jdbc/TestingJdbcUtils.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.jdbc; + +import com.google.common.collect.ImmutableList; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Arrays.asList; + +public class TestingJdbcUtils +{ + private TestingJdbcUtils() {} + + public static List> readRows(ResultSet rs) + throws SQLException + { + ImmutableList.Builder> rows = ImmutableList.builder(); + int columnCount = rs.getMetaData().getColumnCount(); + while (rs.next()) { + List row = new ArrayList<>(); + for (int i = 1; i <= columnCount; i++) { + row.add(rs.getObject(i)); + } + rows.add(row); + } + return rows.build(); + } + + @SafeVarargs + public static List list(T... elements) + { + return asList(elements); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index ce1fb24c3fb1d..1a30a76798d37 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -255,6 +255,7 @@ public final class SystemSessionProperties public static final String MAX_STAGE_COUNT_FOR_EAGER_SCHEDULING = "max_stage_count_for_eager_scheduling"; public static final String HYPERLOGLOG_STANDARD_ERROR_WARNING_THRESHOLD = "hyperloglog_standard_error_warning_threshold"; public static final String PREFER_MERGE_JOIN_FOR_SORTED_INPUTS = "prefer_merge_join_for_sorted_inputs"; + public static final String PREFER_SORT_MERGE_JOIN = "prefer_sort_merge_join"; public static final String SEGMENTED_AGGREGATION_ENABLED = "segmented_aggregation_enabled"; public static final String USE_HISTORY_BASED_PLAN_STATISTICS = "use_history_based_plan_statistics"; public static final String TRACK_HISTORY_BASED_PLAN_STATISTICS = "track_history_based_plan_statistics"; @@ -1386,6 +1387,11 @@ public SystemSessionProperties( "To make it work, the connector needs to guarantee and expose the data properties of the underlying table.", featuresConfig.isPreferMergeJoinForSortedInputs(), true), + booleanProperty( + PREFER_SORT_MERGE_JOIN, + "Prefer sort merge join for all joins. A SortNode is added if input is not already sorted.", + featuresConfig.isPreferSortMergeJoin(), + true), booleanProperty( SEGMENTED_AGGREGATION_ENABLED, "Enable segmented aggregation.", @@ -2881,6 +2887,11 @@ public static boolean preferMergeJoinForSortedInputs(Session session) return session.getSystemProperty(PREFER_MERGE_JOIN_FOR_SORTED_INPUTS, Boolean.class); } + public static boolean preferSortMergeJoin(Session session) + { + return session.getSystemProperty(PREFER_SORT_MERGE_JOIN, Boolean.class); + } + public static boolean isSegmentedAggregationEnabled(Session session) { return session.getSystemProperty(SEGMENTED_AGGREGATION_ENABLED, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java index 40de75c838169..e8638f342a17b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java @@ -24,18 +24,26 @@ import java.util.List; import java.util.Objects; +import static com.facebook.presto.metadata.BuiltInFunctionKind.ENGINE; import static java.util.Objects.requireNonNull; public class BuiltInFunctionHandle implements FunctionHandle { private final Signature signature; + private final BuiltInFunctionKind builtInFunctionKind; @JsonCreator public BuiltInFunctionHandle(@JsonProperty("signature") Signature signature) + { + this(signature, ENGINE); + } + + public BuiltInFunctionHandle(Signature signature, BuiltInFunctionKind builtInFunctionKind) { this.signature = requireNonNull(signature, "signature is null"); checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + this.builtInFunctionKind = requireNonNull(builtInFunctionKind, "builtInFunctionKind is null"); } @JsonProperty @@ -68,6 +76,12 @@ public CatalogSchemaName getCatalogSchemaName() return signature.getName().getCatalogSchemaName(); } + @JsonProperty + public BuiltInFunctionKind getBuiltInFunctionKind() + { + return builtInFunctionKind; + } + @Override public boolean equals(Object o) { @@ -78,13 +92,14 @@ public boolean equals(Object o) return false; } BuiltInFunctionHandle that = (BuiltInFunctionHandle) o; - return Objects.equals(signature, that.signature); + return Objects.equals(signature, that.signature) + && Objects.equals(builtInFunctionKind, that.builtInFunctionKind); } @Override public int hashCode() { - return Objects.hash(signature); + return Objects.hash(signature, builtInFunctionKind); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionKind.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionKind.java new file mode 100644 index 0000000000000..4d12bb7f97d16 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInFunctionKind.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.drift.annotations.ThriftEnum; +import com.facebook.drift.annotations.ThriftEnumValue; + +@ThriftEnum +public enum BuiltInFunctionKind +{ + ENGINE(0), + PLUGIN(1); + + private final int value; + + BuiltInFunctionKind(int value) + { + this.value = value; + } + + @ThriftEnumValue + public int getValue() + { + return value; + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInPluginFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInPluginFunctionNamespaceManager.java new file mode 100644 index 0000000000000..6a694f436f17e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInPluginFunctionNamespaceManager.java @@ -0,0 +1,259 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.function.SqlFunctionResult; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.common.type.UserDefinedType; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.function.AlterRoutineCharacteristics; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.FunctionNamespaceManager; +import com.facebook.presto.spi.function.FunctionNamespaceTransactionHandle; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.ScalarFunctionImplementation; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation; +import com.google.common.base.Supplier; +import com.google.common.base.Suppliers; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.util.concurrent.UncheckedExecutionException; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static com.facebook.presto.metadata.BuiltInFunctionKind.PLUGIN; +import static com.facebook.presto.spi.function.FunctionImplementationType.SQL; +import static com.facebook.presto.spi.function.FunctionKind.SCALAR; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.HOURS; + +public class BuiltInPluginFunctionNamespaceManager + implements FunctionNamespaceManager +{ + private volatile FunctionMap functions = new FunctionMap(); + private final FunctionAndTypeManager functionAndTypeManager; + private final Supplier cachedFunctions = + Suppliers.memoize(this::checkForNamingConflicts); + private final LoadingCache specializedFunctionKeyCache; + private final LoadingCache specializedScalarCache; + + public BuiltInPluginFunctionNamespaceManager(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + specializedFunctionKeyCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS) + .build(CacheLoader.from(this::doGetSpecializedFunctionKey)); + specializedScalarCache = CacheBuilder.newBuilder() + .maximumSize(1000) + .expireAfterWrite(1, HOURS) + .build(CacheLoader.from(key -> { + checkArgument( + key.getFunction() instanceof SqlInvokedFunction, + "Unsupported scalar function class: %s", + key.getFunction().getClass()); + return new SqlInvokedScalarFunctionImplementation(((SqlInvokedFunction) key.getFunction()).getBody()); + })); + } + + public synchronized void registerPluginFunctions(List functions) + { + checkForNamingConflicts(functions); + this.functions = new FunctionMap(this.functions, functions); + } + + @Override + public FunctionHandle getFunctionHandle(Optional transactionHandle, Signature signature) + { + return new BuiltInFunctionHandle(signature, PLUGIN); + } + + @Override + public Collection getFunctions(Optional transactionHandle, QualifiedObjectName functionName) + { + if (functions.list().isEmpty() || + (!functionName.getCatalogSchemaName().equals(functionAndTypeManager.getDefaultNamespace()))) { + return emptyList(); + } + return cachedFunctions.get().get(functionName); + } + + /** + * likePattern / escape is not used for optimization, returning all functions. + */ + @Override + public Collection listFunctions(Optional likePattern, Optional escape) + { + return cachedFunctions.get().list(); + } + + public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) + { + checkArgument(functionHandle instanceof BuiltInFunctionHandle, "Expect BuiltInFunctionHandle"); + Signature signature = ((BuiltInFunctionHandle) functionHandle).getSignature(); + SpecializedFunctionKey functionKey; + try { + functionKey = specializedFunctionKeyCache.getUnchecked(signature); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + SqlFunction function = functionKey.getFunction(); + checkArgument(function instanceof SqlInvokedFunction, "BuiltInPluginFunctionNamespaceManager only support SqlInvokedFunctions"); + SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; + List argumentNames = sqlFunction.getParameters().stream().map(Parameter::getName).collect(toImmutableList()); + return new FunctionMetadata( + signature.getName(), + signature.getArgumentTypes(), + argumentNames, + signature.getReturnType(), + signature.getKind(), + sqlFunction.getRoutineCharacteristics().getLanguage(), + SQL, + function.isDeterministic(), + function.isCalledOnNullInput(), + sqlFunction.getVersion(), + sqlFunction.getComplexTypeFunctionDescriptor()); + } + + public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHandle functionHandle) + { + checkArgument(functionHandle instanceof BuiltInFunctionHandle, "Expect BuiltInFunctionHandle"); + return getScalarFunctionImplementation(((BuiltInFunctionHandle) functionHandle).getSignature()); + } + + @Override + public void setBlockEncodingSerde(BlockEncodingSerde blockEncodingSerde) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public FunctionNamespaceTransactionHandle beginTransaction() + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void commit(FunctionNamespaceTransactionHandle transactionHandle) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void abort(FunctionNamespaceTransactionHandle transactionHandle) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void createFunction(SqlInvokedFunction function, boolean replace) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support setting block encoding"); + } + + @Override + public void dropFunction(QualifiedObjectName functionName, Optional parameterTypes, boolean exists) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support drop function"); + } + + @Override + public void alterFunction(QualifiedObjectName functionName, Optional parameterTypes, AlterRoutineCharacteristics alterRoutineCharacteristics) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not alter function"); + } + + @Override + public void addUserDefinedType(UserDefinedType userDefinedType) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support adding user defined types"); + } + + @Override + public Optional getUserDefinedType(QualifiedObjectName typeName) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not support getting user defined types"); + } + + @Override + public CompletableFuture executeFunction(String source, FunctionHandle functionHandle, Page input, List channels, TypeManager typeManager) + { + throw new UnsupportedOperationException("BuiltInPluginFunctionNamespaceManager does not execute function"); + } + + private ScalarFunctionImplementation getScalarFunctionImplementation(Signature signature) + { + checkArgument(signature.getKind() == SCALAR, "%s is not a scalar function", signature); + checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature); + + try { + return specializedScalarCache.getUnchecked(getSpecializedFunctionKey(signature)); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + } + + private synchronized FunctionMap checkForNamingConflicts() + { + Optional> functionNamespaceManager = + functionAndTypeManager.getServingFunctionNamespaceManager(functionAndTypeManager.getDefaultNamespace()); + checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for catalog '%s'", functionAndTypeManager.getDefaultNamespace().getCatalogName()); + checkForNamingConflicts(functionNamespaceManager.get().listFunctions(Optional.empty(), Optional.empty())); + return functions; + } + + private synchronized void checkForNamingConflicts(Collection functions) + { + for (SqlFunction function : functions) { + for (SqlFunction existingFunction : this.functions.list()) { + checkArgument(!function.getSignature().equals(existingFunction.getSignature()), "Function already registered: %s", function.getSignature()); + } + } + } + + private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature) + { + return functionAndTypeManager.getSpecializedFunctionKey(signature, getFunctions(Optional.empty(), signature.getName())); + } + + private SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) + { + try { + return specializedFunctionKeyCache.getUnchecked(signature); + } + catch (UncheckedExecutionException e) { + throwIfInstanceOf(e.getCause(), PrestoException.class); + throw e; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index 27faec86a4256..eca3c9d6e1e79 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -292,9 +292,6 @@ import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.Multimap; -import com.google.common.collect.Multimaps; import com.google.common.util.concurrent.UncheckedExecutionException; import com.google.errorprone.annotations.ThreadSafe; import io.airlift.slice.Slice; @@ -1395,44 +1392,6 @@ private static class EmptyTransactionHandle { } - private static class FunctionMap - { - private final Multimap functions; - - public FunctionMap() - { - functions = ImmutableListMultimap.of(); - } - - public FunctionMap(FunctionMap map, Iterable functions) - { - this.functions = ImmutableListMultimap.builder() - .putAll(map.functions) - .putAll(Multimaps.index(functions, function -> function.getSignature().getName())) - .build(); - - // Make sure all functions with the same name are aggregations or none of them are - for (Map.Entry> entry : this.functions.asMap().entrySet()) { - Collection values = entry.getValue(); - long aggregations = values.stream() - .map(function -> function.getSignature().getKind()) - .filter(kind -> kind == AGGREGATE) - .count(); - checkState(aggregations == 0 || aggregations == values.size(), "'%s' is both an aggregation and a scalar function", entry.getKey()); - } - } - - public List list() - { - return ImmutableList.copyOf(functions.values()); - } - - public Collection get(QualifiedObjectName name) - { - return functions.get(name); - } - } - /** * TypeSignature but has overridden equals(). Here, we compare exact signature of any underlying distinct * types. Some distinct types may have extra information on their lazily loaded parents, and same parent diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java index 343b4e94d1d9f..8c3dfe0f9b855 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionAndTypeManager.java @@ -35,6 +35,7 @@ import com.facebook.presto.operator.window.WindowFunctionSupplier; import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.StandardErrorCode; import com.facebook.presto.spi.function.AggregationFunctionImplementation; import com.facebook.presto.spi.function.AlterRoutineCharacteristics; import com.facebook.presto.spi.function.FunctionHandle; @@ -92,11 +93,13 @@ import static com.facebook.presto.SystemSessionProperties.isExperimentalFunctionsEnabled; import static com.facebook.presto.SystemSessionProperties.isListBuiltInFunctionsOnly; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.metadata.BuiltInFunctionKind.PLUGIN; import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.metadata.CastType.toOperatorType; import static com.facebook.presto.metadata.FunctionSignatureMatcher.constructFunctionNotFoundErrorMessage; import static com.facebook.presto.metadata.SessionFunctionHandle.SESSION_NAMESPACE; import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables; +import static com.facebook.presto.spi.StandardErrorCode.AMBIGUOUS_FUNCTION_CALL; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; @@ -146,6 +149,7 @@ public class FunctionAndTypeManager private final CatalogSchemaName defaultNamespace; private final AtomicReference servingTypeManager; private final AtomicReference>> servingTypeManagerParametricTypesSupplier; + private final BuiltInPluginFunctionNamespaceManager builtInPluginFunctionNamespaceManager; @Inject public FunctionAndTypeManager( @@ -177,6 +181,7 @@ public FunctionAndTypeManager( this.defaultNamespace = configureDefaultNamespace(functionsConfig.getDefaultNamespacePrefix()); this.servingTypeManager = new AtomicReference<>(builtInTypeAndFunctionNamespaceManager); this.servingTypeManagerParametricTypesSupplier = new AtomicReference<>(this::getServingTypeManagerParametricTypes); + this.builtInPluginFunctionNamespaceManager = new BuiltInPluginFunctionNamespaceManager(this); } public static FunctionAndTypeManager createTestFunctionAndTypeManager() @@ -345,6 +350,9 @@ public FunctionMetadata getFunctionMetadata(FunctionHandle functionHandle) if (functionHandle.getCatalogSchemaName().equals(SESSION_NAMESPACE)) { return ((SessionFunctionHandle) functionHandle).getFunctionMetadata(); } + if (isBuiltInPluginFunctionHandle(functionHandle)) { + return builtInPluginFunctionNamespaceManager.getFunctionMetadata(functionHandle); + } Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for '%s'", functionHandle.getCatalogSchemaName()); return functionNamespaceManager.get().getFunctionMetadata(functionHandle); @@ -436,6 +444,11 @@ public void registerBuiltInFunctions(List functions) builtInTypeAndFunctionNamespaceManager.registerBuiltInFunctions(functions); } + public void registerPluginFunctions(List functions) + { + builtInPluginFunctionNamespaceManager.registerPluginFunctions(functions); + } + /** * likePattern / escape is an opportunistic optimization push down to function namespace managers. * Not all function namespace managers can handle it, thus the returned function list could @@ -453,12 +466,14 @@ public List listFunctions(Session session, Optional likePat functions.addAll(functionNamespaceManagers.get( defaultNamespace.getCatalogName()).listFunctions(likePattern, escape).stream() .collect(toImmutableList())); + functions.addAll(builtInPluginFunctionNamespaceManager.listFunctions(likePattern, escape).stream().collect(toImmutableList())); } else { functions.addAll(SessionFunctionUtils.listFunctions(session.getSessionFunctions())); functions.addAll(functionNamespaceManagers.values().stream() .flatMap(manager -> manager.listFunctions(likePattern, escape).stream()) .collect(toImmutableList())); + functions.addAll(builtInPluginFunctionNamespaceManager.listFunctions(likePattern, escape).stream().collect(toImmutableList())); } return functions.build().stream() @@ -486,7 +501,7 @@ public Collection getFunctions(Session session, Qualified Optional transactionHandle = session.getTransactionId().map( id -> transactionManager.getFunctionNamespaceTransaction(id, functionName.getCatalogName())); - return functionNamespaceManager.get().getFunctions(transactionHandle, functionName); + return getFunctions(functionName, transactionHandle, functionNamespaceManager.get()); } public void createFunction(SqlInvokedFunction function, boolean replace) @@ -601,6 +616,9 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionHand if (functionHandle.getCatalogSchemaName().equals(SESSION_NAMESPACE)) { return ((SessionFunctionHandle) functionHandle).getScalarFunctionImplementation(); } + if (isBuiltInPluginFunctionHandle(functionHandle)) { + return builtInPluginFunctionNamespaceManager.getScalarFunctionImplementation(functionHandle); + } Optional> functionNamespaceManager = getServingFunctionNamespaceManager(functionHandle.getCatalogSchemaName()); checkArgument(functionNamespaceManager.isPresent(), "Cannot find function namespace for '%s'", functionHandle.getCatalogSchemaName()); return functionNamespaceManager.get().getScalarFunctionImplementation(functionHandle); @@ -709,13 +727,7 @@ public FunctionHandle lookupFunction(QualifiedObjectName functionName, List candidates = functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); - Optional match = functionSignatureMatcher.match(candidates, parameterTypes, false); - if (!match.isPresent()) { - throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, candidates)); - } - - return functionNamespaceManager.get().getFunctionHandle(Optional.empty(), match.get()); + return getMatchingFunctionHandle(functionName, Optional.empty(), functionNamespaceManager.get(), parameterTypes, false); } public FunctionHandle lookupCast(CastType castType, Type fromType, Type toType) @@ -785,11 +797,14 @@ private FunctionHandle resolveFunctionInternal(Optional transacti return functionNamespaceManager.resolveFunction(transactionHandle, functionName, parameterTypes.stream().map(TypeSignatureProvider::getTypeSignature).collect(toImmutableList())); } - Collection candidates = functionNamespaceManager.getFunctions(transactionHandle, functionName); - - Optional match = functionSignatureMatcher.match(candidates, parameterTypes, true); - if (match.isPresent()) { - return functionNamespaceManager.getFunctionHandle(transactionHandle, match.get()); + try { + return getMatchingFunctionHandle(functionName, transactionHandle, functionNamespaceManager, parameterTypes, true); + } + catch (PrestoException e) { + // Could still match to a magic literal function + if (e.getErrorCode().getCode() != StandardErrorCode.FUNCTION_NOT_FOUND.toErrorCode().getCode()) { + throw e; + } } if (functionName.getObjectName().startsWith(MAGIC_LITERAL_FUNCTION_PREFIX)) { @@ -805,7 +820,8 @@ private FunctionHandle resolveFunctionInternal(Optional transacti return new BuiltInFunctionHandle(getMagicLiteralFunctionSignature(type)); } - throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, candidates)); + throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage( + functionName, parameterTypes, getFunctions(functionName, transactionHandle, functionNamespaceManager))); } private FunctionHandle resolveBuiltInFunction(QualifiedObjectName functionName, List parameterTypes) @@ -829,7 +845,7 @@ private FunctionHandle lookupCachedFunction(QualifiedObjectName functionName, Li } } - private Optional> getServingFunctionNamespaceManager(CatalogSchemaName functionNamespace) + public Optional> getServingFunctionNamespaceManager(CatalogSchemaName functionNamespace) { return Optional.ofNullable(functionNamespaceManagers.get(functionNamespace.getCatalogName())); } @@ -840,7 +856,6 @@ private Optional> getServingFunc } @Override - @SuppressWarnings("unchecked") public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) { QualifiedObjectName functionName = signature.getName(); @@ -849,8 +864,13 @@ public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature) throw new PrestoException(FUNCTION_NOT_FOUND, format("Cannot find function namespace for signature '%s'", functionName)); } - Collection candidates = (Collection) functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); + Collection candidates = functionNamespaceManager.get().getFunctions(Optional.empty(), functionName); + + return getSpecializedFunctionKey(signature, candidates); + } + public SpecializedFunctionKey getSpecializedFunctionKey(Signature signature, Collection candidates) + { // search for exact match Type returnType = getType(signature.getReturnType()); List argumentTypeSignatureProviders = fromTypeSignatures(signature.getArgumentTypes()); @@ -914,6 +934,66 @@ private Map getServingTypeManagerParametricTypes() .collect(toImmutableMap(ParametricType::getName, parametricType -> parametricType)); } + private Collection getFunctions( + QualifiedObjectName functionName, + Optional transactionHandle, + FunctionNamespaceManager functionNamespaceManager) + { + return ImmutableList.builder() + .addAll(functionNamespaceManager.getFunctions(transactionHandle, functionName)) + .addAll(builtInPluginFunctionNamespaceManager.getFunctions(transactionHandle, functionName)) + .build(); + } + + /** + * Gets the function handle of the function if there is a match. We enforce explicit naming for dynamic function namespaces. + * All unqualified function names will only be resolved against the built-in default function namespace. We get all the candidates + * from the current default namespace and additionally all the candidates from builtInPluginFunctionNamespaceManager. + * + * @throws PrestoException if there are no matches or multiple matches + */ + private FunctionHandle getMatchingFunctionHandle( + QualifiedObjectName functionName, + Optional transactionHandle, + FunctionNamespaceManager functionNamespaceManager, + List parameterTypes, + boolean coercionAllowed) + { + Optional matchingDefaultFunctionSignature = + getMatchingFunction(functionNamespaceManager.getFunctions(transactionHandle, functionName), parameterTypes, coercionAllowed); + Optional matchingPluginFunctionSignature = + getMatchingFunction(builtInPluginFunctionNamespaceManager.getFunctions(transactionHandle, functionName), parameterTypes, coercionAllowed); + + if (matchingDefaultFunctionSignature.isPresent() && matchingPluginFunctionSignature.isPresent()) { + throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, format("Function '%s' has two matching signatures. Please specify parameter types. \n" + + "First match : '%s', Second match: '%s'", functionName, matchingDefaultFunctionSignature.get(), matchingPluginFunctionSignature.get())); + } + + if (matchingDefaultFunctionSignature.isPresent()) { + return functionNamespaceManager.getFunctionHandle(transactionHandle, matchingDefaultFunctionSignature.get()); + } + + if (matchingPluginFunctionSignature.isPresent()) { + return builtInPluginFunctionNamespaceManager.getFunctionHandle(transactionHandle, matchingPluginFunctionSignature.get()); + } + + throw new PrestoException(FUNCTION_NOT_FOUND, constructFunctionNotFoundErrorMessage(functionName, parameterTypes, + getFunctions(functionName, transactionHandle, functionNamespaceManager))); + } + + private Optional getMatchingFunction( + Collection candidates, + List parameterTypes, + boolean coercionAllowed) + { + return functionSignatureMatcher.match(candidates, parameterTypes, coercionAllowed); + } + + private boolean isBuiltInPluginFunctionHandle(FunctionHandle functionHandle) + { + return (functionHandle instanceof BuiltInFunctionHandle) && ((BuiltInFunctionHandle) functionHandle).getBuiltInFunctionKind().equals(PLUGIN); + } + private static class FunctionResolutionCacheKey { private final QualifiedObjectName functionName; diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java index 9521027f30812..7dc7ee5194249 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionExtractor.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.metadata; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.operator.scalar.annotations.CodegenScalarFromAnnotationsParser; import com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser; import com.facebook.presto.operator.scalar.annotations.SqlInvokedScalarFromAnnotationsParser; @@ -28,6 +29,7 @@ import java.util.Collection; import java.util.List; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; @@ -44,6 +46,11 @@ public static List extractFunctions(Collection> classes) } public static List extractFunctions(Class clazz) + { + return extractFunctions(clazz, JAVA_BUILTIN_NAMESPACE); + } + + public static List extractFunctions(Class clazz, CatalogSchemaName defaultNamespace) { if (WindowFunction.class.isAssignableFrom(clazz)) { @SuppressWarnings("unchecked") @@ -61,12 +68,12 @@ public static List extractFunctions(Class clazz) } if (clazz.isAnnotationPresent(SqlInvokedScalarFunction.class)) { - return SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz); + return SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz, defaultNamespace); } List scalarFunctions = ImmutableList.builder() .addAll(ScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)) - .addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)) + .addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz, defaultNamespace)) .addAll(CodegenScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)) .build(); checkArgument(!scalarFunctions.isEmpty(), "Class [%s] does not define any scalar functions", clazz.getName()); diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java index 13a90da951278..f3dab976902a9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static java.util.Objects.requireNonNull; public class FunctionListBuilder @@ -62,13 +63,13 @@ public FunctionListBuilder scalars(Class clazz) public FunctionListBuilder sqlInvokedScalar(Class clazz) { - functions.addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz)); + functions.addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinition(clazz, JAVA_BUILTIN_NAMESPACE)); return this; } public FunctionListBuilder sqlInvokedScalars(Class clazz) { - functions.addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz)); + functions.addAll(SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(clazz, JAVA_BUILTIN_NAMESPACE)); return this; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionMap.java b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionMap.java new file mode 100644 index 0000000000000..d9f2f22f43900 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/metadata/FunctionMap.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.metadata; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.spi.function.SqlFunction; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.spi.function.FunctionKind.AGGREGATE; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class FunctionMap +{ + private final Multimap functions; + + public FunctionMap() + { + functions = ImmutableListMultimap.of(); + } + + public FunctionMap(FunctionMap map, Iterable functions) + { + requireNonNull(map, "map is null"); + requireNonNull(functions, "functions is null"); + this.functions = ImmutableListMultimap.builder() + .putAll(map.functions) + .putAll(Multimaps.index(functions, function -> function.getSignature().getName())) + .build(); + + // Make sure all functions with the same name are aggregations or none of them are + for (Map.Entry> entry : this.functions.asMap().entrySet()) { + Collection values = entry.getValue(); + long aggregations = values.stream() + .map(function -> function.getSignature().getKind()) + .filter(kind -> kind == AGGREGATE) + .count(); + checkState(aggregations == 0 || aggregations == values.size(), "'%s' is both an aggregation and a scalar function", entry.getKey()); + } + } + + public List list() + { + return ImmutableList.copyOf(functions.values()); + } + + public Collection get(QualifiedObjectName name) + { + return functions.get(name); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java index e34ede50a57f3..1c5f621cb8d05 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java +++ b/presto-main-base/src/main/java/com/facebook/presto/operator/scalar/annotations/SqlInvokedScalarFromAnnotationsParser.java @@ -13,6 +13,7 @@ */ package com.facebook.presto.operator.scalar.annotations; +import com.facebook.presto.common.CatalogSchemaName; import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.type.TypeSignature; import com.facebook.presto.spi.PrestoException; @@ -39,7 +40,6 @@ import java.util.stream.Stream; import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; -import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; import static com.facebook.presto.operator.annotations.FunctionsParserHelper.findPublicStaticMethods; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; @@ -60,7 +60,7 @@ public final class SqlInvokedScalarFromAnnotationsParser { private SqlInvokedScalarFromAnnotationsParser() {} - public static List parseFunctionDefinition(Class clazz) + public static List parseFunctionDefinition(Class clazz, CatalogSchemaName defaultNamespace) { checkArgument(clazz.isAnnotationPresent(SqlInvokedScalarFunction.class), "Class is not annotated with SqlInvokedScalarFunction: %s", clazz.getName()); @@ -68,15 +68,15 @@ public static List parseFunctionDefinition(Class clazz) Optional description = Optional.ofNullable(clazz.getAnnotation(Description.class)).map(Description::value); return findScalarsInFunctionDefinitionClass(clazz).stream() - .map(method -> createSqlInvokedFunctions(method, Optional.of(header), description)) + .map(method -> createSqlInvokedFunctions(method, Optional.of(header), description, defaultNamespace)) .flatMap(List::stream) .collect(toImmutableList()); } - public static List parseFunctionDefinitions(Class clazz) + public static List parseFunctionDefinitions(Class clazz, CatalogSchemaName defaultNamespace) { return findScalarsInFunctionSetClass(clazz).stream() - .map(method -> createSqlInvokedFunctions(method, Optional.empty(), Optional.empty())) + .map(method -> createSqlInvokedFunctions(method, Optional.empty(), Optional.empty(), defaultNamespace)) .flatMap(List::stream) .collect(toImmutableList()); } @@ -121,7 +121,7 @@ private static List findScalarsInFunctionSetClass(Class clazz) return ImmutableList.copyOf(methods); } - private static List createSqlInvokedFunctions(Method method, Optional header, Optional description) + private static List createSqlInvokedFunctions(Method method, Optional header, Optional description, CatalogSchemaName defaultNamespace) { SqlInvokedScalarFunction functionHeader = header.orElseGet(() -> method.getAnnotation(SqlInvokedScalarFunction.class)); String functionDescription = description.orElseGet(() -> method.isAnnotationPresent(Description.class) ? method.getAnnotation(Description.class).value() : ""); @@ -167,7 +167,7 @@ else if (method.isAnnotationPresent(SqlParameters.class)) { return Stream.concat(Stream.of(functionHeader.value()), stream(functionHeader.alias())) .map(name -> new SqlInvokedFunction( - QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, name), + QualifiedObjectName.valueOf(defaultNamespace, name), parameters, typeVariableConstraints, emptyList(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java index 3dfe5202d05e6..d072762d14dbc 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main-base/src/main/java/com/facebook/presto/server/PluginManager.java @@ -306,6 +306,12 @@ public void installPlugin(Plugin plugin) log.info("Registering client request filter factory"); clientRequestFilterManager.registerClientRequestFilterFactory(clientRequestFilterFactory); } + + for (Class functionClass : plugin.getSqlInvokedFunctions()) { + log.info("Registering functions from %s", functionClass.getName()); + metadata.getFunctionAndTypeManager().registerPluginFunctions( + extractFunctions(functionClass, metadata.getFunctionAndTypeManager().getDefaultNamespace())); + } } public void installCoordinatorPlugin(CoordinatorPlugin plugin) diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index a6749bf1a7e18..2b0b89d99b472 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -226,6 +226,7 @@ public class FeaturesConfig private boolean streamingForPartialAggregationEnabled; private boolean preferMergeJoinForSortedInputs; + private boolean preferSortMergeJoin; private boolean segmentedAggregationEnabled; private int maxStageCountForEagerScheduling = 25; @@ -2232,6 +2233,19 @@ public FeaturesConfig setPreferMergeJoinForSortedInputs(boolean preferMergeJoinF return this; } + public boolean isPreferSortMergeJoin() + { + return preferSortMergeJoin; + } + + @Config("experimental.optimizer.prefer-sort-merge-join") + @ConfigDescription("Prefer sort merge join for all joins. A SortNode is added if input is not already sorted.") + public FeaturesConfig setPreferSortMergeJoin(boolean preferSortMergeJoin) + { + this.preferSortMergeJoin = preferSortMergeJoin; + return this; + } + public boolean isSegmentedAggregationEnabled() { return segmentedAggregationEnabled; @@ -2970,6 +2984,7 @@ public boolean isInEqualityJoinPushdownEnabled() { return inEqualityJoinPushdownEnabled; } + public boolean isPrestoSparkExecutionEnvironment() { return prestoSparkExecutionEnvironment; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java index d4a6d6981f612..71d1199bf9da8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/GroupedExecutionTagger.java @@ -40,6 +40,7 @@ import static com.facebook.presto.SystemSessionProperties.GROUPED_EXECUTION; import static com.facebook.presto.SystemSessionProperties.isGroupedExecutionEnabled; +import static com.facebook.presto.SystemSessionProperties.preferSortMergeJoin; import static com.facebook.presto.spi.StandardErrorCode.INVALID_PLAN_ERROR; import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_PAGE_SINK_COMMIT; import static com.facebook.presto.spi.connector.ConnectorCapabilities.SUPPORTS_REWINDABLE_SPLIT_SOURCE; @@ -161,6 +162,10 @@ public GroupedExecutionTagger.GroupedExecutionProperties visitMergeJoin(MergeJoi left.totalLifespans, left.recoveryEligible && right.recoveryEligible); } + if (preferSortMergeJoin(session)) { + // TODO: This will break the other use case for merge join operating on sorted tables, which requires grouped execution for correctness. + return GroupedExecutionTagger.GroupedExecutionProperties.notCapable(); + } throw new PrestoException( INVALID_PLAN_ERROR, format("When grouped execution can't be enabled, merge join plan is not valid." + diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index dd1bb55ef6741..6e01015cf1d46 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -45,6 +45,7 @@ import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroLimit; import com.facebook.presto.sql.planner.iterative.rule.EvaluateZeroSample; import com.facebook.presto.sql.planner.iterative.rule.ExtractSpatialJoins; +import com.facebook.presto.sql.planner.iterative.rule.ExtractSystemTableFilterRuleSet; import com.facebook.presto.sql.planner.iterative.rule.GatherAndMergeWindows; import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations; @@ -186,6 +187,7 @@ import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer; import com.facebook.presto.sql.planner.optimizations.ShardJoins; import com.facebook.presto.sql.planner.optimizations.SimplifyPlanWithEmptyInput; +import com.facebook.presto.sql.planner.optimizations.SortMergeJoinOptimizer; import com.facebook.presto.sql.planner.optimizations.StatsRecordingPlanOptimizer; import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; import com.facebook.presto.sql.planner.optimizations.UnaliasSymbolReferences; @@ -927,7 +929,8 @@ public PlanOptimizers( // MergeJoinForSortedInputOptimizer can avoid the local exchange for a join operation // Should be placed after AddExchanges, but before AddLocalExchange // To replace the JoinNode to MergeJoin ahead of AddLocalExchange to avoid adding extra local exchange - builder.add(new MergeJoinForSortedInputOptimizer(metadata, featuresConfig.isNativeExecutionEnabled())); + builder.add(new MergeJoinForSortedInputOptimizer(metadata, featuresConfig.isNativeExecutionEnabled()), + new SortMergeJoinOptimizer(metadata, featuresConfig.isNativeExecutionEnabled())); // Optimizers above this don't understand local exchanges, so be careful moving this. builder.add(new AddLocalExchanges(metadata, featuresConfig.isNativeExecutionEnabled())); @@ -935,13 +938,13 @@ public PlanOptimizers( // Optimizers above this do not need to care about aggregations with the type other than SINGLE // This optimizer must be run after all exchange-related optimizers builder.add(new IterativeOptimizer( - metadata, - ruleStats, - statsCalculator, - costCalculator, - ImmutableSet.of( - new PushPartialAggregationThroughJoin(), - new PushPartialAggregationThroughExchange(metadata.getFunctionAndTypeManager(), featuresConfig.isNativeExecutionEnabled()))), + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of( + new PushPartialAggregationThroughJoin(), + new PushPartialAggregationThroughExchange(metadata.getFunctionAndTypeManager(), featuresConfig.isNativeExecutionEnabled()))), // MergePartialAggregationsWithFilter should immediately follow PushPartialAggregationThroughExchange new MergePartialAggregationsWithFilter(metadata.getFunctionAndTypeManager()), new IterativeOptimizer( @@ -980,6 +983,14 @@ public PlanOptimizers( // Pass after connector optimizer, as it relies on connector optimizer to identify empty input tables and convert them to empty ValuesNode builder.add(new SimplifyPlanWithEmptyInput()); + builder.add( + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + new ExtractSystemTableFilterRuleSet(metadata.getFunctionAndTypeManager()).rules())); + // DO NOT add optimizers that change the plan shape (computations) after this point // Precomputed hashes - this assumes that partitioning will not change diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java new file mode 100644 index 0000000000000..af5eeb84d4d99 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ExtractSystemTableFilterRuleSet.java @@ -0,0 +1,296 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.sql.planner.PlannerUtils; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.sql.planner.plan.Patterns.exchange; +import static com.facebook.presto.sql.planner.plan.Patterns.filter; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.planner.plan.Patterns.tableScan; +import static com.facebook.presto.sql.relational.RowExpressionUtils.containsNonCoordinatorEligibleCallExpression; +import static java.util.Objects.requireNonNull; + +/** + * RuleSet for extracting system table filters when they contain non-coordinator-eligible functions (e.g., CPP functions). + * This ensures that system table scans happen on the coordinator while CPP functions execute on workers. + * + * Patterns handled: + * 1. Exchange -> Project -> Filter -> TableScan (system) => Project -> Filter -> Exchange -> TableScan + * 2. Exchange -> Project -> TableScan (system) => Project -> Exchange -> TableScan + * 3. Exchange -> Filter -> TableScan (system) => Filter -> Exchange -> TableScan + */ +public class ExtractSystemTableFilterRuleSet +{ + private final FunctionAndTypeManager functionAndTypeManager; + + public ExtractSystemTableFilterRuleSet(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + public Set> rules() + { + return ImmutableSet.of( + new ProjectFilterScanRule(), + new ProjectScanRule(), + new FilterScanRule()); + } + + private abstract class SystemTableFilterRule + implements Rule + { + protected final Capture tableScanCapture = newCapture(); + + protected boolean containsFunctionsIneligibleOnCoordinator(Optional filterNode, Optional projectNode) + { + boolean hasIneligiblePredicates = filterNode + .map(filter -> containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, filter.getPredicate())) + .orElse(false); + + boolean hasIneligibleProjections = projectNode + .map(project -> project.getAssignments().getExpressions().stream() + .anyMatch(expression -> containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, expression))) + .orElse(false); + + return hasIneligiblePredicates || hasIneligibleProjections; + } + } + + private final class ProjectFilterScanRule + extends SystemTableFilterRule + { + private final Capture exchangeCapture = newCapture(); + private final Capture projectCapture = newCapture(); + private final Capture filterCapture = newCapture(); + + @Override + public Pattern getPattern() + { + return exchange() + .capturedAs(exchangeCapture) + .with(source().matching( + project() + .capturedAs(projectCapture) + .with(source().matching( + filter() + .capturedAs(filterCapture) + .with(source().matching( + tableScan() + .capturedAs(tableScanCapture) + .matching(PlannerUtils::containsSystemTableScan))))))); + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(tableScanCapture); + ExchangeNode exchangeNode = captures.get(exchangeCapture); + ProjectNode projectNode = captures.get(projectCapture); + FilterNode filterNode = captures.get(filterCapture); + + if (!containsFunctionsIneligibleOnCoordinator(Optional.of(filterNode), Optional.of(projectNode))) { + return Result.empty(); + } + + // The exchange's output variables must match what the filter expects + // Since the filter was originally between project and table scan, it expects + // the table scan's output variables + PartitioningScheme newPartitioningScheme = new PartitioningScheme( + exchangeNode.getPartitioningScheme().getPartitioning(), + tableScanNode.getOutputVariables(), + exchangeNode.getPartitioningScheme().getHashColumn(), + exchangeNode.getPartitioningScheme().isScaleWriters(), + exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), + exchangeNode.getPartitioningScheme().getEncoding(), + exchangeNode.getPartitioningScheme().getBucketToPartition()); + + // Create new exchange with table scan as source + ExchangeNode newExchange = new ExchangeNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchangeNode.getType(), + exchangeNode.getScope(), + newPartitioningScheme, + ImmutableList.of(tableScanNode), + ImmutableList.of(tableScanNode.getOutputVariables()), + exchangeNode.isEnsureSourceOrdering(), + exchangeNode.getOrderingScheme()); + + // Recreate filter with exchange as source + FilterNode newFilter = new FilterNode( + filterNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newExchange, + filterNode.getPredicate()); + + // Recreate project with filter as source + ProjectNode newProject = new ProjectNode( + projectNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newFilter, + projectNode.getAssignments(), + projectNode.getLocality()); + + return Result.ofPlanNode(newProject); + } + } + + private final class ProjectScanRule + extends SystemTableFilterRule + { + private final Capture exchangeCapture = newCapture(); + private final Capture projectCapture = newCapture(); + + @Override + public Pattern getPattern() + { + return exchange() + .capturedAs(exchangeCapture) + .with(source().matching( + project() + .capturedAs(projectCapture) + .with(source().matching( + tableScan() + .capturedAs(tableScanCapture) + .matching(PlannerUtils::containsSystemTableScan))))); + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(tableScanCapture); + ExchangeNode exchangeNode = captures.get(exchangeCapture); + ProjectNode projectNode = captures.get(projectCapture); + + if (!containsFunctionsIneligibleOnCoordinator(Optional.empty(), Optional.of(projectNode))) { + return Result.empty(); + } + + // Update partitioning scheme to match table scan outputs + PartitioningScheme newPartitioningScheme = new PartitioningScheme( + exchangeNode.getPartitioningScheme().getPartitioning(), + tableScanNode.getOutputVariables(), + exchangeNode.getPartitioningScheme().getHashColumn(), + exchangeNode.getPartitioningScheme().isScaleWriters(), + exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), + exchangeNode.getPartitioningScheme().getEncoding(), + exchangeNode.getPartitioningScheme().getBucketToPartition()); + + // Create new exchange with table scan as source + ExchangeNode newExchange = new ExchangeNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchangeNode.getType(), + exchangeNode.getScope(), + newPartitioningScheme, + ImmutableList.of(tableScanNode), + ImmutableList.of(tableScanNode.getOutputVariables()), + exchangeNode.isEnsureSourceOrdering(), + exchangeNode.getOrderingScheme()); + + // Recreate project with exchange as source + ProjectNode newProject = new ProjectNode( + projectNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newExchange, + projectNode.getAssignments(), + projectNode.getLocality()); + + return Result.ofPlanNode(newProject); + } + } + + private final class FilterScanRule + extends SystemTableFilterRule + { + private final Capture exchangeCapture = newCapture(); + private final Capture filterCapture = newCapture(); + + @Override + public Pattern getPattern() + { + return exchange() + .capturedAs(exchangeCapture) + .with(source().matching( + filter() + .capturedAs(filterCapture) + .with(source().matching( + tableScan() + .capturedAs(tableScanCapture) + .matching(PlannerUtils::containsSystemTableScan))))); + } + + @Override + public Result apply(ExchangeNode node, Captures captures, Context context) + { + TableScanNode tableScanNode = captures.get(tableScanCapture); + ExchangeNode exchangeNode = captures.get(exchangeCapture); + FilterNode filterNode = captures.get(filterCapture); + + if (!containsFunctionsIneligibleOnCoordinator(Optional.of(filterNode), Optional.empty())) { + return Result.empty(); + } + + // Update partitioning scheme to match table scan outputs + PartitioningScheme newPartitioningScheme = new PartitioningScheme( + exchangeNode.getPartitioningScheme().getPartitioning(), + tableScanNode.getOutputVariables(), + exchangeNode.getPartitioningScheme().getHashColumn(), + exchangeNode.getPartitioningScheme().isScaleWriters(), + exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), + exchangeNode.getPartitioningScheme().getEncoding(), + exchangeNode.getPartitioningScheme().getBucketToPartition()); + + // Create new exchange with table scan as source + ExchangeNode newExchange = new ExchangeNode( + exchangeNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + exchangeNode.getType(), + exchangeNode.getScope(), + newPartitioningScheme, + ImmutableList.of(tableScanNode), + ImmutableList.of(tableScanNode.getOutputVariables()), + exchangeNode.isEnsureSourceOrdering(), + exchangeNode.getOrderingScheme()); + + // Recreate filter with exchange as source + FilterNode newFilter = new FilterNode( + filterNode.getSourceLocation(), + context.getIdAllocator().getNextId(), + newExchange, + filterNode.getPredicate()); + + return Result.ofPlanNode(newFilter); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java index faf96d27adabc..ea79310608dfb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PickTableLayout.java @@ -59,10 +59,12 @@ import static com.facebook.presto.metadata.TableLayoutResult.computeEnforced; import static com.facebook.presto.spi.relation.DomainTranslator.BASIC_COLUMN_EXTRACTOR; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; +import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan; import static com.facebook.presto.sql.planner.iterative.rule.PreconditionRules.checkRulesAreFiredBeforeAddExchangesRule; import static com.facebook.presto.sql.planner.plan.Patterns.filter; import static com.facebook.presto.sql.planner.plan.Patterns.source; import static com.facebook.presto.sql.planner.plan.Patterns.tableScan; +import static com.facebook.presto.sql.relational.RowExpressionUtils.containsNonCoordinatorEligibleCallExpression; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Sets.intersection; @@ -271,6 +273,16 @@ private static PlanNode pushPredicateIntoTableScan( new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()), metadata.getFunctionAndTypeManager()); RowExpression deterministicPredicate = logicalRowExpressions.filterDeterministicConjuncts(predicate); + // If the predicate contains non-Java expressions, we cannot prune partitions over system tables. + RowExpression ineligiblePredicate = TRUE_CONSTANT; + if (containsSystemTableScan(node)) { + ineligiblePredicate = logicalRowExpressions.filterConjuncts( + deterministicPredicate, + expression -> containsNonCoordinatorEligibleCallExpression(metadata.getFunctionAndTypeManager(), expression)); + deterministicPredicate = logicalRowExpressions.filterConjuncts( + deterministicPredicate, + expression -> !containsNonCoordinatorEligibleCallExpression(metadata.getFunctionAndTypeManager(), expression)); + } DomainTranslator.ExtractionResult decomposedPredicate = domainTranslator.fromPredicate( session.toConnectorSession(), deterministicPredicate, @@ -339,7 +351,8 @@ private static PlanNode pushPredicateIntoTableScan( RowExpression resultingPredicate = logicalRowExpressions.combineConjuncts( domainTranslator.toPredicate(layout.getUnenforcedConstraint().transform(assignments::get)), logicalRowExpressions.filterNonDeterministicConjuncts(predicate), - decomposedPredicate.getRemainingExpression()); + decomposedPredicate.getRemainingExpression(), + ineligiblePredicate); if (!TRUE_CONSTANT.equals(resultingPredicate)) { return new FilterNode(node.getSourceLocation(), idAllocator.getNextId(), tableScan, resultingPredicate); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index d0f15ab3f08dc..1618c69b543e6 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -29,6 +29,7 @@ import com.facebook.presto.spi.plan.JoinNode; import com.facebook.presto.spi.plan.LimitNode; import com.facebook.presto.spi.plan.MarkDistinctNode; +import com.facebook.presto.spi.plan.MergeJoinNode; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.OutputNode; import com.facebook.presto.spi.plan.Partitioning; @@ -78,6 +79,7 @@ import static com.facebook.presto.SystemSessionProperties.isQuickDistinctLimitEnabled; import static com.facebook.presto.SystemSessionProperties.isSegmentedAggregationEnabled; import static com.facebook.presto.SystemSessionProperties.isSpillEnabled; +import static com.facebook.presto.SystemSessionProperties.preferSortMergeJoin; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; import static com.facebook.presto.operator.aggregation.AggregationUtils.hasSingleNodeExecutionPreference; @@ -887,6 +889,17 @@ public PlanWithProperties visitSpatialJoin(SpatialJoinNode node, StreamPreferred return rebaseAndDeriveProperties(node, ImmutableList.of(probe, build)); } + @Override + public PlanWithProperties visitMergeJoin(MergeJoinNode node, StreamPreferredProperties parentPreferences) + { + if (preferSortMergeJoin(session)) { + PlanWithProperties probe = planAndEnforce(node.getLeft(), singleStream(), singleStream()); + PlanWithProperties build = planAndEnforce(node.getRight(), singleStream(), singleStream()); + return rebaseAndDeriveProperties(node, ImmutableList.of(probe, build)); + } + return super.visitMergeJoin(node, parentPreferences); + } + @Override public PlanWithProperties visitIndexJoin(IndexJoinNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortMergeJoinOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortMergeJoinOptimizer.java new file mode 100644 index 0000000000000..f97398529a6f5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SortMergeJoinOptimizer.java @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.EquiJoinClause; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.MergeJoinNode; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.SortNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.preferSortMergeJoin; +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class SortMergeJoinOptimizer + implements PlanOptimizer +{ + private final Metadata metadata; + private final boolean nativeExecution; + private boolean isEnabledForTesting; + + public SortMergeJoinOptimizer(Metadata metadata, boolean nativeExecution) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.nativeExecution = nativeExecution; + } + + @Override + public void setEnabledForTesting(boolean isSet) + { + isEnabledForTesting = isSet; + } + + @Override + public boolean isEnabled(Session session) + { + // TODO: Consider group execution and single node execution. + return isEnabledForTesting || preferSortMergeJoin(session); + } + + @Override + public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider type, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(variableAllocator, "variableAllocator is null"); + requireNonNull(idAllocator, "idAllocator is null"); + + if (isEnabled(session)) { + Rewriter rewriter = new SortMergeJoinOptimizer.Rewriter(idAllocator, metadata, session); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, null); + return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + } + return PlanOptimizerResult.optimizerResult(plan, false); + } + + /** + * @param joinNode + * @return returns true if merge join is supported for the given join node. + */ + public boolean isMergeJoinEligible(JoinNode joinNode) + { + return (joinNode.getType() == JoinType.INNER || joinNode.getType() == JoinType.LEFT || joinNode.getType() == JoinType.RIGHT) + && !joinNode.isCrossJoin(); + } + + private class Rewriter + extends SimplePlanRewriter + { + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final Session session; + private boolean planChanged; + + private Rewriter(PlanNodeIdAllocator idAllocator, Metadata metadata, Session session) + { + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); + } + + public boolean isPlanChanged() + { + return planChanged; + } + + @Override + public PlanNode visitJoin(JoinNode node, RewriteContext context) + { + if (!isMergeJoinEligible(node)) { + return node; + } + + PlanNode left = node.getLeft(); + PlanNode right = node.getRight(); + + List leftJoinColumns = node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); + + if (!isPlanOutputSortedByColumns(left, leftJoinColumns)) { + List leftOrdering = node.getCriteria().stream() + .map(criterion -> new Ordering(criterion.getLeft(), ASC_NULLS_FIRST)) + .collect(toImmutableList()); + left = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + left, + new OrderingScheme(leftOrdering), + true, + ImmutableList.of()); + } + + List rightJoinColumns = node.getCriteria().stream() + .map(EquiJoinClause::getRight) + .collect(toImmutableList()); + if (!isPlanOutputSortedByColumns(right, rightJoinColumns)) { + List rightOrdering = node.getCriteria().stream() + .map(criterion -> new Ordering(criterion.getRight(), ASC_NULLS_FIRST)) + .collect(toImmutableList()); + right = new SortNode( + Optional.empty(), + idAllocator.getNextId(), + right, + new OrderingScheme(rightOrdering), + true, + ImmutableList.of()); + } + + planChanged = true; + return new MergeJoinNode( + Optional.empty(), + node.getId(), + node.getType(), + left, + right, + node.getCriteria(), + node.getOutputVariables(), + node.getFilter(), + node.getLeftHashVariable(), + node.getRightHashVariable()); + } + + private boolean isPlanOutputSortedByColumns(PlanNode plan, List columns) + { + StreamPropertyDerivations.StreamProperties properties = StreamPropertyDerivations.derivePropertiesRecursively(plan, metadata, session, nativeExecution); + + // Check if partitioning columns (bucketed-by columns [B]) are a subset of join columns [J] + // B = subset (J) + if (!verifyStreamProperties(properties, columns)) { + return false; + } + + // Check if the output of the subplan is ordered by the join columns + return !LocalProperties.match(properties.getLocalProperties(), LocalProperties.sorted(columns, ASC_NULLS_FIRST)).get(0).isPresent(); + } + + private boolean verifyStreamProperties(StreamPropertyDerivations.StreamProperties streamProperties, List joinColumns) + { + if (!streamProperties.getPartitioningColumns().isPresent()) { + return false; + } + List partitioningColumns = streamProperties.getPartitioningColumns().get(); + return partitioningColumns.size() <= joinColumns.size() && joinColumns.containsAll(partitioningColumns); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CheckNoIneligibleFunctionsInCoordinatorFragments.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CheckNoIneligibleFunctionsInCoordinatorFragments.java new file mode 100644 index 0000000000000..a2abbe91c84a9 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/CheckNoIneligibleFunctionsInCoordinatorFragments.java @@ -0,0 +1,174 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.sql.planner.SimplePlanVisitor; +import com.facebook.presto.sql.planner.plan.ExchangeNode; + +import static com.facebook.presto.sql.planner.PlannerUtils.containsSystemTableScan; +import static com.facebook.presto.sql.relational.RowExpressionUtils.containsNonCoordinatorEligibleCallExpression; +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +/** + * Validates that there are no filter or projection nodes containing non-Java functions + * (which must be evaluated on native nodes) within the same fragment as a system table scan + * (which must be evaluated on the coordinator). + */ +public class CheckNoIneligibleFunctionsInCoordinatorFragments + implements PlanChecker.Checker +{ + @Override + public void validate(PlanNode planNode, Session session, Metadata metadata, WarningCollector warningCollector) + { + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + // Validate each fragment independently + validateFragment(planNode, functionAndTypeManager); + } + + private void validateFragment(PlanNode root, FunctionAndTypeManager functionAndTypeManager) + { + // First, collect information about this fragment + FragmentValidator validator = new FragmentValidator(functionAndTypeManager); + root.accept(validator, null); + + // Check if this fragment violates the constraint + checkState( + !(validator.hasSystemTableScan() && validator.hasNonCoordinatorEligibleFunction()), + "Fragment contains both system table scan and non-Java functions. " + + "System table scans must execute on the coordinator while non-Java functions must execute on native nodes. " + + "These operations must be in separate fragments separated by an exchange."); + + // Recursively validate child fragments + ChildFragmentVisitor childVisitor = new ChildFragmentVisitor(functionAndTypeManager); + root.accept(childVisitor, null); + } + + /** + * Visits nodes within a single fragment to collect information about + * system table scans and non-coordinator-eligible functions. + * Stops at exchange boundaries. + */ + private static class FragmentValidator + extends SimplePlanVisitor + { + private final FunctionAndTypeManager functionAndTypeManager; + private boolean hasSystemTableScan; + private boolean hasNonCoordinatorEligibleFunction; + + public FragmentValidator(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + public boolean hasSystemTableScan() + { + return hasSystemTableScan; + } + + public boolean hasNonCoordinatorEligibleFunction() + { + return hasNonCoordinatorEligibleFunction; + } + + @Override + public Void visitExchange(ExchangeNode node, Void context) + { + // Don't traverse into exchange sources - they are different fragments + return null; + } + + @Override + public Void visitTableScan(TableScanNode node, Void context) + { + if (containsSystemTableScan(node)) { + hasSystemTableScan = true; + } + return null; + } + + @Override + public Void visitFilter(FilterNode node, Void context) + { + RowExpression predicate = node.getPredicate(); + if (containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, predicate)) { + hasNonCoordinatorEligibleFunction = true; + } + return visitPlan(node, context); + } + + @Override + public Void visitProject(ProjectNode node, Void context) + { + boolean hasIneligibleProjections = node.getAssignments().getExpressions().stream() + .anyMatch(expression -> containsNonCoordinatorEligibleCallExpression(functionAndTypeManager, expression)); + + if (hasIneligibleProjections) { + hasNonCoordinatorEligibleFunction = true; + } + return visitPlan(node, context); + } + + @Override + public Void visitPlan(PlanNode node, Void context) + { + for (PlanNode source : node.getSources()) { + source.accept(this, context); + } + return null; + } + } + + /** + * Visits nodes to find and validate child fragments (those below exchanges). + */ + private class ChildFragmentVisitor + extends SimplePlanVisitor + { + private final FunctionAndTypeManager functionAndTypeManager; + + public ChildFragmentVisitor(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + @Override + public Void visitExchange(ExchangeNode node, Void context) + { + // Each source of an exchange is a separate fragment + for (PlanNode source : node.getSources()) { + validateFragment(source, functionAndTypeManager); + } + return null; + } + + @Override + public Void visitPlan(PlanNode node, Void context) + { + for (PlanNode source : node.getSources()) { + source.accept(this, context); + } + return null; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java index e8914dbbc331f..d0a85b8632e35 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/PlanChecker.java @@ -74,9 +74,12 @@ public PlanChecker(FeaturesConfig featuresConfig, boolean noExchange, PlanChecke new VerifyProjectionLocality(), new DynamicFiltersChecker(), new WarnOnScanWithoutPartitionPredicate(featuresConfig)); - if (featuresConfig.isNativeExecutionEnabled() && (featuresConfig.isDisableTimeStampWithTimeZoneForNative() || - featuresConfig.isDisableIPAddressForNative())) { - builder.put(Stage.INTERMEDIATE, new CheckUnsupportedPrestissimoTypes(featuresConfig)); + if (featuresConfig.isNativeExecutionEnabled()) { + if (featuresConfig.isDisableTimeStampWithTimeZoneForNative() || + featuresConfig.isDisableIPAddressForNative()) { + builder.put(Stage.INTERMEDIATE, new CheckUnsupportedPrestissimoTypes(featuresConfig)); + } + builder.put(Stage.FINAL, new CheckNoIneligibleFunctionsInCoordinatorFragments()); } checkers = builder.build(); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionUtils.java new file mode 100644 index 0000000000000..531b25e715de5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/relational/RowExpressionUtils.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.relational; + +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionVisitor; + +import static java.util.Objects.requireNonNull; + +public class RowExpressionUtils +{ + private RowExpressionUtils() {} + + public static boolean containsNonCoordinatorEligibleCallExpression(FunctionAndTypeManager functionAndTypeManager, RowExpression expression) + { + return expression.accept(new ContainsNonCoordinatorEligibleCallExpressionVisitor(functionAndTypeManager), null); + } + + private static class ContainsNonCoordinatorEligibleCallExpressionVisitor + implements RowExpressionVisitor + { + private final FunctionAndTypeManager functionAndTypeManager; + + public ContainsNonCoordinatorEligibleCallExpressionVisitor(FunctionAndTypeManager functionAndTypeManager) + { + this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null"); + } + + @Override + public Boolean visitCall(CallExpression call, Void context) + { + // If the call is not a Java function, we return true to indicate that we found a non-Java expression + FunctionHandle functionHandle = call.getFunctionHandle(); + FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(functionHandle); + if (!functionMetadata.getImplementationType().canBeEvaluatedInCoordinator()) { + return true; + } + for (RowExpression argument : call.getArguments()) { + if (argument.accept(this, context)) { + return true; // Found a non-Java expression in arguments + } + } + return false; + } + + @Override + public Boolean visitExpression(RowExpression expression, Void context) + { + for (RowExpression child : expression.getChildren()) { + if (child.accept(this, context)) { + return true; // Found a non-Java expression + } + } + return false; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index d3c7044fc6df1..6bdd21fcc2312 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -1117,12 +1117,17 @@ public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.P } public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.PlanStage stage, boolean noExchange, WarningCollector warningCollector) + { + return createPlan(session, sql, stage, noExchange, false, warningCollector); + } + + public Plan createPlan(Session session, @Language("SQL") String sql, Optimizer.PlanStage stage, boolean noExchange, boolean nativeExecutionEnabled, WarningCollector warningCollector) { AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, warningCollector); BuiltInPreparedQuery preparedQuery = new BuiltInQueryPreparer(sqlParser).prepareQuery(analyzerOptions, sql, session.getPreparedStatements(), warningCollector); assertFormattedSql(sqlParser, createParsingOptions(session), preparedQuery.getStatement()); - return createPlan(session, sql, getPlanOptimizers(noExchange), stage, warningCollector); + return createPlan(session, sql, getPlanOptimizers(noExchange, nativeExecutionEnabled), stage, warningCollector); } public void setAdditionalOptimizer(List additionalOptimizer) @@ -1131,10 +1136,16 @@ public void setAdditionalOptimizer(List additionalOptimizer) } public List getPlanOptimizers(boolean noExchange) + { + return getPlanOptimizers(noExchange, false); + } + + public List getPlanOptimizers(boolean noExchange, boolean nativeExecutionEnabled) { FeaturesConfig featuresConfig = new FeaturesConfig() .setDistributedIndexJoinsEnabled(false) - .setOptimizeHashGeneration(true); + .setOptimizeHashGeneration(true) + .setNativeExecutionEnabled(nativeExecutionEnabled); ImmutableList.Builder planOptimizers = ImmutableList.builder(); if (!additionalOptimizer.isEmpty()) { planOptimizers.addAll(additionalOptimizer); diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java index 053635444a321..6ed8a5dacb20b 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/TestAnnotationEngineForSqlInvokedScalars.java @@ -51,7 +51,7 @@ public void testParseFunctionDefinition() new ArrayType(BIGINT).getTypeSignature(), ImmutableList.of(INTEGER.getTypeSignature())); - List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunction.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunction.class, JAVA_BUILTIN_NAMESPACE); assertEquals(functions.size(), 1); SqlInvokedFunction f = functions.get(0); @@ -75,7 +75,7 @@ public void testParseFunctionDefinitionWithTypeParameter() ImmutableList.of(new TypeSignature("T")), false); - List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunctionWithTypeParameter.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(SingleImplementationSQLInvokedScalarFunctionWithTypeParameter.class, JAVA_BUILTIN_NAMESPACE); assertEquals(functions.size(), 1); SqlInvokedFunction f = functions.get(0); diff --git a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java index be335dd54c5c9..6d0b4f0f65052 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/operator/scalar/TestCustomFunctions.java @@ -24,6 +24,7 @@ import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; public class TestCustomFunctions extends AbstractTestFunctions @@ -41,7 +42,7 @@ protected TestCustomFunctions(FeaturesConfig config) public void setupClass() { registerScalar(CustomFunctions.class); - List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(CustomFunctions.class); + List functions = SqlInvokedScalarFromAnnotationsParser.parseFunctionDefinitions(CustomFunctions.class, JAVA_BUILTIN_NAMESPACE); this.functionAssertions.addFunctions(functions); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 56920001aa06e..25a802b1e06ea 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -194,6 +194,7 @@ public void testDefaults() .setMaxStageCountForEagerScheduling(25) .setHyperloglogStandardErrorWarningThreshold(0.004) .setPreferMergeJoinForSortedInputs(false) + .setPreferSortMergeJoin(false) .setSegmentedAggregationEnabled(false) .setQueryAnalyzerTimeout(new Duration(3, MINUTES)) .setQuickDistinctLimitEnabled(false) @@ -408,6 +409,7 @@ public void testExplicitPropertyMappings() .put("execution-policy.max-stage-count-for-eager-scheduling", "123") .put("hyperloglog-standard-error-warning-threshold", "0.02") .put("optimizer.prefer-merge-join-for-sorted-inputs", "true") + .put("experimental.optimizer.prefer-sort-merge-join", "true") .put("optimizer.segmented-aggregation-enabled", "true") .put("planner.query-analyzer-timeout", "10s") .put("optimizer.quick-distinct-limit-enabled", "true") @@ -619,6 +621,7 @@ public void testExplicitPropertyMappings() .setMaxStageCountForEagerScheduling(123) .setHyperloglogStandardErrorWarningThreshold(0.02) .setPreferMergeJoinForSortedInputs(true) + .setPreferSortMergeJoin(true) .setSegmentedAggregationEnabled(true) .setQueryAnalyzerTimeout(new Duration(10, SECONDS)) .setQuickDistinctLimitEnabled(true) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java index 509eebb559a61..13759ae9df0a6 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestLogicalPlanner.java @@ -66,8 +66,10 @@ import static com.facebook.presto.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static com.facebook.presto.SystemSessionProperties.LEAF_NODE_LIMIT_ENABLED; import static com.facebook.presto.SystemSessionProperties.MAX_LEAF_NODES_IN_PLAN; +import static com.facebook.presto.SystemSessionProperties.NATIVE_EXECUTION_ENABLED; import static com.facebook.presto.SystemSessionProperties.OFFSET_CLAUSE_ENABLED; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_HASH_GENERATION; +import static com.facebook.presto.SystemSessionProperties.PREFER_SORT_MERGE_JOIN; import static com.facebook.presto.SystemSessionProperties.PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID; import static com.facebook.presto.SystemSessionProperties.REMOVE_CROSS_JOIN_WITH_CONSTANT_SINGLE_ROW_INPUT; import static com.facebook.presto.SystemSessionProperties.SIMPLIFY_PLAN_WITH_EMPTY_INPUT; @@ -108,6 +110,7 @@ import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.markDistinct; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.mergeJoin; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; @@ -129,6 +132,7 @@ import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST; import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING; @@ -526,6 +530,64 @@ public void testJoinWithOrderBySameKey() tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))); } + @Test + public void testSortMergeJoin() + { + Session preferSortMergeJoin = Session.builder(noJoinReordering()) + .setSystemProperty(NATIVE_EXECUTION_ENABLED, "true") + .setSystemProperty(PREFER_SORT_MERGE_JOIN, "true") + .setSystemProperty(DISTRIBUTED_SORT, "false") + .build(); + + // Both sides are not sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.custkey = l.partkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_CK", "LINEITEM_PK")), Optional.empty(), + sort( + ImmutableList.of(sort("ORDERS_CK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("orders", ImmutableMap.of("ORDERS_CK", "custkey")))), + sort( + ImmutableList.of(sort("LINEITEM_PK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("lineitem", ImmutableMap.of("LINEITEM_PK", "partkey"))))))); + + // Left side is sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.orderkey = l.partkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_OK", "LINEITEM_PK")), Optional.empty(), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))), + sort( + ImmutableList.of(sort("LINEITEM_PK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("lineitem", ImmutableMap.of("LINEITEM_PK", "partkey"))))))); + + // Right side is sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.custkey = l.orderkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_CK", "LINEITEM_OK")), Optional.empty(), + sort( + ImmutableList.of(sort("ORDERS_CK", ASCENDING, FIRST)), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("orders", ImmutableMap.of("ORDERS_CK", "custkey")))), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))); + + // Both sides are sorted. + assertPlan("SELECT o.orderkey FROM orders o INNER JOIN lineitem l ON o.orderkey = l.orderkey", + preferSortMergeJoin, + anyTree( + mergeJoin(INNER, ImmutableList.of(equiJoinClause("ORDERS_OK", "LINEITEM_OK")), Optional.empty(), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("orders", ImmutableMap.of("ORDERS_OK", "orderkey"))), + exchange(LOCAL, GATHER, ImmutableList.of(), + tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))); + } + @Test public void testUncorrelatedSubqueries() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java index 45deffb292fa7..f090c3e1e897a 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/BasePlanTest.java @@ -226,11 +226,21 @@ protected void assertDistributedPlan(String sql, PlanMatchPattern pattern) assertDistributedPlan(sql, getQueryRunner().getDefaultSession(), pattern); } + protected void assertNativeDistributedPlan(String sql, PlanMatchPattern pattern) + { + assertNativeDistributedPlan(sql, getQueryRunner().getDefaultSession(), pattern); + } + protected void assertDistributedPlan(String sql, Session session, PlanMatchPattern pattern) { assertPlanWithSession(sql, session, false, pattern); } + protected void assertNativeDistributedPlan(String sql, Session session, PlanMatchPattern pattern) + { + assertPlanWithSession(sql, session, false, true, pattern); + } + protected void assertMinimallyOptimizedPlan(@Language("SQL") String sql, PlanMatchPattern pattern) { List optimizers = ImmutableList.of( @@ -262,9 +272,14 @@ protected void assertMinimallyOptimizedPlanDoesNotMatch(@Language("SQL") String } protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean noExchange, PlanMatchPattern pattern) + { + assertPlanWithSession(sql, session, noExchange, false, pattern); + } + + protected void assertPlanWithSession(@Language("SQL") String sql, Session session, boolean noExchange, boolean nativeExecutionEnabled, PlanMatchPattern pattern) { queryRunner.inTransaction(session, transactionSession -> { - Plan actualPlan = queryRunner.createPlan(transactionSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, noExchange, WarningCollector.NOOP); + Plan actualPlan = queryRunner.createPlan(transactionSession, sql, Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, noExchange, nativeExecutionEnabled, WarningCollector.NOOP); PlanAssert.assertPlan(transactionSession, queryRunner.getMetadata(), queryRunner.getStatsCalculator(), actualPlan, pattern); return null; }); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java new file mode 100644 index 0000000000000..577779a780ad4 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestAddExchangesPlansWithFunctions.java @@ -0,0 +1,636 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; +import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.operator.scalar.CombineHashFunction; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.facebook.presto.type.BigintOperators; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser.parseFunctionDefinitions; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA; +import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +/** + * These are plan tests similar to what we have for other optimizers (e.g. {@link com.facebook.presto.sql.planner.TestPredicatePushdown}) + * They test that the plan for a query after the optimizer runs is as expected. + * These are separate from {@link TestAddExchanges} because those are unit tests for + * how layouts get chosen. + *

+ * Key behavior tested: When CPP functions are used with system tables, the filter containing + * the CPP function is preserved above the exchange (not pushed down) to ensure the filter + * executes in a different fragment from the system table scan. This validates the fragment + * boundary between CPP function evaluation and system table access. + */ +public class TestAddExchangesPlansWithFunctions + extends BasePlanTest +{ + public TestAddExchangesPlansWithFunctions() + { + super(TestAddExchangesPlansWithFunctions::createTestQueryRunner); + } + + private static final SqlInvokedFunction CPP_FOO = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "cpp_foo"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "cpp_foo(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction CPP_BAZ = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "cpp_baz"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "cpp_baz(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction JAVA_BAR = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "java_bar"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "java_bar(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction JAVA_FEE = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "java_fee"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature(StandardTypes.BIGINT), + "java_fee(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction NOT = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "not"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BOOLEAN))), + parseTypeSignature(StandardTypes.BOOLEAN), + "not(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static final SqlInvokedFunction CPP_ARRAY_CONSTRUCTOR = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "array_constructor"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.BIGINT)), new Parameter("y", parseTypeSignature(StandardTypes.BIGINT))), + parseTypeSignature("array(bigint)"), + "array_constructor(x, y)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + private static LocalQueryRunner createTestQueryRunner() + { + LocalQueryRunner queryRunner = new LocalQueryRunner(testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .build(), + new FeaturesConfig(), + new FunctionsConfig().setDefaultNamespacePrefix("dummy.unittest")); + queryRunner.createCatalog("tpch", new TpchConnectorFactory(), ImmutableMap.of()); + queryRunner.getMetadata().getFunctionAndTypeManager().addFunctionNamespace( + "dummy", + new InMemoryFunctionNamespaceManager( + "dummy", + new SqlFunctionExecutors( + ImmutableMap.of( + CPP, FunctionImplementationType.CPP, + JAVA, FunctionImplementationType.JAVA), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("cpp"))); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_FOO, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_BAZ, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_BAR, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_FEE, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(NOT, true); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_ARRAY_CONSTRUCTOR, true); + parseFunctionDefinitions(BigintOperators.class).stream() + .map(TestAddExchangesPlansWithFunctions::convertToSqlInvokedFunction) + .forEach(function -> queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(function, true)); + parseFunctionDefinitions(CombineHashFunction.class).stream() + .map(TestAddExchangesPlansWithFunctions::convertToSqlInvokedFunction) + .forEach(function -> queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(function, true)); + return queryRunner; + } + + public static SqlInvokedFunction convertToSqlInvokedFunction(SqlScalarFunction scalarFunction) + { + QualifiedObjectName functionName = new QualifiedObjectName("dummy", "unittest", scalarFunction.getSignature().getName().getObjectName()); + TypeSignature returnType = scalarFunction.getSignature().getReturnType(); + RoutineCharacteristics characteristics = RoutineCharacteristics.builder() + .setLanguage(RoutineCharacteristics.Language.JAVA) // Assuming JAVA as the language + .setDeterminism(RoutineCharacteristics.Determinism.DETERMINISTIC) + .setNullCallClause(RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT) + .build(); + + // Convert scalar function arguments to SqlInvokedFunction parameters + ImmutableList parameters = scalarFunction.getSignature().getArgumentTypes().stream() + .map(type -> new Parameter(type.toString(), TypeSignature.parseTypeSignature(type.toString()))) + .collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf)); + + // Create the SqlInvokedFunction + return new SqlInvokedFunction( + functionName, + parameters, + returnType, + scalarFunction.getSignature().getName().toString(), // Using the function name as the body for simplicity + characteristics, + "", // Empty description + notVersioned()); + } + + @Test + public void testFilterWithCppFunctionDoesNotGetPushedIntoSystemTableScan() + { + // java_fee and java_bar are java functions, they are both pushed down past the exchange + assertNativeDistributedPlan("SELECT java_fee(ordinal_position) FROM information_schema.columns WHERE java_bar(ordinal_position) = 1", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("java_fee", expression("java_fee(ordinal_position)")), + filter("java_bar(ordinal_position) = BIGINT'1'", + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))); + // cpp_foo is a CPP function, it is not pushed down past the exchange because the source is a system table scan + // The filter is preserved above the exchange to prove that the filter is not in the same fragment as the system table scan + assertNativeDistributedPlan("SELECT cpp_baz(ordinal_position) FROM information_schema.columns WHERE cpp_foo(ordinal_position) = 1", + anyTree( + project(ImmutableMap.of("cpp_baz", expression("cpp_baz(ordinal_position)")), + filter("cpp_foo(ordinal_position) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))); + } + + @Test + public void testJoinWithCppFunctionDoesNotGetPushedIntoSystemTableScan() + { + // java_bar is a java function, it is pushed down past the exchange + assertNativeDistributedPlan( + "SELECT c1.table_name FROM information_schema.columns c1 JOIN information_schema.columns c2 ON c1.ordinal_position = c2.ordinal_position WHERE java_bar(c1.ordinal_position) = 1", + anyTree( + exchange( + join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "ordinal_position_4")), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project( + filter("java_bar(ordinal_position) = BIGINT'1'", + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name")))))), + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project( + filter("java_bar(ordinal_position_4) = BIGINT'1'", + tableScan("columns", ImmutableMap.of( + "ordinal_position_4", "ordinal_position")))))))))); + + // cpp_foo is a CPP function, it is not pushed down past the exchange because the source is a system table scan + assertNativeDistributedPlan( + "SELECT cpp_baz(c1.ordinal_position) FROM information_schema.columns c1 JOIN information_schema.columns c2 ON c1.ordinal_position = c2.ordinal_position WHERE cpp_foo(c1.ordinal_position) = 1", + output( + exchange( + project(ImmutableMap.of("cpp_baz", expression("cpp_baz(ordinal_position)")), + join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "ordinal_position_4")), + anyTree( + filter("cpp_foo(ordinal_position) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + project( + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position")))))), + anyTree( + filter("cpp_foo(ordinal_position_4) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + project( + tableScan("columns", ImmutableMap.of( + "ordinal_position_4", "ordinal_position"))))))))))); + } + + @Test + public void testMixedFunctionTypesInComplexPredicates() + { + // Test AND condition with mixed Java and CPP functions + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE java_bar(ordinal_position) = 1 AND cpp_foo(ordinal_position) > 0", + anyTree( + filter("java_bar(ordinal_position) = BIGINT'1' AND cpp_foo(ordinal_position) > BIGINT'0'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Test OR condition with mixed functions - entire predicate should be evaluated after exchange + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE java_bar(ordinal_position) = 1 OR cpp_foo(ordinal_position) = 2", + anyTree( + filter("java_bar(ordinal_position) = BIGINT'1' OR cpp_foo(ordinal_position) = BIGINT'2'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testNestedFunctionCalls() + { + // CPP function nested inside Java function - should not push down + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE java_bar(cpp_foo(ordinal_position)) = 1", + anyTree( + filter("java_bar(cpp_foo(ordinal_position)) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Java function nested inside CPP function - should not push down + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(java_bar(ordinal_position)) = 1", + anyTree( + filter("cpp_foo(java_bar(ordinal_position)) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Multiple levels of nesting + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(java_bar(cpp_foo(ordinal_position))) = 1", + anyTree( + filter("cpp_foo(java_bar(cpp_foo(ordinal_position))) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testMixedSystemAndRegularTables() + { + // System table with CPP function joined with regular table + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns c JOIN nation n ON c.ordinal_position = n.nationkey WHERE cpp_foo(c.ordinal_position) = 1", + output( + join(INNER, ImmutableList.of(equiJoinClause("ordinal_position", "nationkey")), + filter("cpp_foo(ordinal_position) = BIGINT'1'", + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("ordinal_position", expression("ordinal_position")), + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))), + anyTree( + project(ImmutableMap.of("nationkey", expression("nationkey")), + filter( + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))))); + + // Regular table with CPP function (should work normally without extra exchange) + assertNativeDistributedPlan( + "SELECT * FROM nation WHERE cpp_foo(nationkey) = 1", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + filter("cpp_foo(nationkey) = BIGINT'1'", + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testAggregationsWithMixedFunctions() + { + // Aggregation with CPP function in GROUP BY + assertNativeDistributedPlan( + "SELECT DISTINCT cpp_foo(ordinal_position) FROM information_schema.columns", + anyTree( + project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // Aggregation with Java function in GROUP BY - can be pushed down + assertNativeDistributedPlan( + "SELECT DISTINCT java_bar(ordinal_position) FROM information_schema.columns", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + project(ImmutableMap.of("java_bar", expression("java_bar(ordinal_position)")), + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testComplexPredicateWithMultipleFunctions() + { + // Complex predicate with multiple CPP and Java functions + // Since the predicate contains CPP functions (cpp_foo, baz), the exchange is inserted before the system table scan + // The RemoveRedundantExchanges rule removes the inner exchange that was added by ExtractIneligiblePredicatesFromSystemTableScans + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE (cpp_foo(ordinal_position) > 0 AND java_bar(ordinal_position) < 100) OR cpp_baz(ordinal_position) = 50", + anyTree( + filter( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testProjectionWithMixedFunctions() + { + // Projection with both Java and CPP functions + assertNativeDistributedPlan( + "SELECT java_bar(ordinal_position) as java_result, cpp_foo(ordinal_position) as cpp_result FROM information_schema.columns", + anyTree( + project(ImmutableMap.of( + "java_result", expression("java_bar(ordinal_position)"), + "cpp_result", expression("cpp_foo(ordinal_position)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testCaseStatementsWithCppFunctions() + { + // CASE statement with CPP function in condition + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT CASE WHEN cpp_foo(ordinal_position) > 0 THEN 'positive' ELSE 'negative' END FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // CASE statement with CPP function in result + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT CASE WHEN ordinal_position > 0 THEN cpp_foo(ordinal_position) ELSE 0 END FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testBuiltinFunctionWithExplicitNamespace() + { + // Test that built-in functions with explicit namespace are handled correctly + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE presto.default.length(table_name) > 10", + anyTree( + exchange(REMOTE_STREAMING, GATHER, + filter("length(table_name) > BIGINT'10'", + tableScan("columns", ImmutableMap.of("table_name", "table_name")))))); + } + + @Test(enabled = false) // TODO: Window functions are resolved with namespace which causes issues in tests + public void testWindowFunctionsWithCppFunctions() + { + // Window function with CPP function in partition by + assertNativeDistributedPlan( + "SELECT row_number() OVER (PARTITION BY cpp_foo(ordinal_position)) FROM information_schema.columns", + anyTree( + exchange( + project( + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))))); + + // Window function with CPP function in order by + assertNativeDistributedPlan( + "SELECT row_number() OVER (ORDER BY cpp_foo(ordinal_position)) FROM information_schema.columns", + anyTree( + exchange( + project( + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position"))))))))); + } + + @Test + public void testMultipleSystemTableJoins() + { + // Multiple system tables with CPP functions + // This test verifies that when joining two system tables with a CPP function comparison, + // an exchange is added between the table scan and the join to ensure CPP functions + // execute in a separate fragment from system table access + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns c1 " + + "JOIN information_schema.columns c2 ON cpp_foo(c1.ordinal_position) = cpp_foo(c2.ordinal_position)", + anyTree( + exchange( + join(INNER, ImmutableList.of(equiJoinClause("cpp_foo", "foo_4")), + exchange( + project(ImmutableMap.of("cpp_foo", expression("cpp_foo")), + project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))), + anyTree( + exchange( + project(ImmutableMap.of("foo_4", expression("foo_4")), + project(ImmutableMap.of("foo_4", expression("cpp_foo(ordinal_position_4)")), + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position_4", "ordinal_position"))))))))))); + } + + @Test + public void testInPredicateWithCppFunction() + { + // IN predicate with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IN (1, 2, 3)", + anyTree( + filter("cpp_foo(ordinal_position) IN (BIGINT'1', BIGINT'2', BIGINT'3')", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testBetweenPredicateWithCppFunction() + { + // BETWEEN predicate with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) BETWEEN 1 AND 10", + anyTree( + filter("cpp_foo(ordinal_position) BETWEEN BIGINT'1' AND BIGINT'10'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testNullHandlingWithCppFunctions() + { + // IS NULL check with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IS NULL", + anyTree( + filter("cpp_foo(ordinal_position) IS NULL", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + + // COALESCE with CPP function + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT COALESCE(cpp_foo(ordinal_position), 0) FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testUnionWithCppFunctions() + { + // UNION ALL with CPP functions from system tables + assertNativeDistributedPlan( + "SELECT cpp_foo(ordinal_position) FROM information_schema.columns " + + "UNION ALL SELECT cpp_foo(nationkey) FROM nation", + output( + exchange( + anyTree( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))), + anyTree( + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testExistsSubqueryWithCppFunction() + { + // EXISTS subquery with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns c WHERE EXISTS (SELECT 1 FROM nation n WHERE cpp_foo(c.ordinal_position) = n.nationkey)", + anyTree( + join( + anyTree( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))), + anyTree( + tableScan("nation", ImmutableMap.of("nationkey", "nationkey")))))); + } + + @Test + public void testLimitWithCppFunction() + { + // LIMIT with CPP function in ORDER BY + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns ORDER BY cpp_foo(ordinal_position) LIMIT 10", + output( + project( + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))))); + } + + @Test + public void testCastOperationsWithCppFunctions() + { + // CAST operations with CPP functions + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT CAST(cpp_foo(ordinal_position) AS VARCHAR) FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testArrayConstructorWithCppFunction() + { + // Array constructor with CPP function + assertNativeDistributedPlan( + "SELECT ARRAY[cpp_foo(ordinal_position), cpp_baz(ordinal_position)] FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testRowConstructorWithCppFunction() + { + // ROW constructor with CPP function + // The RemoveRedundantExchanges optimizer removes the redundant exchange + assertNativeDistributedPlan( + "SELECT ROW(cpp_foo(ordinal_position), table_name) FROM information_schema.columns", + anyTree( + project( + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name")))))); + } + + @Test + public void testIsNotNullWithCppFunction() + { + // IS NOT NULL check with CPP function + assertNativeDistributedPlan( + "SELECT * FROM information_schema.columns WHERE cpp_foo(ordinal_position) IS NOT NULL", + anyTree( + filter("cpp_foo(ordinal_position) IS NOT NULL", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of("ordinal_position", "ordinal_position")))))); + } + + @Test + public void testComplexJoinWithMultipleCppFunctions() + { + // Complex join with multiple CPP functions in different positions + // The filters are pushed into FilterProject nodes and the join happens on the expression cpp_foo(c1.ordinal_position) + assertNativeDistributedPlan( + "SELECT c1.table_name, n.name FROM information_schema.columns c1 " + + "JOIN nation n ON cpp_foo(c1.ordinal_position) = n.nationkey " + + "WHERE cpp_baz(c1.ordinal_position) > 0 AND cpp_foo(n.nationkey) < 100", + anyTree( + join(INNER, ImmutableList.of(equiJoinClause("cpp_foo", "nationkey")), + project(ImmutableMap.of("table_name", expression("table_name"), "cpp_foo", expression("cpp_foo")), + project(ImmutableMap.of("cpp_foo", expression("cpp_foo(ordinal_position)")), + filter("cpp_baz(ordinal_position) > BIGINT'0' AND cpp_foo(cpp_foo(ordinal_position)) < BIGINT'100'", + exchange(REMOTE_STREAMING, GATHER, + tableScan("columns", ImmutableMap.of( + "ordinal_position", "ordinal_position", + "table_name", "table_name")))))), + anyTree( + project(ImmutableMap.of("nationkey", expression("nationkey"), + "name", expression("name")), + filter("cpp_foo(nationkey) < BIGINT'100'", + tableScan("nation", ImmutableMap.of( + "nationkey", "nationkey", + "name", "name")))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java new file mode 100644 index 0000000000000..106b526a6b75e --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/sanity/TestCheckNoIneligibleFunctionsInCoordinatorFragments.java @@ -0,0 +1,474 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.sanity; + +import com.facebook.presto.Session; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig; +import com.facebook.presto.functionNamespace.execution.NoopSqlFunctionExecutor; +import com.facebook.presto.functionNamespace.execution.SqlFunctionExecutors; +import com.facebook.presto.functionNamespace.testing.InMemoryFunctionNamespaceManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.TestingColumnHandle; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.spi.function.FunctionImplementationType; +import com.facebook.presto.spi.function.Parameter; +import com.facebook.presto.spi.function.RoutineCharacteristics; +import com.facebook.presto.spi.function.SqlInvokedFunction; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Partitioning; +import com.facebook.presto.spi.plan.PartitioningScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestingMetadata.TestingTableHandle; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.function.Function; + +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Determinism.DETERMINISTIC; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.CPP; +import static com.facebook.presto.spi.function.RoutineCharacteristics.Language.JAVA; +import static com.facebook.presto.spi.function.RoutineCharacteristics.NullCallClause.RETURNS_NULL_ON_NULL_INPUT; +import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class TestCheckNoIneligibleFunctionsInCoordinatorFragments + extends BasePlanTest +{ + // CPP function for testing (similar to TestAddExchangesPlansWithFunctions) + private static final SqlInvokedFunction CPP_FUNC = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "cpp_func"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.VARCHAR))), + parseTypeSignature(StandardTypes.VARCHAR), + "cpp_func(x)", + RoutineCharacteristics.builder().setLanguage(CPP).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + // JAVA function for testing + private static final SqlInvokedFunction JAVA_FUNC = new SqlInvokedFunction( + new QualifiedObjectName("dummy", "unittest", "java_func"), + ImmutableList.of(new Parameter("x", parseTypeSignature(StandardTypes.VARCHAR))), + parseTypeSignature(StandardTypes.VARCHAR), + "java_func(x)", + RoutineCharacteristics.builder().setLanguage(JAVA).setDeterminism(DETERMINISTIC).setNullCallClause(RETURNS_NULL_ON_NULL_INPUT).build(), + "", + notVersioned()); + + public TestCheckNoIneligibleFunctionsInCoordinatorFragments() + { + super(TestCheckNoIneligibleFunctionsInCoordinatorFragments::createTestQueryRunner); + } + + private static LocalQueryRunner createTestQueryRunner() + { + LocalQueryRunner queryRunner = new LocalQueryRunner( + testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .build(), + new FeaturesConfig().setNativeExecutionEnabled(true), + new FunctionsConfig().setDefaultNamespacePrefix("dummy.unittest")); + + queryRunner.createCatalog("local", new TpchConnectorFactory(), ImmutableMap.of()); + + // Add function namespace with both CPP and JAVA functions + queryRunner.getMetadata().getFunctionAndTypeManager().addFunctionNamespace( + "dummy", + new InMemoryFunctionNamespaceManager( + "dummy", + new SqlFunctionExecutors( + ImmutableMap.of( + CPP, FunctionImplementationType.CPP, + JAVA, FunctionImplementationType.JAVA), + new NoopSqlFunctionExecutor()), + new SqlInvokedFunctionNamespaceManagerConfig().setSupportedFunctionLanguages("CPP,JAVA"))); + + // Register the functions + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(CPP_FUNC, false); + queryRunner.getMetadata().getFunctionAndTypeManager().createFunction(JAVA_FUNC, false); + + return queryRunner; + } + + @Test + public void testSystemTableScanWithJavaFunctionPasses() + { + // System table scan with Java function in same fragment should pass + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // Create a system table scan - using proper system connector ID + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode tableScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Java function (using our registered java_func) + return p.project( + assignment(result, p.rowExpression("java_func(col)")), + tableScan); + }); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*") + public void testSystemTableScanWithCppFunctionInProjectFails() + { + // System table scan with C++ function in same fragment should fail + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // C++ function (using our registered cpp_func) + return p.project( + assignment(result, p.rowExpression("cpp_func(col)")), + systemScan); + }); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*") + public void testSystemTableScanWithCppFunctionInFilterFails() + { + // System table scan with C++ function in filter should fail + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Filter with C++ function + return p.filter( + p.rowExpression("cpp_func(col) = 'test'"), + systemScan); + }); + } + + @Test + public void testSystemTableScanWithCppFunctionSeparatedByExchangePasses() + { + // System table scan and C++ function separated by exchange should pass + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Exchange creates fragment boundary + PartitioningScheme partitioningScheme = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col)); + + ExchangeNode exchange = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme, + ImmutableList.of(systemScan), + ImmutableList.of(ImmutableList.of(col)), + false, + Optional.empty()); + + // C++ function in different fragment + return p.project( + assignment(result, p.rowExpression("cpp_func(col)")), + exchange); + }); + } + + @Test + public void testRegularTableScanWithCppFunctionPasses() + { + // Regular table scan with C++ function should pass (no system table) + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression result = p.variable("result", VARCHAR); + + // Regular table scan (not system) + TableHandle regularTableHandle = new TableHandle( + new ConnectorId("local"), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode regularScan = p.tableScan( + regularTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // C++ function + return p.project( + assignment(result, p.rowExpression("cpp_func(col)")), + regularScan); + }); + } + + @Test(expectedExceptions = IllegalStateException.class, + expectedExceptionsMessageRegExp = "Fragment contains both system table scan and non-Java functions.*") + public void testMultipleFragmentsWithCppFunctionInSystemFragment() + { + // Complex plan where CPP function is in same fragment as system table scan (should fail) + validatePlan( + p -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", BIGINT); + VariableReferenceExpression col3 = p.variable("col3", BIGINT); + + // Fragment 1: System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col1), + ImmutableMap.of(col1, new TestingColumnHandle("col1"))); + + // Convert to numeric for join (using CPP function - this should fail) + PlanNode project1 = p.project( + assignment(col2, p.rowExpression("cast(cpp_func(col1) as bigint)")), + systemScan); + + // Fragment 2: Regular values with computation + PlanNode values = p.values(col3); + + // Exchange to separate fragments + PartitioningScheme partitioningScheme1 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col2)); + + ExchangeNode exchange1 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme1, + ImmutableList.of(project1), + ImmutableList.of(ImmutableList.of(col2)), + false, + Optional.empty()); + + PartitioningScheme partitioningScheme2 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col3)); + + ExchangeNode exchange2 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme2, + ImmutableList.of(values), + ImmutableList.of(ImmutableList.of(col3)), + false, + Optional.empty()); + + // Join the results + return p.join( + JoinType.INNER, + exchange1, + exchange2, + p.rowExpression("col2 = col3")); + }); + } + + @Test + public void testMultipleFragmentsWithExchange() + { + // Complex plan with multiple fragments properly separated (Java function - should pass) + validatePlan( + p -> { + VariableReferenceExpression col1 = p.variable("col1", VARCHAR); + VariableReferenceExpression col2 = p.variable("col2", BIGINT); + VariableReferenceExpression col3 = p.variable("col3", BIGINT); + + // Fragment 1: System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col1), + ImmutableMap.of(col1, new TestingColumnHandle("col1"))); + + // Convert to numeric for join (using Java function) + PlanNode project1 = p.project( + assignment(col2, p.rowExpression("cast(java_func(col1) as bigint)")), + systemScan); + + // Fragment 2: Regular values with computation + PlanNode values = p.values(col3); + + // Exchange to separate fragments + PartitioningScheme partitioningScheme1 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col2)); + + ExchangeNode exchange1 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme1, + ImmutableList.of(project1), + ImmutableList.of(ImmutableList.of(col2)), + false, + Optional.empty()); + + PartitioningScheme partitioningScheme2 = new PartitioningScheme( + Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), + ImmutableList.of(col3)); + + ExchangeNode exchange2 = new ExchangeNode( + Optional.empty(), + p.getIdAllocator().getNextId(), + ExchangeNode.Type.GATHER, + ExchangeNode.Scope.LOCAL, + partitioningScheme2, + ImmutableList.of(values), + ImmutableList.of(ImmutableList.of(col3)), + false, + Optional.empty()); + + // Join the results + return p.join( + JoinType.INNER, + exchange1, + exchange2, + p.rowExpression("col2 = col3")); + }); + } + + @Test + public void testFilterAndProjectWithSystemTable() + { + // Test filter and project both with Java functions on system table + validatePlan( + p -> { + VariableReferenceExpression col = p.variable("col", VARCHAR); + VariableReferenceExpression len = p.variable("len", BIGINT); + + // System table scan + TableHandle systemTableHandle = new TableHandle( + ConnectorId.createSystemTablesConnectorId(new ConnectorId("local")), + new TestingTableHandle(), + TestingTransactionHandle.create(), + Optional.empty()); + + PlanNode systemScan = p.tableScan( + systemTableHandle, + ImmutableList.of(col), + ImmutableMap.of(col, new TestingColumnHandle("col"))); + + // Filter with Java function + PlanNode filtered = p.filter( + p.rowExpression("java_func(col) = 'test'"), + systemScan); + + // Project with Java function + return p.project( + assignment(len, p.rowExpression("cast(java_func(col) as bigint)")), + filtered); + }); + } + + private void validatePlan(Function planProvider) + { + Session session = testSessionBuilder() + .setCatalog("local") + .setSchema("tiny") + .build(); + + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + Metadata metadata = getQueryRunner().getMetadata(); + PlanBuilder builder = new PlanBuilder(TEST_SESSION, idAllocator, metadata); + PlanNode planNode = planProvider.apply(builder); + + getQueryRunner().inTransaction(session, transactionSession -> { + new CheckNoIneligibleFunctionsInCoordinatorFragments().validate(planNode, transactionSession, metadata, WarningCollector.NOOP); + return null; + }); + } +} diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 277a3126aab40..07e0f4828ceb0 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -267,13 +267,11 @@ io.projectreactor.netty reactor-netty-core - 1.1.29 io.projectreactor.netty reactor-netty-http - 1.1.29 diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp index f67df7f111f23..e48a1aa0acc8d 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp @@ -70,7 +70,7 @@ void updateFromSystemConfigs( const auto& systemConfigName = configNameEntry.second; if (queryConfigs.count(veloxConfigName) == 0) { const auto propertyOpt = systemConfig->optionalProperty(systemConfigName); - if (propertyOpt.hasValue()) { + if (propertyOpt.has_value()) { queryConfigs[veloxConfigName] = propertyOpt.value(); } } diff --git a/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp b/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp index 027193a2f045a..b10d9962c0bb9 100644 --- a/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp +++ b/presto-native-execution/presto_cpp/main/common/ConfigReader.cpp @@ -87,7 +87,7 @@ std::string requiredProperty( const velox::config::ConfigBase& properties, const std::string& name) { auto value = properties.get(name); - if (!value.hasValue()) { + if (!value.has_value()) { VELOX_USER_FAIL("Missing configuration property {}", name); } return value.value(); @@ -120,7 +120,7 @@ std::string getOptionalProperty( const std::string& name, const std::string& defaultValue) { auto value = properties.get(name); - if (!value.hasValue()) { + if (!value.has_value()) { return defaultValue; } return value.value(); diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 893e21f224d0c..91a47abccb551 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -110,7 +110,7 @@ folly::Optional ConfigBase::setValue( propertyName); auto oldValue = config_->get(propertyName); config_->set(propertyName, value); - if (oldValue.hasValue()) { + if (oldValue.has_value()) { return oldValue; } return registeredProps_[propertyName]; @@ -372,7 +372,7 @@ SystemConfig::remoteFunctionServerLocation() const { // First check if there is a UDS path registered. If there's one, use it. auto remoteServerUdsPath = optionalProperty(kRemoteFunctionServerThriftUdsPath); - if (remoteServerUdsPath.hasValue()) { + if (remoteServerUdsPath.has_value()) { return folly::SocketAddress::makeFromPath(remoteServerUdsPath.value()); } @@ -382,13 +382,13 @@ SystemConfig::remoteFunctionServerLocation() const { auto remoteServerPort = optionalProperty(kRemoteFunctionServerThriftPort); - if (remoteServerPort.hasValue()) { + if (remoteServerPort.has_value()) { // Fallback to localhost if address is not specified. - return remoteServerAddress.hasValue() + return remoteServerAddress.has_value() ? folly:: SocketAddress{remoteServerAddress.value(), remoteServerPort.value()} : folly::SocketAddress{"::1", remoteServerPort.value()}; - } else if (remoteServerAddress.hasValue()) { + } else if (remoteServerAddress.has_value()) { VELOX_FAIL( "Remote function server port not provided using '{}'.", kRemoteFunctionServerThriftPort); @@ -959,7 +959,7 @@ int NodeConfig::prometheusExecutorThreads() const { static constexpr int kNodePrometheusExecutorThreadsDefault = 2; auto resultOpt = optionalProperty(kNodePrometheusExecutorThreads); - if (resultOpt.hasValue()) { + if (resultOpt.has_value()) { return resultOpt.value(); } return kNodePrometheusExecutorThreadsDefault; @@ -967,7 +967,7 @@ int NodeConfig::prometheusExecutorThreads() const { std::string NodeConfig::nodeId() const { auto resultOpt = optionalProperty(kNodeId); - if (resultOpt.hasValue()) { + if (resultOpt.has_value()) { return resultOpt.value(); } // Generate the nodeId which must be a UUID. nodeId must be a singleton. @@ -985,7 +985,7 @@ std::string NodeConfig::nodeInternalAddress( auto resultOpt = optionalProperty(kNodeInternalAddress); /// node.ip(kNodeIp) is legacy config replaced with node.internal-address, but /// still valid config in Presto, so handling both. - if (!resultOpt.hasValue()) { + if (!resultOpt.has_value()) { resultOpt = optionalProperty(kNodeIp); } if (resultOpt.has_value()) { diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index 6512e9e47046f..148469e0c351b 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -95,7 +95,7 @@ class ConfigBase { template folly::Optional optionalProperty(const std::string& propertyName) const { auto valOpt = config_->get(propertyName); - if (valOpt.hasValue()) { + if (valOpt.has_value()) { return valOpt.value(); } const auto it = registeredProps_.find(propertyName); @@ -115,7 +115,7 @@ class ConfigBase { folly::Optional optionalProperty( const std::string& propertyName) const { auto val = config_->get(propertyName); - if (val.hasValue()) { + if (val.has_value()) { return val; } const auto it = registeredProps_.find(propertyName); diff --git a/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile b/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile index 207270d69f970..e67770e78a0b8 100644 --- a/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile +++ b/presto-native-execution/scripts/dockerfiles/centos-dependency.dockerfile @@ -25,9 +25,10 @@ COPY velox/CMake/resolve_dependency_modules/arrow/cmake-compatibility.patch /vel ENV VELOX_ARROW_CMAKE_PATCH=/velox/cmake-compatibility.patch RUN bash -c "mkdir build && \ (cd build && ../scripts/setup-centos.sh && \ - ../velox/scripts/setup-centos9.sh install_adapters && \ ../scripts/setup-adapters.sh && \ source ../velox/scripts/setup-centos9.sh && \ + source ../velox/scripts/setup-centos-adapters.sh && \ + install_adapters && \ install_clang15 && \ install_cuda 12.8) && \ rm -rf build" diff --git a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java index 2e25e413e3b21..31bd66b36092c 100644 --- a/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java +++ b/presto-native-sidecar-plugin/src/test/java/com/facebook/presto/sidecar/TestNativeSidecarPlugin.java @@ -318,9 +318,8 @@ public void testApproxPercentile() @Test public void testInformationSchemaTables() { - assertQueryFails("select lower(table_name) from information_schema.tables " - + "where table_name = 'lineitem' or table_name = 'LINEITEM' ", - "Compiler failed"); + assertQuery("select lower(table_name) from information_schema.tables " + + "where table_name = 'lineitem' or table_name = 'LINEITEM' "); } @Test diff --git a/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java b/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java index fa48afa418524..3317d07e90f8c 100644 --- a/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java +++ b/presto-product-tests/src/main/java/com/facebook/presto/tests/SystemConnectorTests.java @@ -13,14 +13,20 @@ */ package com.facebook.presto.tests; +import com.google.common.collect.ImmutableList; import io.prestodb.tempto.ProductTest; +import io.prestodb.tempto.assertions.QueryAssert; import org.testng.annotations.Test; import java.sql.JDBCType; +import java.util.List; +import java.util.stream.Collectors; import static com.facebook.presto.tests.TestGroups.JDBC; import static com.facebook.presto.tests.TestGroups.SYSTEM_CONNECTOR; import static com.facebook.presto.tests.utils.JdbcDriverUtils.usingTeradataJdbcDriver; +import static com.facebook.presto.tests.utils.QueryExecutors.onPresto; +import static io.prestodb.tempto.assertions.QueryAssert.Row.row; import static io.prestodb.tempto.assertions.QueryAssert.assertThat; import static io.prestodb.tempto.query.QueryExecutor.defaultQueryExecutor; import static io.prestodb.tempto.query.QueryExecutor.query; @@ -107,4 +113,41 @@ public void selectMetadataCatalogs() .hasColumns(VARCHAR, VARCHAR) .hasAnyRows(); } + + @Test(groups = SYSTEM_CONNECTOR) + public void selectJdbcColumns() + { + try { + String hiveSQL = "select table_name, column_name from system.jdbc.columns where table_cat = 'hive' AND table_schem = 'default'"; + String icebergSQL = "select table_name, column_name from system.jdbc.columns where table_cat = 'iceberg' AND table_schem = 'default'"; + + List preexistingHiveColumns = onPresto().executeQuery(hiveSQL).rows().stream() + .map(list -> row(list.toArray())) + .collect(Collectors.toList()); + + List preexistingIcebergColumns = onPresto().executeQuery(icebergSQL).rows().stream() + .map(list -> row(list.toArray())) + .collect(Collectors.toList()); + + onPresto().executeQuery("CREATE TABLE hive.default.test_hive_system_jdbc_columns (_double DOUBLE)"); + onPresto().executeQuery("CREATE TABLE iceberg.default.test_iceberg_system_jdbc_columns (_string VARCHAR, _integer INTEGER)"); + + assertThat(onPresto().executeQuery(hiveSQL)) + .containsOnly(ImmutableList.builder() + .addAll(preexistingHiveColumns) + .add(row("test_hive_system_jdbc_columns", "_double")) + .build()); + + assertThat(onPresto().executeQuery(icebergSQL)) + .containsOnly(ImmutableList.builder() + .addAll(preexistingIcebergColumns) + .add(row("test_iceberg_system_jdbc_columns", "_string")) + .add(row("test_iceberg_system_jdbc_columns", "_integer")) + .build()); + } + finally { + onPresto().executeQuery("DROP TABLE IF EXISTS hive.default.test_hive_system_jdbc_columns"); + onPresto().executeQuery("DROP TABLE IF EXISTS iceberg.default.test_iceberg_system_jdbc_columns"); + } + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java index 2da8ad1970eec..0164dd10c3af1 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java @@ -153,4 +153,9 @@ default Iterable getClientRequestFilterFactories() { return emptyList(); } + + default Set> getSqlInvokedFunctions() + { + return emptySet(); + } } diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestDuplicateSqlInvokedFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestDuplicateSqlInvokedFunctions.java new file mode 100644 index 0000000000000..e3df08e051039 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestDuplicateSqlInvokedFunctions.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functions; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.TypeParameter; + +public final class TestDuplicateSqlInvokedFunctions +{ + private TestDuplicateSqlInvokedFunctions() {} + + @SqlInvokedScalarFunction(value = "array_intersect", deterministic = true, calledOnNullInput = false) + @Description("Intersects elements of all arrays in the given array") + @TypeParameter("T") + @SqlParameter(name = "input", type = "array>") + @SqlType("array") + public static String arrayIntersectArray() + { + return "RETURN reduce(input, IF((cardinality(input) = 0), ARRAY[], input[1]), (s, x) -> array_intersect(s, x), (s) -> s)"; + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestFunctions.java new file mode 100644 index 0000000000000..532c12acbc65a --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestFunctions.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functions; + +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.SqlType; +import io.airlift.slice.Slice; + +public final class TestFunctions +{ + private TestFunctions() + {} + + @Description("Returns modulo of value by numberOfBuckets") + @ScalarFunction + @SqlType(StandardTypes.BIGINT) + public static long modulo( + @SqlType(StandardTypes.BIGINT) long value, + @SqlType(StandardTypes.BIGINT) long numberOfBuckets) + { + return value % numberOfBuckets; + } + + @Description(("Return the input string")) + @ScalarFunction + @SqlType(StandardTypes.VARCHAR) + public static Slice identity(@SqlType(StandardTypes.VARCHAR) Slice slice) + { + return slice; + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedDuplicateSqlInvokedFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedDuplicateSqlInvokedFunctions.java new file mode 100644 index 0000000000000..4ab7e80a82906 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedDuplicateSqlInvokedFunctions.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functions; + +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.tests.TestingPrestoClient; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Set; +import java.util.regex.Pattern; + +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; +import static org.testng.Assert.fail; + +public class TestPluginLoadedDuplicateSqlInvokedFunctions +{ + protected TestingPrestoServer server; + protected TestingPrestoClient client; + + @BeforeClass + public void setup() + throws Exception + { + server = new TestingPrestoServer(); + server.installPlugin(new TestDuplicateFunctionsPlugin()); + client = new TestingPrestoClient(server, testSessionBuilder() + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("America/Bahia_Banderas")) + .build()); + } + + public void assertInvalidFunction(String expr, String exceptionPattern) + { + try { + client.execute("SELECT " + expr); + fail("Function expected to fail but not"); + } + catch (Exception e) { + if (!(e.getMessage().matches(exceptionPattern))) { + fail(format("Expected exception message '%s' to match '%s' but not", + e.getMessage(), exceptionPattern)); + } + } + } + + private static class TestDuplicateFunctionsPlugin + implements Plugin + { + @Override + public Set> getSqlInvokedFunctions() + { + return ImmutableSet.>builder() + .add(TestDuplicateSqlInvokedFunctions.class) + .build(); + } + } + + @Test + public void testDuplicateFunctionsLoaded() + { + assertInvalidFunction(JAVA_BUILTIN_NAMESPACE + ".modulo(10,3)", + Pattern.quote(format("java.lang.IllegalArgumentException: Function already registered: %s.array_intersect(array(array(T))):array(T)", JAVA_BUILTIN_NAMESPACE))); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedSqlInvokedFunctions.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedSqlInvokedFunctions.java new file mode 100644 index 0000000000000..233d04e737603 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestPluginLoadedSqlInvokedFunctions.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functions; + +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.tests.TestingPrestoClient; +import com.google.common.collect.ImmutableSet; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Set; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static java.lang.String.format; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.fail; + +public class TestPluginLoadedSqlInvokedFunctions +{ + protected TestingPrestoServer server; + protected TestingPrestoClient client; + + private static final String CATALOG_NAME = JAVA_BUILTIN_NAMESPACE.getCatalogName(); + + @BeforeClass + public void setup() + throws Exception + { + server = new TestingPrestoServer(); + server.installPlugin(new TestFunctionsPlugin()); + client = new TestingPrestoClient(server, testSessionBuilder() + .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("America/Bahia_Banderas")) + .build()); + } + + public void assertInvalidFunction(String expr, String exceptionPattern) + { + try { + client.execute("SELECT " + expr); + fail("Function expected to fail but not"); + } + catch (Exception e) { + if (!(e.getMessage().matches(exceptionPattern))) { + fail(format("Expected exception message '%s' to match '%s' but not", + e.getMessage(), exceptionPattern)); + } + } + } + + private static class TestFunctionsPlugin + implements Plugin + { + @Override + public Set> getSqlInvokedFunctions() + { + return ImmutableSet.>builder() + .add(TestSqlInvokedFunctionsPlugin.class) + .build(); + } + + @Override + public Set> getFunctions() + { + return ImmutableSet.>builder() + .add(TestFunctions.class) + .build(); + } + } + + public void check(@Language("SQL") String query, Type expectedType, Object expectedValue) + { + MaterializedResult result = client.execute(query).getResult(); + assertEquals(result.getRowCount(), 1); + assertEquals(result.getTypes().get(0), expectedType); + Object actual = result.getMaterializedRows().get(0).getField(0); + assertEquals(actual, expectedValue); + } + + @Test + public void testNewFunctionNamespaceFunction() + { + check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".modulo(10,3)", BIGINT, 1L); + check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".identity('test-functions')", VARCHAR, "test-functions"); + check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".custom_square(2, 3)", INTEGER, 4); + check("SELECT " + JAVA_BUILTIN_NAMESPACE + ".custom_square(null, 3)", INTEGER, 9); + } + + @Test + public void testInvalidFunctionAndNamespace() + { + assertInvalidFunction(CATALOG_NAME + ".namespace.modulo(10,3)", format("line 1:8: Function %s.namespace.modulo not registered", CATALOG_NAME)); + assertInvalidFunction(CATALOG_NAME + ".system.some_func(10)", format("line 1:8: Function %s.system.some_func not registered", CATALOG_NAME)); + } + + @Test(dependsOnMethods = + {"testNewFunctionNamespaceFunction", + "testInvalidFunctionAndNamespace"}) + public void testDuplicateFunctionsLoaded() + { + // Because we trigger the conflict check as soon as the plugins are loaded, + // this will throw an Exception: Function already registered: presto.default.modulo(bigint,bigint):bigint while installing the plugin itself + assertThrows(IllegalArgumentException.class, () -> server.installPlugin(new TestFunctionsPlugin())); + } +} diff --git a/presto-tests/src/test/java/com/facebook/presto/functions/TestSqlInvokedFunctionsPlugin.java b/presto-tests/src/test/java/com/facebook/presto/functions/TestSqlInvokedFunctionsPlugin.java new file mode 100644 index 0000000000000..65ec6162758f8 --- /dev/null +++ b/presto-tests/src/test/java/com/facebook/presto/functions/TestSqlInvokedFunctionsPlugin.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.functions; + +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.SqlInvokedScalarFunction; +import com.facebook.presto.spi.function.SqlParameter; +import com.facebook.presto.spi.function.SqlParameters; +import com.facebook.presto.spi.function.SqlType; + +public final class TestSqlInvokedFunctionsPlugin +{ + private TestSqlInvokedFunctionsPlugin() + {} + + @SqlInvokedScalarFunction(value = "custom_square", deterministic = true, calledOnNullInput = false) + @Description("Custom SQL to test NULLIF in Functions") + @SqlParameters({@SqlParameter(name = "x", type = "integer"), @SqlParameter(name = "y", type = "integer")}) + @SqlType("integer") + public static String customSquare() + { + return "RETURN IF(NULLIF(x, y) IS NOT NULL, x * x, y * y)"; + } +}