Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions driver/connection.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package driver

import (
"cloud.google.com/go/bigquery"
"context"
"database/sql/driver"
"fmt"

"cloud.google.com/go/bigquery"
)

type bigQueryConnection struct {
type BigQueryConnection struct {
ctx context.Context
client *bigquery.Client
config bigQueryConfig
Expand All @@ -16,19 +17,19 @@ type bigQueryConnection struct {
dataset *bigquery.Dataset
}

func (connection *bigQueryConnection) GetDataset() *bigquery.Dataset {
func (connection *BigQueryConnection) GetDataset() *bigquery.Dataset {
if connection.dataset != nil {
return connection.dataset
}
connection.dataset = connection.client.Dataset(connection.config.dataSet)
return connection.dataset
}

func (connection *bigQueryConnection) GetContext() context.Context {
func (connection *BigQueryConnection) GetContext() context.Context {
return connection.ctx
}

func (connection *bigQueryConnection) Ping(ctx context.Context) error {
func (connection *BigQueryConnection) Ping(ctx context.Context) error {

dataset := connection.GetDataset()
if dataset == nil {
Expand All @@ -43,12 +44,12 @@ func (connection *bigQueryConnection) Ping(ctx context.Context) error {
return nil
}

func (connection *bigQueryConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
func (connection *BigQueryConnection) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var statement = &bigQueryStatement{connection, query}
return statement.QueryContext(ctx, args)
}

func (connection *bigQueryConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
func (connection *BigQueryConnection) Query(query string, args []driver.Value) (driver.Rows, error) {
statement, err := connection.Prepare(query)
if err != nil {
return nil, nil
Expand All @@ -57,13 +58,13 @@ func (connection *bigQueryConnection) Query(query string, args []driver.Value) (
return statement.Query(args)
}

func (connection *bigQueryConnection) Prepare(query string) (driver.Stmt, error) {
func (connection *BigQueryConnection) Prepare(query string) (driver.Stmt, error) {
var statement = &bigQueryStatement{connection, query}

return statement, nil
}

func (connection *bigQueryConnection) Close() error {
func (connection *BigQueryConnection) Close() error {
if connection.closed {
return nil
}
Expand All @@ -74,27 +75,27 @@ func (connection *bigQueryConnection) Close() error {
return connection.client.Close()
}

func (connection *bigQueryConnection) Begin() (driver.Tx, error) {
func (connection *BigQueryConnection) Begin() (driver.Tx, error) {
var transaction = &bigQueryTransaction{connection}

return transaction, nil
}

func (connection *bigQueryConnection) query(query string) (*bigquery.Query, error) {
func (connection *BigQueryConnection) query(query string) (*bigquery.Query, error) {
return connection.client.Query(query), nil
}

func (connection *bigQueryConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
func (connection *BigQueryConnection) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var statement = &bigQueryStatement{connection, query}
return statement.ExecContext(ctx, args)
}

func (connection *bigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) {
func (connection *BigQueryConnection) Exec(query string, args []driver.Value) (driver.Result, error) {
var statement = &bigQueryStatement{connection, query}
return statement.Exec(args)
}

func (bigQueryConnection) CheckNamedValue(*driver.NamedValue) error {
func (BigQueryConnection) CheckNamedValue(*driver.NamedValue) error {
// TODO: Revise in the future
return nil
}
2 changes: 1 addition & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (b bigQueryDriver) Open(uri string) (driver.Conn, error) {
return nil, err
}

return &bigQueryConnection{
return &BigQueryConnection{
ctx: ctx,
client: client,
config: *config,
Expand Down
5 changes: 3 additions & 2 deletions driver/statement.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package driver

import (
"cloud.google.com/go/bigquery"
"context"
"database/sql/driver"
"errors"

"cloud.google.com/go/bigquery"
"github.com/sirupsen/logrus"
"gorm.io/driver/bigquery/adaptor"
)

type bigQueryStatement struct {
connection *bigQueryConnection
connection *BigQueryConnection
query string
}

Expand Down
2 changes: 1 addition & 1 deletion driver/transaction.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package driver

type bigQueryTransaction struct {
connection *bigQueryConnection
connection *BigQueryConnection
}

func (transaction *bigQueryTransaction) Commit() error {
Expand Down
113 changes: 108 additions & 5 deletions migrator.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
package bigquery

import (
"context"
"errors"
"fmt"
"slices"
"strings"

"gorm.io/driver/bigquery/driver"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
Expand All @@ -13,8 +20,11 @@ type Migrator struct {
}

func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT CURRENT_DATABASE()").Row().Scan(&name)
return
datasetID, err := m.getDatasetID()
if err != nil {
return ""
}
return datasetID
}

func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
Expand All @@ -40,7 +50,15 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
func (m Migrator) HasTable(value interface{}) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM `INFORMATION_SCHEMA.TABLES` WHERE table_name = ?", stmt.Table).Row().Scan(&count)
// According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region.
// See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro
//
// We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view.
datasetID, err := m.getDatasetID()
if err != nil {
return err
}
return m.DB.Raw("SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.TABLES` WHERE table_name = ?", stmt.Table).Row().Scan(&count)
})

return count > 0
Expand All @@ -67,8 +85,17 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
name = field.DBName
}

// According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region.
// See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro
//
// We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view.
datasetID, err := m.getDatasetID()
if err != nil {
return err
}

return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
"SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.columns` WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND column_name = ?",
stmt.Table, name,
).Row().Scan(&count)
})
Expand All @@ -79,11 +106,87 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
// According to the BigQuery documentation, an INFORMATION_SCHEMA view must be qualified with a dataset or region.
// See: https://docs.cloud.google.com/bigquery/docs/information-schema-intro
//
// We are going to attempt to get the dataset ID from the connection and use it to qualify the INFORMATION_SCHEMA view.
datasetID, err := m.getDatasetID()
if err != nil {
return err
}

return m.DB.Raw(
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?",
"SELECT count(*) FROM `"+datasetID+".INFORMATION_SCHEMA.table_constraints` WHERE table_schema = CURRENT_SCHEMA() AND table_name = ? AND constraint_name = ?",
stmt.Table, name,
).Row().Scan(&count)
})

return count > 0
}

// FullDataTypeOf returns field's db full data type
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
expr.SQL = m.DataTypeOf(field)

if field.NotNull {
expr.SQL += " NOT NULL"
}

if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") {
if field.DefaultValueInterface != nil {
defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}}
m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface)
expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface)
} else if field.DefaultValue != "(-)" {
expr.SQL += " DEFAULT " + field.DefaultValue
}
}

options := map[string]string{}
if field.Comment != "" {
options["description"] = field.Comment
}

if len(options) > 0 {
optionParts := []string{}
for key, value := range options {
optionParts = append(optionParts, fmt.Sprintf("%s = %s", key, logger.ExplainSQL("?", nil, `'`, value)))
}
slices.Sort(optionParts)
expr.SQL += " OPTIONS (" + strings.Join(optionParts, ", ") + ")"
}

return
}

// getDatasetID is a helper function to get the dataset ID from the connection.
func (m Migrator) getDatasetID() (string, error) {
sqlDB, err := m.DB.DB()
if err != nil {
return "", fmt.Errorf("could not get underlying database: %w", err)
}
ctx := context.Background()
conn, err := sqlDB.Conn(ctx)
if err != nil {
return "", fmt.Errorf("could not get connection: %w", err)
}

datasetID := ""
err = conn.Raw(func(rawConnection any) error {
bigQueryConnection, ok := rawConnection.(*driver.BigQueryConnection)
if !ok {
return errors.New("connection is not a *driver.BigQueryConnection")
}
dataset := bigQueryConnection.GetDataset()
if dataset == nil {
return errors.New("dataset is nil")
}
datasetID = dataset.DatasetID
return nil
})
if err != nil {
return "", fmt.Errorf("could not get dataset ID: %w", err)
}

return datasetID, nil
}