Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .vimspector.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
{
"$schema": "https://puremourning.github.io/vimspector/schema/vimspector.schema.json#",
"configurations": {
"dbtool": {
"adapter": "vscode-cpptools",
"configuration": {
"request": "launch",
"program": "${workspaceRoot}/out/build/linux-clang-debug/src/tools/dbtool",
"args": [
"dump-schema",
"--connection-string=DRIVER=SQLite3\\;DATABASE=/home/christianparpart/.local/state/warp-terminal/warp.sqlite"
],
"cwd": "${workspaceRoot}",
"externalConsole": true,
"stopAtEntry": false,
"MIMode": "gdb"
},
"breakpoints": {
"exception": {
"caught": "Y",
"uncaught": "Y"
}
}
},
"CoreTest - SQLite": {
"adapter": "vscode-cpptools",
"configuration": {
Expand Down
49 changes: 49 additions & 0 deletions src/Lightweight/SqlColumnTypeDefinitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

#pragma once

#include <optional>
#include <variant>

#include <sql.h>
#include <sqlext.h>

namespace SqlColumnTypeDefinitions
{

Expand Down Expand Up @@ -50,3 +54,48 @@ using SqlColumnTypeDefinition = std::variant<SqlColumnTypeDefinitions::Bigint,
SqlColumnTypeDefinitions::Timestamp,
SqlColumnTypeDefinitions::VarBinary,
SqlColumnTypeDefinitions::Varchar>;

/// Maps ODBC data type (with given @p size and @p precision) to SqlColumnTypeDefinition
///
/// @return SqlColumnTypeDefinition if the data type is supported, otherwise std::nullopt
constexpr std::optional<SqlColumnTypeDefinition> MakeColumnTypeFromNative(int value,
std::size_t size,
std::size_t precision)
{
// Maps ODBC data types to SqlColumnTypeDefinition
// See: https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/sql-data-types?view=sql-server-ver16
using namespace SqlColumnTypeDefinitions;
// clang-format off
switch (value)
{
case SQL_BIGINT: return Bigint {};
case SQL_BINARY: return Binary { size };
case SQL_BIT: return Bool {};
case SQL_CHAR: return Char { size };
case SQL_DATE: return Date {};
case SQL_DECIMAL: return Decimal { .precision = precision, .scale = size };
case SQL_DOUBLE: return Real {};
case SQL_FLOAT: return Real {};
case SQL_GUID: return Guid {};
case SQL_INTEGER: return Integer {};
case SQL_LONGVARBINARY: return VarBinary { size };
case SQL_LONGVARCHAR: return Varchar { size };
case SQL_NUMERIC: return Decimal { .precision = precision, .scale = size };
case SQL_REAL: return Real {};
case SQL_SMALLINT: return Smallint {};
case SQL_TIME: return Time {};
case SQL_TIMESTAMP: return DateTime {};
case SQL_TINYINT: return Tinyint {};
case SQL_TYPE_DATE: return Date {};
case SQL_TYPE_TIME: return Time {};
case SQL_TYPE_TIMESTAMP: return DateTime {};
case SQL_VARBINARY: return Binary { size };
case SQL_VARCHAR: return Varchar { size };
case SQL_WCHAR: return NChar { size };
case SQL_WLONGVARCHAR: return NVarchar { size };
case SQL_WVARCHAR: return NVarchar { size };
case SQL_UNKNOWN_TYPE: return std::nullopt;
default: return std::nullopt;
}
// clang-format on
}
4 changes: 4 additions & 0 deletions src/Lightweight/SqlConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "SqlQueryFormatter.hpp"

#include <sql.h>
#include <sqlext.h>

using namespace std::chrono_literals;
using namespace std::string_view_literals;
Expand Down Expand Up @@ -188,7 +189,10 @@ bool SqlConnection::Connect(SqlConnectionString sqlConnectionString) noexcept
nullptr,
SQL_DRIVER_NOPROMPT);
if (!SQL_SUCCEEDED(sqlResult))
{
RequireSuccess(sqlResult);
return false;
}

sqlResult = SQLSetConnectAttrA(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER);
if (!SQL_SUCCEEDED(sqlResult))
Expand Down
1 change: 1 addition & 0 deletions src/Lightweight/SqlLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class SqlTraceLogger: public SqlStandardLogger
_state = State::Executing;
_lastPreparedQuery = query;
_startedAt = std::chrono::steady_clock::now();
_fetchRowCount = 0;
}

void OnExecute(std::string_view const& query) override
Expand Down
2 changes: 1 addition & 1 deletion src/Lightweight/SqlQuery/Migrate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ SqlMigrationQueryBuilder& SqlMigrationQueryBuilder::DropTable(std::string_view t
SqlCreateTableQueryBuilder SqlMigrationQueryBuilder::CreateTable(std::string_view tableName)
{
_migrationPlan.steps.emplace_back(SqlCreateTablePlan {
.tableName = tableName,
.tableName = std::string(tableName),
.columns = {},
});
return SqlCreateTableQueryBuilder { std::get<SqlCreateTablePlan>(_migrationPlan.steps.back()) };
Expand Down
2 changes: 1 addition & 1 deletion src/Lightweight/SqlQuery/MigrationPlan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct SqlColumnDeclaration

struct SqlCreateTablePlan
{
std::string_view tableName;
std::string tableName;
std::vector<SqlColumnDeclaration> columns;
};

Expand Down
116 changes: 62 additions & 54 deletions src/Lightweight/SqlSchema.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: Apache-2.0

#include "SqlConnection.hpp"
#include "SqlColumnTypeDefinitions.hpp"
#include "SqlError.hpp"
#include "SqlSchema.hpp"
#include "SqlStatement.hpp"
Expand Down Expand Up @@ -29,55 +29,14 @@ bool operator<(KeyPair const& a, KeyPair const& b)

namespace
{
SqlColumnTypeDefinition FromNativeDataType(int value, size_t size, size_t precision)
{
// Maps ODBC data types to SqlColumnTypeDefinition
// See: https://learn.microsoft.com/en-us/sql/odbc/reference/appendixes/sql-data-types?view=sql-server-ver16
using namespace SqlColumnTypeDefinitions;
// clang-format off
switch (value)
{
case SQL_BIGINT: return Bigint {};
case SQL_BINARY: return Binary { size };
case SQL_BIT: return Bool {};
case SQL_CHAR: return Char { size };
case SQL_DATE: return Date {};
case SQL_DECIMAL: assert(size <= precision); return Decimal { .precision = precision, .scale = size };
case SQL_DOUBLE: return Real {};
case SQL_FLOAT: return Real {};
case SQL_GUID: return Guid {};
case SQL_INTEGER: return Integer {};
case SQL_LONGVARBINARY: return VarBinary { size };
case SQL_LONGVARCHAR: return Varchar { size };
case SQL_NUMERIC: assert(size <= precision); return Decimal { .precision = precision, .scale = size };
case SQL_REAL: return Real {};
case SQL_SMALLINT: return Smallint {};
case SQL_TIME: return Time {};
case SQL_TIMESTAMP: return DateTime {};
case SQL_TINYINT: return Tinyint {};
case SQL_TYPE_DATE: return Date {};
case SQL_TYPE_TIME: return Time {};
case SQL_TYPE_TIMESTAMP: return DateTime {};
case SQL_VARBINARY: return Binary { size };
case SQL_VARCHAR: return Varchar { size };
case SQL_WCHAR: return NChar { size };
case SQL_WLONGVARCHAR: return NVarchar { size };
case SQL_WVARCHAR: return NVarchar { size };
// case SQL_UNKNOWN_TYPE:
default:
SqlLogger::GetLogger().OnError(SqlError::UNSUPPORTED_TYPE);
throw std::runtime_error(std::format("Unsupported data type: {}", value));
}
// clang-format on
}

std::vector<std::string> AllTables(std::string_view database, std::string_view schema)
std::vector<std::string> AllTables(SqlConnection& connection, std::string_view schema)
{
auto const tableType = "TABLE"sv;
(void) database;
(void) schema;
auto database = connection.DatabaseName();
// (void) schema;

auto stmt = SqlStatement();
auto stmt = SqlStatement { connection };
SqlLogger::GetLogger().OnExecute(std::format(R"(SQLTables("{}"."{}".*))", database, schema));
auto sqlResult = SQLTables(stmt.NativeHandle(),
(SQLCHAR*) database.data(),
(SQLSMALLINT) database.size(),
Expand All @@ -92,6 +51,7 @@ namespace
auto result = std::vector<std::string>();
while (stmt.FetchRow())
result.emplace_back(stmt.GetColumn<std::string>(3));
stmt.CloseCursor();

return result;
}
Expand All @@ -106,6 +66,7 @@ namespace
auto* fkCatalog = (SQLCHAR*) (!foreignKey.catalog.empty() ? foreignKey.catalog.c_str() : nullptr);
auto* fkSchema = (SQLCHAR*) (!foreignKey.schema.empty() ? foreignKey.schema.c_str() : nullptr);
auto* fkTable = (SQLCHAR*) (!foreignKey.table.empty() ? foreignKey.table.c_str() : nullptr);
SqlLogger::GetLogger().OnExecute(std::format(R"(SQLForeignKeys(pk="{}", fk="{}"))", primaryKey, foreignKey));
auto sqlResult = SQLForeignKeys(stmt.NativeHandle(),
pkCatalog,
(SQLSMALLINT) primaryKey.catalog.size(),
Expand Down Expand Up @@ -148,6 +109,7 @@ namespace
keyColumns.resize(sequenceNumber);
keyColumns[sequenceNumber - 1] = std::move(pkColumnName);
}
stmt.CloseCursor();

auto result = std::vector<ForeignKeyConstraint>();
for (auto const& [keyPair, columns]: constraints)
Expand All @@ -168,6 +130,8 @@ namespace
std::vector<std::string> keys;
std::vector<size_t> sequenceNumbers;

SqlLogger::GetLogger().OnExecute(
std::format(R"(SQLPrimaryKeys("{}"."{}"."{}"))", table.catalog, table.schema, table.table));
auto sqlResult = SQLPrimaryKeys(stmt.NativeHandle(),
(SQLCHAR*) table.catalog.data(),
(SQLSMALLINT) table.catalog.size(),
Expand All @@ -183,6 +147,7 @@ namespace
keys.emplace_back(stmt.GetColumn<std::string>(4));
sequenceNumbers.emplace_back(stmt.GetColumn<size_t>(5));
}
stmt.CloseCursor();

std::vector<std::string> sortedKeys;
sortedKeys.resize(keys.size());
Expand All @@ -194,10 +159,12 @@ namespace

} // namespace

void ReadAllTables(std::string_view database, std::string_view schema, EventHandler& eventHandler)
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void ReadAllTables(SqlConnection& connection, std::string_view schema, EventHandler& eventHandler)
{
auto stmt = SqlStatement {};
auto const tableNames = AllTables(database, schema);
auto stmt = SqlStatement { connection };
auto const database = connection.DatabaseName();
auto const tableNames = AllTables(connection, schema);

for (auto const& tableName: tableNames)
{
Expand Down Expand Up @@ -225,7 +192,8 @@ void ReadAllTables(std::string_view database, std::string_view schema, EventHand
for (auto const& foreignKey: incomingForeignKeys)
eventHandler.OnExternalForeignKey(foreignKey);

auto columnStmt = SqlStatement();
auto columnStmt = SqlStatement { connection };
SqlLogger::GetLogger().OnExecute(std::format(R"(SQLColumns("{}"."{}"."{}".*))", database, schema, tableName));
auto const sqlResult = SQLColumns(columnStmt.NativeHandle(),
(SQLCHAR*) database.data(),
(SQLSMALLINT) database.size(),
Expand Down Expand Up @@ -274,7 +242,13 @@ void ReadAllTables(std::string_view database, std::string_view schema, EventHand
column.defaultValue = {};
}

column.type = FromNativeDataType(type, column.size, column.decimalDigits);
if (auto cType = MakeColumnTypeFromNative(type, column.size, column.decimalDigits); cType.has_value())
column.type = *cType;
else
{
SqlLogger::GetLogger().OnError(SqlError::UNSUPPORTED_TYPE);
throw std::runtime_error(std::format("Unsupported data type: {}", type));
}

// accumulated properties
column.isPrimaryKey = std::ranges::contains(primaryKeys, column.name);
Expand All @@ -290,6 +264,7 @@ void ReadAllTables(std::string_view database, std::string_view schema, EventHand

eventHandler.OnColumn(column);
}
stmt.CloseCursor();

eventHandler.OnTableEnd();
}
Expand All @@ -302,7 +277,7 @@ std::string ToLowerCase(std::string_view str)
return result;
}

TableList ReadAllTables(std::string_view database, std::string_view schema)
TableList ReadAllTables(SqlConnection& connection, std::string_view schema)
{
TableList tables;
struct EventHandler: public SqlSchema::EventHandler
Expand Down Expand Up @@ -341,7 +316,7 @@ TableList ReadAllTables(std::string_view database, std::string_view schema)
tables.back().externalForeignKeys.emplace_back(foreignKeyConstraint);
}
} eventHandler { tables };
ReadAllTables(database, schema, eventHandler);
ReadAllTables(connection, schema, eventHandler);

std::map<std::string, std::string> tableNameCaseMap;
for (auto const& table: tables)
Expand Down Expand Up @@ -376,4 +351,37 @@ std::vector<ForeignKeyConstraint> AllForeignKeysFrom(SqlStatement& stmt, FullyQu
return AllForeignKeys(stmt, FullyQualifiedTableName {}, table);
}

SqlMigrationQueryBuilder BuildStructureFromSchema(SqlConnection& connection,
std::string_view schemaName,
SqlQueryFormatter const& dialect)
{
auto builder = SqlMigrationQueryBuilder { dialect };
SqlSchema::TableList tables = SqlSchema::ReadAllTables(connection, schemaName);
for (SqlSchema::Table const& table: tables)
{
auto tableBuilder = builder.CreateTable(table.name);
for (SqlSchema::Column const& column: table.columns)
{
if (column.isPrimaryKey)
{
if (column.isAutoIncrement)
tableBuilder.PrimaryKeyWithAutoIncrement(column.name, column.type);
else
tableBuilder.PrimaryKey(column.name, column.type);
}
else
{
tableBuilder.Column(SqlColumnDeclaration {
.name = column.name,
.type = column.type,
.required = !column.isNullable,
.unique = column.isUnique,
.index = false, // TODO
});
}
}
}
return builder;
}

} // namespace SqlSchema
16 changes: 14 additions & 2 deletions src/Lightweight/SqlSchema.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#endif

#include "Api.hpp"
#include "SqlQuery/Migrate.hpp"
#include "SqlQuery/MigrationPlan.hpp"

#include <format>
Expand Down Expand Up @@ -121,7 +122,7 @@ class EventHandler
};

/// Reads all tables in the given database and schema and calls the event handler for each table.
LIGHTWEIGHT_API void ReadAllTables(std::string_view database, std::string_view schema, EventHandler& eventHandler);
LIGHTWEIGHT_API void ReadAllTables(SqlConnection& connection, std::string_view schema, EventHandler& eventHandler);

/// Holds the definition of a table in a SQL database as read from the database schema.
struct Table
Expand All @@ -148,7 +149,7 @@ struct Table
using TableList = std::vector<Table>;

/// Retrieves all tables in the given @p database and @p schema.
LIGHTWEIGHT_API TableList ReadAllTables(std::string_view database, std::string_view schema = {});
LIGHTWEIGHT_API TableList ReadAllTables(SqlConnection& connection, std::string_view schema = {});

/// Retrieves all tables in the given database and schema that have a foreign key to the given table.
LIGHTWEIGHT_API std::vector<ForeignKeyConstraint> AllForeignKeysTo(SqlStatement& stmt,
Expand All @@ -158,6 +159,17 @@ LIGHTWEIGHT_API std::vector<ForeignKeyConstraint> AllForeignKeysTo(SqlStatement&
LIGHTWEIGHT_API std::vector<ForeignKeyConstraint> AllForeignKeysFrom(SqlStatement& stmt,
FullyQualifiedTableName const& table);

/// Builds a migration plan to create the structure of the given database and schema.
///
/// @param connection The connection to the database to read the tables and relations from.
/// @param schemaName The name of the schema to read the tables and relations from.
/// @param dialect The SQL dialect to use for the migration plan.
///
/// @return A migration plan that creates the structure of the given database and schema.
LIGHTWEIGHT_API SqlMigrationQueryBuilder BuildStructureFromSchema(SqlConnection& connection,
std::string_view schemaName,
SqlQueryFormatter const& dialect);

} // namespace SqlSchema

template <>
Expand Down
Loading
Loading