diff --git a/.vimspector.json b/.vimspector.json index e21895ba..ddfb17c8 100644 --- a/.vimspector.json +++ b/.vimspector.json @@ -1,6 +1,27 @@ { "$schema": "https://puremourning.github.io/vimspector/schema/vimspector.schema.json#", "configurations": { + "dbtool": { + "adapter": "vscode-cpptools", + "configuration": { + "request": "launch", + "program": "${workspaceRoot}/out/build/linux-clang-debug/src/tools/dbtool", + "args": [ + "dump-schema", + "--connection-string=DRIVER=SQLite3\\;DATABASE=/home/christianparpart/.local/state/warp-terminal/warp.sqlite" + ], + "cwd": "${workspaceRoot}", + "externalConsole": true, + "stopAtEntry": false, + "MIMode": "gdb" + }, + "breakpoints": { + "exception": { + "caught": "Y", + "uncaught": "Y" + } + } + }, "CoreTest - SQLite": { "adapter": "vscode-cpptools", "configuration": { diff --git a/src/Lightweight/SqlColumnTypeDefinitions.hpp b/src/Lightweight/SqlColumnTypeDefinitions.hpp index 021d658e..8913637f 100644 --- a/src/Lightweight/SqlColumnTypeDefinitions.hpp +++ b/src/Lightweight/SqlColumnTypeDefinitions.hpp @@ -2,8 +2,12 @@ #pragma once +#include #include +#include +#include + namespace SqlColumnTypeDefinitions { @@ -50,3 +54,48 @@ using SqlColumnTypeDefinition = std::variant; + +/// Maps ODBC data type (with given @p size and @p precision) to SqlColumnTypeDefinition +/// +/// @return SqlColumnTypeDefinition if the data type is supported, otherwise std::nullopt +constexpr std::optional MakeColumnTypeFromNative(int value, + std::size_t size, + std::size_t precision) +{ + // Maps ODBC data types to SqlColumnTypeDefinition + // See: https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/sql-data-types?view=sql-server-ver16 + using namespace SqlColumnTypeDefinitions; + // clang-format off + switch (value) + { + case SQL_BIGINT: return Bigint {}; + case SQL_BINARY: return Binary { size }; + case SQL_BIT: return Bool {}; + case SQL_CHAR: return Char { size }; + case SQL_DATE: return Date {}; + case SQL_DECIMAL: return Decimal { .precision = precision, .scale = size }; + case SQL_DOUBLE: return Real {}; + case SQL_FLOAT: return Real {}; + case SQL_GUID: return Guid {}; + case SQL_INTEGER: return Integer {}; + case SQL_LONGVARBINARY: return VarBinary { size }; + case SQL_LONGVARCHAR: return Varchar { size }; + case SQL_NUMERIC: return Decimal { .precision = precision, .scale = size }; + case SQL_REAL: return Real {}; + case SQL_SMALLINT: return Smallint {}; + case SQL_TIME: return Time {}; + case SQL_TIMESTAMP: return DateTime {}; + case SQL_TINYINT: return Tinyint {}; + case SQL_TYPE_DATE: return Date {}; + case SQL_TYPE_TIME: return Time {}; + case SQL_TYPE_TIMESTAMP: return DateTime {}; + case SQL_VARBINARY: return Binary { size }; + case SQL_VARCHAR: return Varchar { size }; + case SQL_WCHAR: return NChar { size }; + case SQL_WLONGVARCHAR: return NVarchar { size }; + case SQL_WVARCHAR: return NVarchar { size }; + case SQL_UNKNOWN_TYPE: return std::nullopt; + default: return std::nullopt; + } + // clang-format on +} diff --git a/src/Lightweight/SqlConnection.cpp b/src/Lightweight/SqlConnection.cpp index 9338e5ea..b8e0dc94 100644 --- a/src/Lightweight/SqlConnection.cpp +++ b/src/Lightweight/SqlConnection.cpp @@ -5,6 +5,7 @@ #include "SqlQueryFormatter.hpp" #include +#include using namespace std::chrono_literals; using namespace std::string_view_literals; @@ -188,7 +189,10 @@ bool SqlConnection::Connect(SqlConnectionString sqlConnectionString) noexcept nullptr, SQL_DRIVER_NOPROMPT); if (!SQL_SUCCEEDED(sqlResult)) + { + RequireSuccess(sqlResult); return false; + } sqlResult = SQLSetConnectAttrA(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER); if (!SQL_SUCCEEDED(sqlResult)) diff --git a/src/Lightweight/SqlLogger.cpp b/src/Lightweight/SqlLogger.cpp index 4c5b8a3f..ec74b1fd 100644 --- a/src/Lightweight/SqlLogger.cpp +++ b/src/Lightweight/SqlLogger.cpp @@ -218,6 +218,7 @@ class SqlTraceLogger: public SqlStandardLogger _state = State::Executing; _lastPreparedQuery = query; _startedAt = std::chrono::steady_clock::now(); + _fetchRowCount = 0; } void OnExecute(std::string_view const& query) override diff --git a/src/Lightweight/SqlQuery/Migrate.cpp b/src/Lightweight/SqlQuery/Migrate.cpp index ebe36f29..64ea8836 100644 --- a/src/Lightweight/SqlQuery/Migrate.cpp +++ b/src/Lightweight/SqlQuery/Migrate.cpp @@ -24,7 +24,7 @@ SqlMigrationQueryBuilder& SqlMigrationQueryBuilder::DropTable(std::string_view t SqlCreateTableQueryBuilder SqlMigrationQueryBuilder::CreateTable(std::string_view tableName) { _migrationPlan.steps.emplace_back(SqlCreateTablePlan { - .tableName = tableName, + .tableName = std::string(tableName), .columns = {}, }); return SqlCreateTableQueryBuilder { std::get(_migrationPlan.steps.back()) }; diff --git a/src/Lightweight/SqlQuery/MigrationPlan.hpp b/src/Lightweight/SqlQuery/MigrationPlan.hpp index e7b37430..784bdcc6 100644 --- a/src/Lightweight/SqlQuery/MigrationPlan.hpp +++ b/src/Lightweight/SqlQuery/MigrationPlan.hpp @@ -218,7 +218,7 @@ struct SqlColumnDeclaration struct SqlCreateTablePlan { - std::string_view tableName; + std::string tableName; std::vector columns; }; diff --git a/src/Lightweight/SqlSchema.cpp b/src/Lightweight/SqlSchema.cpp index e57d5ee4..6e72ee46 100644 --- a/src/Lightweight/SqlSchema.cpp +++ b/src/Lightweight/SqlSchema.cpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -#include "SqlConnection.hpp" +#include "SqlColumnTypeDefinitions.hpp" #include "SqlError.hpp" #include "SqlSchema.hpp" #include "SqlStatement.hpp" @@ -29,55 +29,14 @@ bool operator<(KeyPair const& a, KeyPair const& b) namespace { - SqlColumnTypeDefinition FromNativeDataType(int value, size_t size, size_t precision) - { - // Maps ODBC data types to SqlColumnTypeDefinition - // See: https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/sql-data-types?view=sql-server-ver16 - using namespace SqlColumnTypeDefinitions; - // clang-format off - switch (value) - { - case SQL_BIGINT: return Bigint {}; - case SQL_BINARY: return Binary { size }; - case SQL_BIT: return Bool {}; - case SQL_CHAR: return Char { size }; - case SQL_DATE: return Date {}; - case SQL_DECIMAL: assert(size <= precision); return Decimal { .precision = precision, .scale = size }; - case SQL_DOUBLE: return Real {}; - case SQL_FLOAT: return Real {}; - case SQL_GUID: return Guid {}; - case SQL_INTEGER: return Integer {}; - case SQL_LONGVARBINARY: return VarBinary { size }; - case SQL_LONGVARCHAR: return Varchar { size }; - case SQL_NUMERIC: assert(size <= precision); return Decimal { .precision = precision, .scale = size }; - case SQL_REAL: return Real {}; - case SQL_SMALLINT: return Smallint {}; - case SQL_TIME: return Time {}; - case SQL_TIMESTAMP: return DateTime {}; - case SQL_TINYINT: return Tinyint {}; - case SQL_TYPE_DATE: return Date {}; - case SQL_TYPE_TIME: return Time {}; - case SQL_TYPE_TIMESTAMP: return DateTime {}; - case SQL_VARBINARY: return Binary { size }; - case SQL_VARCHAR: return Varchar { size }; - case SQL_WCHAR: return NChar { size }; - case SQL_WLONGVARCHAR: return NVarchar { size }; - case SQL_WVARCHAR: return NVarchar { size }; - // case SQL_UNKNOWN_TYPE: - default: - SqlLogger::GetLogger().OnError(SqlError::UNSUPPORTED_TYPE); - throw std::runtime_error(std::format("Unsupported data type: {}", value)); - } - // clang-format on - } - - std::vector AllTables(std::string_view database, std::string_view schema) + std::vector AllTables(SqlConnection& connection, std::string_view schema) { auto const tableType = "TABLE"sv; - (void) database; - (void) schema; + auto database = connection.DatabaseName(); + // (void) schema; - auto stmt = SqlStatement(); + auto stmt = SqlStatement { connection }; + SqlLogger::GetLogger().OnExecute(std::format(R"(SQLTables("{}"."{}".*))", database, schema)); auto sqlResult = SQLTables(stmt.NativeHandle(), (SQLCHAR*) database.data(), (SQLSMALLINT) database.size(), @@ -92,6 +51,7 @@ namespace auto result = std::vector(); while (stmt.FetchRow()) result.emplace_back(stmt.GetColumn(3)); + stmt.CloseCursor(); return result; } @@ -106,6 +66,7 @@ namespace auto* fkCatalog = (SQLCHAR*) (!foreignKey.catalog.empty() ? foreignKey.catalog.c_str() : nullptr); auto* fkSchema = (SQLCHAR*) (!foreignKey.schema.empty() ? foreignKey.schema.c_str() : nullptr); auto* fkTable = (SQLCHAR*) (!foreignKey.table.empty() ? foreignKey.table.c_str() : nullptr); + SqlLogger::GetLogger().OnExecute(std::format(R"(SQLForeignKeys(pk="{}", fk="{}"))", primaryKey, foreignKey)); auto sqlResult = SQLForeignKeys(stmt.NativeHandle(), pkCatalog, (SQLSMALLINT) primaryKey.catalog.size(), @@ -148,6 +109,7 @@ namespace keyColumns.resize(sequenceNumber); keyColumns[sequenceNumber - 1] = std::move(pkColumnName); } + stmt.CloseCursor(); auto result = std::vector(); for (auto const& [keyPair, columns]: constraints) @@ -168,6 +130,8 @@ namespace std::vector keys; std::vector sequenceNumbers; + SqlLogger::GetLogger().OnExecute( + std::format(R"(SQLPrimaryKeys("{}"."{}"."{}"))", table.catalog, table.schema, table.table)); auto sqlResult = SQLPrimaryKeys(stmt.NativeHandle(), (SQLCHAR*) table.catalog.data(), (SQLSMALLINT) table.catalog.size(), @@ -183,6 +147,7 @@ namespace keys.emplace_back(stmt.GetColumn(4)); sequenceNumbers.emplace_back(stmt.GetColumn(5)); } + stmt.CloseCursor(); std::vector sortedKeys; sortedKeys.resize(keys.size()); @@ -194,10 +159,12 @@ namespace } // namespace -void ReadAllTables(std::string_view database, std::string_view schema, EventHandler& eventHandler) +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void ReadAllTables(SqlConnection& connection, std::string_view schema, EventHandler& eventHandler) { - auto stmt = SqlStatement {}; - auto const tableNames = AllTables(database, schema); + auto stmt = SqlStatement { connection }; + auto const database = connection.DatabaseName(); + auto const tableNames = AllTables(connection, schema); for (auto const& tableName: tableNames) { @@ -225,7 +192,8 @@ void ReadAllTables(std::string_view database, std::string_view schema, EventHand for (auto const& foreignKey: incomingForeignKeys) eventHandler.OnExternalForeignKey(foreignKey); - auto columnStmt = SqlStatement(); + auto columnStmt = SqlStatement { connection }; + SqlLogger::GetLogger().OnExecute(std::format(R"(SQLColumns("{}"."{}"."{}".*))", database, schema, tableName)); auto const sqlResult = SQLColumns(columnStmt.NativeHandle(), (SQLCHAR*) database.data(), (SQLSMALLINT) database.size(), @@ -274,7 +242,13 @@ void ReadAllTables(std::string_view database, std::string_view schema, EventHand column.defaultValue = {}; } - column.type = FromNativeDataType(type, column.size, column.decimalDigits); + if (auto cType = MakeColumnTypeFromNative(type, column.size, column.decimalDigits); cType.has_value()) + column.type = *cType; + else + { + SqlLogger::GetLogger().OnError(SqlError::UNSUPPORTED_TYPE); + throw std::runtime_error(std::format("Unsupported data type: {}", type)); + } // accumulated properties column.isPrimaryKey = std::ranges::contains(primaryKeys, column.name); @@ -290,6 +264,7 @@ void ReadAllTables(std::string_view database, std::string_view schema, EventHand eventHandler.OnColumn(column); } + stmt.CloseCursor(); eventHandler.OnTableEnd(); } @@ -302,7 +277,7 @@ std::string ToLowerCase(std::string_view str) return result; } -TableList ReadAllTables(std::string_view database, std::string_view schema) +TableList ReadAllTables(SqlConnection& connection, std::string_view schema) { TableList tables; struct EventHandler: public SqlSchema::EventHandler @@ -341,7 +316,7 @@ TableList ReadAllTables(std::string_view database, std::string_view schema) tables.back().externalForeignKeys.emplace_back(foreignKeyConstraint); } } eventHandler { tables }; - ReadAllTables(database, schema, eventHandler); + ReadAllTables(connection, schema, eventHandler); std::map tableNameCaseMap; for (auto const& table: tables) @@ -376,4 +351,37 @@ std::vector AllForeignKeysFrom(SqlStatement& stmt, FullyQu return AllForeignKeys(stmt, FullyQualifiedTableName {}, table); } +SqlMigrationQueryBuilder BuildStructureFromSchema(SqlConnection& connection, + std::string_view schemaName, + SqlQueryFormatter const& dialect) +{ + auto builder = SqlMigrationQueryBuilder { dialect }; + SqlSchema::TableList tables = SqlSchema::ReadAllTables(connection, schemaName); + for (SqlSchema::Table const& table: tables) + { + auto tableBuilder = builder.CreateTable(table.name); + for (SqlSchema::Column const& column: table.columns) + { + if (column.isPrimaryKey) + { + if (column.isAutoIncrement) + tableBuilder.PrimaryKeyWithAutoIncrement(column.name, column.type); + else + tableBuilder.PrimaryKey(column.name, column.type); + } + else + { + tableBuilder.Column(SqlColumnDeclaration { + .name = column.name, + .type = column.type, + .required = !column.isNullable, + .unique = column.isUnique, + .index = false, // TODO + }); + } + } + } + return builder; +} + } // namespace SqlSchema diff --git a/src/Lightweight/SqlSchema.hpp b/src/Lightweight/SqlSchema.hpp index 1396dca5..00a64c97 100644 --- a/src/Lightweight/SqlSchema.hpp +++ b/src/Lightweight/SqlSchema.hpp @@ -7,6 +7,7 @@ #endif #include "Api.hpp" +#include "SqlQuery/Migrate.hpp" #include "SqlQuery/MigrationPlan.hpp" #include @@ -121,7 +122,7 @@ class EventHandler }; /// Reads all tables in the given database and schema and calls the event handler for each table. -LIGHTWEIGHT_API void ReadAllTables(std::string_view database, std::string_view schema, EventHandler& eventHandler); +LIGHTWEIGHT_API void ReadAllTables(SqlConnection& connection, std::string_view schema, EventHandler& eventHandler); /// Holds the definition of a table in a SQL database as read from the database schema. struct Table @@ -148,7 +149,7 @@ struct Table using TableList = std::vector; /// Retrieves all tables in the given @p database and @p schema. -LIGHTWEIGHT_API TableList ReadAllTables(std::string_view database, std::string_view schema = {}); +LIGHTWEIGHT_API TableList ReadAllTables(SqlConnection& connection, std::string_view schema = {}); /// Retrieves all tables in the given database and schema that have a foreign key to the given table. LIGHTWEIGHT_API std::vector AllForeignKeysTo(SqlStatement& stmt, @@ -158,6 +159,17 @@ LIGHTWEIGHT_API std::vector AllForeignKeysTo(SqlStatement& LIGHTWEIGHT_API std::vector AllForeignKeysFrom(SqlStatement& stmt, FullyQualifiedTableName const& table); +/// Builds a migration plan to create the structure of the given database and schema. +/// +/// @param connection The connection to the database to read the tables and relations from. +/// @param schemaName The name of the schema to read the tables and relations from. +/// @param dialect The SQL dialect to use for the migration plan. +/// +/// @return A migration plan that creates the structure of the given database and schema. +LIGHTWEIGHT_API SqlMigrationQueryBuilder BuildStructureFromSchema(SqlConnection& connection, + std::string_view schemaName, + SqlQueryFormatter const& dialect); + } // namespace SqlSchema template <> diff --git a/src/Lightweight/SqlStatement.hpp b/src/Lightweight/SqlStatement.hpp index f9ac3001..fe15d0e5 100644 --- a/src/Lightweight/SqlStatement.hpp +++ b/src/Lightweight/SqlStatement.hpp @@ -828,7 +828,10 @@ void SqlStatement::MigrateDirect(Callable const& callable, std::source_location callable(migration); auto const queries = migration.GetPlan().ToSql(); for (auto const& query: queries) + { ExecuteDirect(query, location); + CloseCursor(); + } } template diff --git a/src/tests/CoreTests.cpp b/src/tests/CoreTests.cpp index 9c6204ba..61462635 100644 --- a/src/tests/CoreTests.cpp +++ b/src/tests/CoreTests.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -558,4 +559,38 @@ TEST_CASE_METHOD(SqlTestFixture, "SELECT into SqlVariantRowIterator", "[SqlState } } +TEST_CASE_METHOD(SqlTestFixture, "SqlSchema: simple", "[SqlSchema]") +{ + auto conn = SqlConnection {}; + auto stmt = SqlStatement { conn }; + stmt.MigrateDirect([](SqlMigrationQueryBuilder& migration) { + using namespace SqlColumnTypeDefinitions; + for (int i = 1; i <= 2; ++i) + { + migration.CreateTable("Table" + std::to_string(i)) + .PrimaryKeyWithAutoIncrement("pk", Integer {}) + .Column("c1", Varchar { 30 }); + } + }); + + auto const schemaName = ""sv; // TODO(pr) how to get the schema name for SQL server tests (dbo) here, generically? + auto const& dialect = SqlQueryFormatter::SqlServer(); + SqlMigrationQueryBuilder migrationPlan = SqlSchema::BuildStructureFromSchema(conn, schemaName, dialect); + SqlMigrationPlan const& plan = migrationPlan.GetPlan(); + auto const sqlStatements = plan.ToSql(); + CHECK(sqlStatements.size() == 2); + CHECK(NormalizeText(sqlStatements.at(0)) == NormalizeText(R"sql( + CREATE TABLE "Table1" ( + "pk" INTEGER IDENTITY(1,1) PRIMARY KEY, + "c1" VARCHAR(30) + ); + )sql")); + CHECK(NormalizeText(sqlStatements.at(1)) == NormalizeText(R"sql( + CREATE TABLE "Table2" ( + "pk" INTEGER IDENTITY(1,1) PRIMARY KEY, + "c1" VARCHAR(30) + ); + )sql")); +} + // NOLINTEND(readability-container-size-empty) diff --git a/src/tools/CMakeLists.txt b/src/tools/CMakeLists.txt index 55b8d59e..6e7338b3 100644 --- a/src/tools/CMakeLists.txt +++ b/src/tools/CMakeLists.txt @@ -2,3 +2,8 @@ add_executable(ddl2cpp ddl2cpp.cpp) target_link_libraries(ddl2cpp PRIVATE Lightweight::Lightweight) target_compile_features(ddl2cpp PUBLIC cxx_std_23) install(TARGETS ddl2cpp DESTINATION bin) + +add_executable(dbtool dbtool.cpp) +target_link_libraries(dbtool PRIVATE Lightweight::Lightweight) +target_compile_features(dbtool PUBLIC cxx_std_23) +install(TARGETS dbtool DESTINATION bin) diff --git a/src/tools/ddl2cpp.cpp b/src/tools/ddl2cpp.cpp index ce504b44..796b7357 100644 --- a/src/tools/ddl2cpp.cpp +++ b/src/tools/ddl2cpp.cpp @@ -440,7 +440,8 @@ int main(int argc, char const* argv[]) PrintInfo(); - std::vector tables = SqlSchema::ReadAllTables(config.database, config.schema); + auto connection = SqlConnection {}; + std::vector tables = SqlSchema::ReadAllTables(connection, config.schema); CxxModelPrinter printer; printer.Config().makeAliases = config.makeAliases; printer.Config().formatType = config.formatType;