diff --git a/driver/connection.go b/driver/connection.go index 2d8303a..a27f4f0 100644 --- a/driver/connection.go +++ b/driver/connection.go @@ -1,13 +1,14 @@ package driver import ( - "cloud.google.com/go/bigquery" "context" "database/sql/driver" "fmt" + + "cloud.google.com/go/bigquery" ) -type bigQueryConnection struct { +type BigQueryConnection struct { ctx context.Context client *bigquery.Client config bigQueryConfig @@ -16,7 +17,7 @@ type bigQueryConnection struct { dataset *bigquery.Dataset } -func (connection *bigQueryConnection) GetDataset() *bigquery.Dataset { +func (connection *BigQueryConnection) GetDataset() *bigquery.Dataset { if connection.dataset != nil { return connection.dataset } @@ -24,11 +25,11 @@ func (connection *bigQueryConnection) GetDataset() *bigquery.Dataset { return connection.dataset } -func (connection *bigQueryConnection) GetContext() context.Context { +func (connection *BigQueryConnection) GetContext() context.Context { return connection.ctx } -func (connection *bigQueryConnection) Ping(ctx context.Context) error { +func (connection *BigQueryConnection) Ping(ctx context.Context) error { dataset := connection.GetDataset() if dataset == nil { @@ -43,12 +44,12 @@ func (connection *bigQueryConnection) Ping(ctx context.Context) error { return nil } -func (connection *bigQueryConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { +func (connection *BigQueryConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { var statement = &bigQueryStatement{connection, query} return statement.QueryContext(ctx, args) } -func (connection *bigQueryConnection) Query(query string, args []driver.Value) (driver.Rows, error) { +func (connection *BigQueryConnection) Query(query string, args []driver.Value) (driver.Rows, error) { statement, err := connection.Prepare(query) if err != nil { return nil, nil @@ -57,13 +58,13 @@ func (connection *bigQueryConnection) Query(query string, args []driver.Value) ( return statement.Query(args) } -func (connection *bigQueryConnection) Prepare(query string) (driver.Stmt, error) { +func (connection *BigQueryConnection) Prepare(query string) (driver.Stmt, error) { var statement = &bigQueryStatement{connection, query} return statement, nil } -func (connection *bigQueryConnection) Close() error { +func (connection *BigQueryConnection) Close() error { if connection.closed { return nil } @@ -74,27 +75,27 @@ func (connection *bigQueryConnection) Close() error { return connection.client.Close() } -func (connection *bigQueryConnection) Begin() (driver.Tx, error) { +func (connection *BigQueryConnection) Begin() (driver.Tx, error) { var transaction = &bigQueryTransaction{connection} return transaction, nil } -func (connection *bigQueryConnection) query(query string) (*bigquery.Query, error) { +func (connection *BigQueryConnection) query(query string) (*bigquery.Query, error) { return connection.client.Query(query), nil } -func (connection *bigQueryConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (connection *BigQueryConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { var statement = &bigQueryStatement{connection, query} return statement.ExecContext(ctx, args) } -func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) { +func (connection *BigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) { var statement = &bigQueryStatement{connection, query} return statement.Exec(args) } -func (bigQueryConnection) CheckNamedValue(*driver.NamedValue) error { +func (BigQueryConnection) CheckNamedValue(*driver.NamedValue) error { // TODO: Revise in the future return nil } diff --git a/driver/driver.go b/driver/driver.go index fff8696..d11a1c8 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -57,7 +57,7 @@ func (b bigQueryDriver) Open(uri string) (driver.Conn, error) { return nil, err } - return &bigQueryConnection{ + return &BigQueryConnection{ ctx: ctx, client: client, config: *config, diff --git a/driver/statement.go b/driver/statement.go index dd28080..14ef201 100644 --- a/driver/statement.go +++ b/driver/statement.go @@ -1,16 +1,17 @@ package driver import ( - "cloud.google.com/go/bigquery" "context" "database/sql/driver" "errors" + + "cloud.google.com/go/bigquery" "github.com/sirupsen/logrus" "gorm.io/driver/bigquery/adaptor" ) type bigQueryStatement struct { - connection *bigQueryConnection + connection *BigQueryConnection query string } diff --git a/driver/transaction.go b/driver/transaction.go index 66b2539..e53148e 100644 --- a/driver/transaction.go +++ b/driver/transaction.go @@ -1,7 +1,7 @@ package driver type bigQueryTransaction struct { - connection *bigQueryConnection + connection *BigQueryConnection } func (transaction *bigQueryTransaction) Commit() error { diff --git a/migrator.go b/migrator.go index be7996d..dd6e844 100644 --- a/migrator.go +++ b/migrator.go @@ -1,9 +1,16 @@ package bigquery import ( + "context" "errors" + "fmt" + "slices" + "strings" + + "gorm.io/driver/bigquery/driver" "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" ) @@ -13,8 +20,11 @@ type Migrator struct { } func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name) - return + datasetID, err := m.getDatasetID() + if err != nil { + return "" + } + return datasetID } func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { @@ -40,7 +50,15 @@ func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { - return m.DB.Raw("SELECT count(*) FROM `INFORMATION_SCHEMA.TABLES` WHERE table_name = ?", stmt.Table).Row().Scan(&count) + // According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region. + // See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro + // + // We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view. + datasetID, err := m.getDatasetID() + if err != nil { + return err + } + return m.DB.Raw("SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.TABLES` WHERE table_name = ?", stmt.Table).Row().Scan(&count) }) return count > 0 @@ -67,8 +85,17 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { name = field.DBName } + // According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region. + // See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro + // + // We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view. + datasetID, err := m.getDatasetID() + if err != nil { + return err + } + return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", + "SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.columns` WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?", stmt.Table, name, ).Row().Scan(&count) }) @@ -79,11 +106,87 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { + // According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region. + // See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro + // + // We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view. + datasetID, err := m.getDatasetID() + if err != nil { + return err + } + return m.DB.Raw( - "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?", + "SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.table_constraints` WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?", stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } + +// FullDataTypeOf returns field's db full data type +func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { + expr.SQL = m.DataTypeOf(field) + + if field.NotNull { + expr.SQL += " NOT NULL" + } + + if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { + if field.DefaultValueInterface != nil { + defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} + m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) + expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) + } else if field.DefaultValue != "(-)" { + expr.SQL += " DEFAULT " + field.DefaultValue + } + } + + options := map[string]string{} + if field.Comment != "" { + options["description"] = field.Comment + } + + if len(options) > 0 { + optionParts := []string{} + for key, value := range options { + optionParts = append(optionParts, fmt.Sprintf("%s = %s", key, logger.ExplainSQL("?", nil, `'`, value))) + } + slices.Sort(optionParts) + expr.SQL += " OPTIONS (" + strings.Join(optionParts, ", ") + ")" + } + + return +} + +// getDatasetID is a helper function to get the dataset ID from the connection. +func (m Migrator) getDatasetID() (string, error) { + sqlDB, err := m.DB.DB() + if err != nil { + return "", fmt.Errorf("could not get underlying database: %w", err) + } + ctx := context.Background() + conn, err := sqlDB.Conn(ctx) + if err != nil { + return "", fmt.Errorf("could not get connection: %w", err) + } + + datasetID := "" + err = conn.Raw(func(rawConnection any) error { + bigQueryConnection, ok := rawConnection.(*driver.BigQueryConnection) + if !ok { + return errors.New("connection is not a *driver.BigQueryConnection") + } + dataset := bigQueryConnection.GetDataset() + if dataset == nil { + return errors.New("dataset is nil") + } + datasetID = dataset.DatasetID + return nil + }) + if err != nil { + return "", fmt.Errorf("could not get dataset ID: %w", err) + } + + return datasetID, nil +}