diff --git a/memory/table.go b/memory/table.go index 9d697afe0e..543c01263b 100644 --- a/memory/table.go +++ b/memory/table.go @@ -134,7 +134,7 @@ func stripTblNames(e sql.Expression) (sql.Expression, transform.TreeIdentity, er case *expression.GetField: // strip table names ne := expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable()) - ne = ne.WithQuotedNames(sql.GlobalParser, e.IsQuotedIdentifier()) + ne = ne.WithQuotedNames(sql.GlobalSchemaFormatter, e.IsQuotedIdentifier()) return ne, transform.NewTree, nil default: } diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index fc00960285..aaf447eb65 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -264,14 +264,15 @@ func (ab *Builder) Build() *Analyzer { } return &Analyzer{ - Debug: debug || ab.debug, - Verbose: verbose, - contextStack: make([]string, 0), - Batches: batches, - Catalog: NewCatalog(ab.provider), - Coster: memo.NewDefaultCoster(), - ExecBuilder: rowexec.DefaultBuilder, - Parser: sql.GlobalParser, + Debug: debug || ab.debug, + Verbose: verbose, + contextStack: make([]string, 0), + Batches: batches, + Catalog: NewCatalog(ab.provider), + Coster: memo.NewDefaultCoster(), + ExecBuilder: rowexec.DefaultBuilder, + Parser: sql.GlobalParser, + SchemaFormatter: sql.GlobalSchemaFormatter, } } @@ -296,6 +297,8 @@ type Analyzer struct { Runner StatementRunner // Parser is the parser used to parse SQL statements. Parser sql.Parser + // SchemaFormatter is used to format the schema of a node to a string. + SchemaFormatter sql.SchemaFormatter } // NewDefault creates a default Analyzer instance with all default Rules and configuration. diff --git a/sql/analyzer/resolve_column_defaults.go b/sql/analyzer/resolve_column_defaults.go index 001735321b..93e24737ce 100644 --- a/sql/analyzer/resolve_column_defaults.go +++ b/sql/analyzer/resolve_column_defaults.go @@ -315,7 +315,7 @@ func quoteDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ * switch node := n.(type) { case *plan.AlterDefaultSet: eWrapper := expression.WrapExpression(node.Default) - newExpr, same, err := quoteIdentifiers(a.Parser, eWrapper) + newExpr, same, err := quoteIdentifiers(a.SchemaFormatter, eWrapper) if err != nil { return node, transform.SameTree, err } @@ -335,7 +335,7 @@ func quoteDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ * return e, transform.SameTree, nil } - return quoteIdentifiers(a.Parser, eWrapper) + return quoteIdentifiers(a.SchemaFormatter, eWrapper) }) case *plan.ResolvedTable: ct, ok := node.Table.(*information_schema.ColumnsTable) @@ -354,7 +354,7 @@ func quoteDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ * return e, transform.SameTree, nil } - return quoteIdentifiers(a.Parser, eWrapper) + return quoteIdentifiers(a.SchemaFormatter, eWrapper) }) if err != nil { @@ -376,7 +376,7 @@ func quoteDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ * }) } -func quoteIdentifiers(parser sql.Parser, wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { +func quoteIdentifiers(schemaFormatter sql.SchemaFormatter, wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) { newDefault, ok := wrap.Unwrap().(*sql.ColumnDefaultValue) if !ok { return wrap, transform.SameTree, nil @@ -388,7 +388,7 @@ func quoteIdentifiers(parser sql.Parser, wrap *expression.Wrapper) (sql.Expressi newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { if e, isGf := expr.(*expression.GetField); isGf { - return e.WithQuotedNames(parser, true), transform.NewTree, nil + return e.WithQuotedNames(schemaFormatter, true), transform.NewTree, nil } return expr, transform.SameTree, nil }) diff --git a/sql/auth.go b/sql/auth.go index dea287a734..cb8088acc1 100644 --- a/sql/auth.go +++ b/sql/auth.go @@ -84,5 +84,5 @@ func GetAuthorizationHandlerFactory() AuthorizationHandlerFactory { if globalAuthorizationHandlerFactory != nil { return globalAuthorizationHandlerFactory } - return emptyAuthorizationHandlerFactory{} + return NoopAuthorizationHandlerFactory{} } diff --git a/sql/auth_empty.go b/sql/auth_empty.go deleted file mode 100644 index 6fb3f1b804..0000000000 --- a/sql/auth_empty.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2024 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sql - -import ( - ast "github.com/dolthub/vitess/go/vt/sqlparser" -) - -// emptyAuthorizationHandlerFactory is the AuthorizationHandlerFactory for emptyAuthorizationHandler. -type emptyAuthorizationHandlerFactory struct{} - -var _ AuthorizationHandlerFactory = emptyAuthorizationHandlerFactory{} - -// CreateHandler implements the AuthorizationHandlerFactory interface. -func (emptyAuthorizationHandlerFactory) CreateHandler(cat Catalog) AuthorizationHandler { - return emptyAuthorizationHandler{} -} - -// emptyAuthorizationHandler will always return a "true" result. -type emptyAuthorizationHandler struct{} - -var _ AuthorizationHandler = emptyAuthorizationHandler{} - -// NewQueryState implements the AuthorizationHandler interface. -func (emptyAuthorizationHandler) NewQueryState(ctx *Context) AuthorizationQueryState { - return nil -} - -// HandleAuth implements the AuthorizationHandler interface. -func (emptyAuthorizationHandler) HandleAuth(ctx *Context, aqs AuthorizationQueryState, auth ast.AuthInformation) error { - return nil -} - -// HandleAuthNode implements the AuthorizationHandler interface. -func (emptyAuthorizationHandler) HandleAuthNode(ctx *Context, state AuthorizationQueryState, node AuthorizationCheckerNode) error { - return nil -} - -// CheckDatabase implements the AuthorizationHandler interface. -func (emptyAuthorizationHandler) CheckDatabase(ctx *Context, aqs AuthorizationQueryState, dbName string) error { - return nil -} - -// CheckSchema implements the AuthorizationHandler interface. -func (emptyAuthorizationHandler) CheckSchema(ctx *Context, aqs AuthorizationQueryState, dbName string, schemaName string) error { - return nil -} - -// CheckTable implements the AuthorizationHandler interface. -func (emptyAuthorizationHandler) CheckTable(ctx *Context, aqs AuthorizationQueryState, dbName string, schemaName string, tableName string) error { - return nil -} diff --git a/sql/auth_noop.go b/sql/auth_noop.go new file mode 100644 index 0000000000..b8b4f036e8 --- /dev/null +++ b/sql/auth_noop.go @@ -0,0 +1,64 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +// NoopAuthorizationHandlerFactory is the AuthorizationHandlerFactory for emptyAuthorizationHandler. +type NoopAuthorizationHandlerFactory struct{} + +var _ AuthorizationHandlerFactory = NoopAuthorizationHandlerFactory{} + +// CreateHandler implements the AuthorizationHandlerFactory interface. +func (NoopAuthorizationHandlerFactory) CreateHandler(cat Catalog) AuthorizationHandler { + return NoopAuthorizationHandler{} +} + +// NoopAuthorizationHandler will always return a "true" result. +type NoopAuthorizationHandler struct{} + +var _ AuthorizationHandler = NoopAuthorizationHandler{} + +// NewQueryState implements the AuthorizationHandler interface. +func (NoopAuthorizationHandler) NewQueryState(ctx *Context) AuthorizationQueryState { + return nil +} + +// HandleAuth implements the AuthorizationHandler interface. +func (NoopAuthorizationHandler) HandleAuth(ctx *Context, aqs AuthorizationQueryState, auth ast.AuthInformation) error { + return nil +} + +// HandleAuthNode implements the AuthorizationHandler interface. +func (NoopAuthorizationHandler) HandleAuthNode(ctx *Context, state AuthorizationQueryState, node AuthorizationCheckerNode) error { + return nil +} + +// CheckDatabase implements the AuthorizationHandler interface. +func (NoopAuthorizationHandler) CheckDatabase(ctx *Context, aqs AuthorizationQueryState, dbName string) error { + return nil +} + +// CheckSchema implements the AuthorizationHandler interface. +func (NoopAuthorizationHandler) CheckSchema(ctx *Context, aqs AuthorizationQueryState, dbName string, schemaName string) error { + return nil +} + +// CheckTable implements the AuthorizationHandler interface. +func (NoopAuthorizationHandler) CheckTable(ctx *Context, aqs AuthorizationQueryState, dbName string, schemaName string, tableName string) error { + return nil +} diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index 583972ba17..b5e421847a 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -38,8 +38,8 @@ type GetField struct { fieldType2 sql.Type2 nullable bool - // parser is the parser used to parse the expression and print it - parser sql.Parser + // schemaFormatter is the schemaFormatter used to quote field names + schemaFormatter sql.SchemaFormatter // quoteName indicates whether the field name should be quoted when printed with String() quoteName bool @@ -170,7 +170,7 @@ func (p *GetField) String() string { // stripped away. The output of this method is load-bearing in many places of analysis and execution. if p.table == "" { if p.quoteName { - return p.parser.QuoteIdentifier(p.name) + return p.schemaFormatter.QuoteIdentifier(p.name) } return p.name } @@ -197,10 +197,10 @@ func (p *GetField) WithIndex(n int) sql.Expression { } // WithQuotedNames returns a copy of this expression with the backtick names flag set to the given value. -func (p *GetField) WithQuotedNames(parser sql.Parser, quoteNames bool) *GetField { +func (p *GetField) WithQuotedNames(formatter sql.SchemaFormatter, quoteNames bool) *GetField { p2 := *p p2.quoteName = quoteNames - p2.parser = parser + p2.schemaFormatter = formatter return &p2 } diff --git a/sql/parser.go b/sql/parser.go index e1f79372ad..48d894f349 100644 --- a/sql/parser.go +++ b/sql/parser.go @@ -28,6 +28,11 @@ import ( // It defaults to MysqlParser. var GlobalParser Parser = NewMysqlParser() +// GlobalSchemaFormatter is a temporary variable to expose Doltgres schema formatter. +// It defaults to MySqlSchemaFormatter. +var GlobalSchemaFormatter SchemaFormatter = &MySqlSchemaFormatter{} + +// Parser knows how to transform a SQL statement into an AST type Parser interface { // ParseSimple takes a |query| and returns the parsed statement. If |query| represents a no-op statement, // such as ";" or "-- comment", then implementations must return Vitess' ErrEmpty error. @@ -45,17 +50,40 @@ type Parser interface { // the index of the start of the next query. If |query| represents a no-op statement, such as ";" or "-- comment", // then implementations must return Vitess' ErrEmpty error. ParseOneWithOptions(context.Context, string, ast.ParserOptions) (ast.Statement, int, error) +} + +// SchemaFormatter knows how to format a schema into a string +type SchemaFormatter interface { + // GenerateCreateTableStatement returns 'CREATE TABLE' statement with given table names + // and column definition statements in order and the collation and character set names for the table + GenerateCreateTableStatement(tblName string, colStmts []string, temp, autoInc, tblCharsetName, tblCollName, comment string) string + // GenerateCreateTableColumnDefinition returns column definition string for 'CREATE TABLE' statement for given column. + // This part comes first in the 'CREATE TABLE' statement. + GenerateCreateTableColumnDefinition(col *Column, colDefault, onUpdate string, tableCollation CollationID) string + // GenerateCreateTablePrimaryKeyDefinition returns primary key definition string for 'CREATE TABLE' statement + // for given column(s). This part comes after each column definitions. + GenerateCreateTablePrimaryKeyDefinition(pkCols []string) string + // GenerateCreateTableIndexDefinition returns index definition string for 'CREATE TABLE' statement + // for given index. This part comes after primary key definition if there is any. Implementors can signal that the + // index definition provided cannot be included with the second return param + GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector bool, indexID string, indexCols []string, comment string) (string, bool) + // GenerateCreateTableForiegnKeyDefinition returns foreign key constraint definition string for 'CREATE TABLE' statement + // for given foreign key. This part comes after index definitions if there are any. + GenerateCreateTableForiegnKeyDefinition(fkName string, fkCols []string, parentTbl string, parentCols []string, onDelete, onUpdate string) string + // GenerateCreateTableCheckConstraintClause returns check constraint clause string for 'CREATE TABLE' statement + // for given check constraint. This part comes the last and after foreign key definitions if there are any. + GenerateCreateTableCheckConstraintClause(checkName, checkExpr string, enforced bool) string // QuoteIdentifier returns the identifier given quoted according to this parser's dialect. This is used to // standardize identifiers that cannot be parsed without quoting, because they break the normal identifier naming // rules (such as containing spaces) QuoteIdentifier(identifier string) string } -var _ Parser = &MysqlParser{} - // MysqlParser is a mysql syntax parser used as parser in the engine for Dolt. type MysqlParser struct{} +var _ Parser = &MysqlParser{} + // NewMysqlParser creates new MysqlParser func NewMysqlParser() *MysqlParser { return &MysqlParser{} @@ -105,6 +133,148 @@ func RemoveSpaceAndDelimiter(query string, d rune) string { }) } -func (m *MysqlParser) QuoteIdentifier(identifier string) string { - return fmt.Sprintf("`%s`", strings.ReplaceAll(identifier, "`", "``")) +type MySqlSchemaFormatter struct{} + +var _ SchemaFormatter = &MySqlSchemaFormatter{} + +// GenerateCreateTableStatement implements the SchemaFormatter interface. +func (m *MySqlSchemaFormatter) GenerateCreateTableStatement(tblName string, colStmts []string, temp, autoInc, tblCharsetName, tblCollName, comment string) string { + if comment != "" { + // Escape any single quotes in the comment and add the COMMENT keyword + comment = strings.ReplaceAll(comment, "'", "''") + comment = fmt.Sprintf(" COMMENT='%s'", comment) + } + + if autoInc != "" { + autoInc = fmt.Sprintf(" AUTO_INCREMENT=%s", autoInc) + } + + return fmt.Sprintf( + "CREATE%s TABLE %s (\n%s\n) ENGINE=InnoDB%s DEFAULT CHARSET=%s COLLATE=%s%s", + temp, + m.QuoteIdentifier(tblName), + strings.Join(colStmts, ",\n"), + autoInc, + tblCharsetName, + tblCollName, + comment, + ) +} + +// GenerateCreateTableColumnDefinition implements the SchemaFormatter interface. +func (m *MySqlSchemaFormatter) GenerateCreateTableColumnDefinition(col *Column, colDefault, onUpdate string, tableCollation CollationID) string { + var colTypeString string + if collationType, ok := col.Type.(TypeWithCollation); ok { + colTypeString = collationType.StringWithTableCollation(tableCollation) + } else { + colTypeString = col.Type.String() + } + stmt := fmt.Sprintf(" %s %s", m.QuoteIdentifier(col.Name), colTypeString) + if !col.Nullable { + stmt = fmt.Sprintf("%s NOT NULL", stmt) + } + + if col.AutoIncrement { + stmt = fmt.Sprintf("%s AUTO_INCREMENT", stmt) + } + + if c, ok := col.Type.(SpatialColumnType); ok { + if s, d := c.GetSpatialTypeSRID(); d { + stmt = fmt.Sprintf("%s /*!80003 SRID %v */", stmt, s) + } + } + + if col.Generated != nil { + storedStr := "" + if !col.Virtual { + storedStr = " STORED" + } + stmt = fmt.Sprintf("%s GENERATED ALWAYS AS %s%s", stmt, col.Generated.String(), storedStr) + } + + if col.Default != nil && col.Generated == nil { + stmt = fmt.Sprintf("%s DEFAULT %s", stmt, colDefault) + } + + if col.OnUpdate != nil { + stmt = fmt.Sprintf("%s ON UPDATE %s", stmt, onUpdate) + } + + if col.Comment != "" { + stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, col.Comment) + } + return stmt +} + +// GenerateCreateTablePrimaryKeyDefinition implements the SchemaFormatter interface. +func (m *MySqlSchemaFormatter) GenerateCreateTablePrimaryKeyDefinition(pkCols []string) string { + return fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(m.QuoteIdentifiers(pkCols), ",")) +} + +// GenerateCreateTableIndexDefinition implements the SchemaFormatter interface. +func (m *MySqlSchemaFormatter) GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector bool, indexID string, indexCols []string, comment string) (string, bool) { + unique := "" + if isUnique { + unique = "UNIQUE " + } + + spatial := "" + if isSpatial { + unique = "SPATIAL " + } + + fulltext := "" + if isFullText { + fulltext = "FULLTEXT " + } + + vector := "" + if isVector { + vector = "VECTOR " + } + + key := fmt.Sprintf(" %s%s%s%sKEY %s (%s)", unique, spatial, fulltext, vector, m.QuoteIdentifier(indexID), strings.Join(indexCols, ",")) + if comment != "" { + key = fmt.Sprintf("%s COMMENT '%s'", key, comment) + } + return key, true +} + +// GenerateCreateTableForiegnKeyDefinition implements the SchemaFormatter interface. +func (m *MySqlSchemaFormatter) GenerateCreateTableForiegnKeyDefinition(fkName string, fkCols []string, parentTbl string, parentCols []string, onDelete, onUpdate string) string { + keyCols := strings.Join(m.QuoteIdentifiers(fkCols), ",") + refCols := strings.Join(m.QuoteIdentifiers(parentCols), ",") + fkey := fmt.Sprintf(" CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)", m.QuoteIdentifier(fkName), keyCols, m.QuoteIdentifier(parentTbl), refCols) + if onDelete != "" { + fkey = fmt.Sprintf("%s ON DELETE %s", fkey, onDelete) + } + if onUpdate != "" { + fkey = fmt.Sprintf("%s ON UPDATE %s", fkey, onUpdate) + } + return fkey +} + +// GenerateCreateTableCheckConstraintClause implements the SchemaFormatter interface. +func (m *MySqlSchemaFormatter) GenerateCreateTableCheckConstraintClause(checkName, checkExpr string, enforced bool) string { + cc := fmt.Sprintf(" CONSTRAINT %s CHECK (%s)", m.QuoteIdentifier(checkName), checkExpr) + if !enforced { + cc = fmt.Sprintf("%s /*!80016 NOT ENFORCED */", cc) + } + return cc +} + +// QuoteIdentifier wraps the specified identifier in backticks and escapes all occurrences of backticks in the +// identifier by replacing them with double backticks. +func (m *MySqlSchemaFormatter) QuoteIdentifier(id string) string { + return fmt.Sprintf("`%s`", strings.ReplaceAll(id, "`", "``")) +} + +// QuoteIdentifiers wraps each of the specified identifiers in backticks, escapes all occurrences of backticks in +// the identifier, and returns a slice of the quoted identifiers. +func (m *MySqlSchemaFormatter) QuoteIdentifiers(ids []string) []string { + quoted := make([]string, len(ids)) + for i, id := range ids { + quoted[i] = m.QuoteIdentifier(id) + } + return quoted } diff --git a/sql/planbuilder/builder.go b/sql/planbuilder/builder.go index a518fdadd5..6ca527e1ce 100644 --- a/sql/planbuilder/builder.go +++ b/sql/planbuilder/builder.go @@ -111,8 +111,9 @@ type ProcContext struct { // New takes ctx, catalog, event scheduler, and parser. If the parser is nil, then default parser is mysql parser. func New(ctx *sql.Context, cat sql.Catalog, es sql.EventScheduler, p sql.Parser) *Builder { if p == nil { - p = sql.NewMysqlParser() + p = sql.GlobalParser } + var state sql.AuthorizationQueryState if cat != nil { state = cat.AuthorizationHandler().NewQueryState(ctx) diff --git a/sql/rowexec/show_iters.go b/sql/rowexec/show_iters.go index cf618c3b42..ec6ab5a92f 100644 --- a/sql/rowexec/show_iters.go +++ b/sql/rowexec/show_iters.go @@ -452,8 +452,11 @@ func (i *showCreateTablesIter) produceCreateTableStatement(ctx *sql.Context, tab } } - colStmts = append(colStmts, sql.GenerateCreateTableIndexDefinition(index.IsUnique(), index.IsSpatial(), - index.IsFullText(), index.IsVector(), index.ID(), indexCols, index.Comment())) + indexDefn, shouldInclude := sql.GenerateCreateTableIndexDefinition(index.IsUnique(), index.IsSpatial(), + index.IsFullText(), index.IsVector(), index.ID(), indexCols, index.Comment()) + if shouldInclude { + colStmts = append(colStmts, indexDefn) + } } fkt, err := getForeignKeyTable(table) diff --git a/sql/sqlfmt.go b/sql/sqlfmt.go index 4b6811bd65..dfc9f665e4 100644 --- a/sql/sqlfmt.go +++ b/sql/sqlfmt.go @@ -14,11 +14,6 @@ package sql -import ( - "fmt" - "strings" -) - // All functions here are used together to generate 'CREATE TABLE' statement. Each function takes what it requires // to build the definition, which are mostly exact names or values (e.g. columns, indexes names, types, etc.) // These functions allow creating the compatible 'CREATE TABLE' statement from both GMS and Dolt, which use different @@ -27,140 +22,43 @@ import ( // GenerateCreateTableStatement returns 'CREATE TABLE' statement with given table names // and column definition statements in order and the collation and character set names for the table func GenerateCreateTableStatement(tblName string, colStmts []string, temp, autoInc, tblCharsetName, tblCollName, comment string) string { - if comment != "" { - // Escape any single quotes in the comment and add the COMMENT keyword - comment = strings.ReplaceAll(comment, "'", "''") - comment = fmt.Sprintf(" COMMENT='%s'", comment) - } - - if autoInc != "" { - autoInc = fmt.Sprintf(" AUTO_INCREMENT=%s", autoInc) - } - - return fmt.Sprintf( - "CREATE%s TABLE %s (\n%s\n) ENGINE=InnoDB%s DEFAULT CHARSET=%s COLLATE=%s%s", - temp, - QuoteIdentifier(tblName), - strings.Join(colStmts, ",\n"), - autoInc, - tblCharsetName, - tblCollName, - comment, - ) + return GlobalSchemaFormatter.GenerateCreateTableStatement(tblName, colStmts, temp, autoInc, tblCharsetName, tblCollName, comment) } // GenerateCreateTableColumnDefinition returns column definition string for 'CREATE TABLE' statement for given column. // This part comes first in the 'CREATE TABLE' statement. func GenerateCreateTableColumnDefinition(col *Column, colDefault, onUpdate string, tableCollation CollationID) string { - var colTypeString string - if collationType, ok := col.Type.(TypeWithCollation); ok { - colTypeString = collationType.StringWithTableCollation(tableCollation) - } else { - colTypeString = col.Type.String() - } - stmt := fmt.Sprintf(" %s %s", QuoteIdentifier(col.Name), colTypeString) - if !col.Nullable { - stmt = fmt.Sprintf("%s NOT NULL", stmt) - } - - if col.AutoIncrement { - stmt = fmt.Sprintf("%s AUTO_INCREMENT", stmt) - } - - if c, ok := col.Type.(SpatialColumnType); ok { - if s, d := c.GetSpatialTypeSRID(); d { - stmt = fmt.Sprintf("%s /*!80003 SRID %v */", stmt, s) - } - } - - if col.Generated != nil { - storedStr := "" - if !col.Virtual { - storedStr = " STORED" - } - stmt = fmt.Sprintf("%s GENERATED ALWAYS AS %s%s", stmt, col.Generated.String(), storedStr) - } - - if col.Default != nil && col.Generated == nil { - stmt = fmt.Sprintf("%s DEFAULT %s", stmt, colDefault) - } - - if col.OnUpdate != nil { - stmt = fmt.Sprintf("%s ON UPDATE %s", stmt, onUpdate) - } - - if col.Comment != "" { - stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, col.Comment) - } - return stmt + return GlobalSchemaFormatter.GenerateCreateTableColumnDefinition(col, colDefault, onUpdate, tableCollation) } // GenerateCreateTablePrimaryKeyDefinition returns primary key definition string for 'CREATE TABLE' statement // for given column(s). This part comes after each column definitions. func GenerateCreateTablePrimaryKeyDefinition(pkCols []string) string { - return fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(QuoteIdentifiers(pkCols), ",")) + return GlobalSchemaFormatter.GenerateCreateTablePrimaryKeyDefinition(pkCols) } // GenerateCreateTableIndexDefinition returns index definition string for 'CREATE TABLE' statement // for given index. This part comes after primary key definition if there is any. -func GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector bool, indexID string, indexCols []string, comment string) string { - unique := "" - if isUnique { - unique = "UNIQUE " - } - - spatial := "" - if isSpatial { - unique = "SPATIAL " - } - - fulltext := "" - if isFullText { - fulltext = "FULLTEXT " - } - - vector := "" - if isVector { - vector = "VECTOR " - } - - key := fmt.Sprintf(" %s%s%s%sKEY %s (%s)", unique, spatial, fulltext, vector, QuoteIdentifier(indexID), strings.Join(indexCols, ",")) - if comment != "" { - key = fmt.Sprintf("%s COMMENT '%s'", key, comment) - } - return key +func GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector bool, indexID string, indexCols []string, comment string) (string, bool) { + return GlobalSchemaFormatter.GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText, isVector, indexID, indexCols, comment) } // GenerateCreateTableForiegnKeyDefinition returns foreign key constraint definition string for 'CREATE TABLE' statement // for given foreign key. This part comes after index definitions if there are any. func GenerateCreateTableForiegnKeyDefinition(fkName string, fkCols []string, parentTbl string, parentCols []string, onDelete, onUpdate string) string { - keyCols := strings.Join(QuoteIdentifiers(fkCols), ",") - refCols := strings.Join(QuoteIdentifiers(parentCols), ",") - fkey := fmt.Sprintf(" CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)", QuoteIdentifier(fkName), keyCols, QuoteIdentifier(parentTbl), refCols) - if onDelete != "" { - fkey = fmt.Sprintf("%s ON DELETE %s", fkey, onDelete) - } - if onUpdate != "" { - fkey = fmt.Sprintf("%s ON UPDATE %s", fkey, onUpdate) - } - return fkey + return GlobalSchemaFormatter.GenerateCreateTableForiegnKeyDefinition(fkName, fkCols, parentTbl, parentCols, onDelete, onUpdate) } // GenerateCreateTableCheckConstraintClause returns check constraint clause string for 'CREATE TABLE' statement // for given check constraint. This part comes the last and after foreign key definitions if there are any. func GenerateCreateTableCheckConstraintClause(checkName, checkExpr string, enforced bool) string { - cc := fmt.Sprintf(" CONSTRAINT %s CHECK (%s)", QuoteIdentifier(checkName), checkExpr) - if !enforced { - cc = fmt.Sprintf("%s /*!80016 NOT ENFORCED */", cc) - } - return cc + return GlobalSchemaFormatter.GenerateCreateTableCheckConstraintClause(checkName, checkExpr, enforced) } // QuoteIdentifier wraps the specified identifier in backticks and escapes all occurrences of backticks in the // identifier by replacing them with double backticks. func QuoteIdentifier(id string) string { - id = strings.ReplaceAll(id, "`", "``") - return fmt.Sprintf("`%s`", id) + return GlobalSchemaFormatter.QuoteIdentifier(id) } // QuoteIdentifiers wraps each of the specified identifiers in backticks, escapes all occurrences of backticks in @@ -168,7 +66,7 @@ func QuoteIdentifier(id string) string { func QuoteIdentifiers(ids []string) []string { quoted := make([]string, len(ids)) for i, id := range ids { - quoted[i] = QuoteIdentifier(id) + quoted[i] = GlobalSchemaFormatter.QuoteIdentifier(id) } return quoted }