diff --git a/dialect_mysql.go b/dialect_mysql.go index 06606b8b..b70617a2 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -26,11 +26,19 @@ type MySQLDialect struct { // Encoding is the character encoding to use for created tables Encoding string + + // TypeMap overrides the default column types + TypeMap map[reflect.Type]string } func (d MySQLDialect) QuerySuffix() string { return ";" } func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + if d.TypeMap != nil { + if typ, ok := d.TypeMap[val]; ok { + return typ + } + } switch val.Kind() { case reflect.Ptr: return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index 9acfee35..b37b0fdc 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -30,10 +30,15 @@ var _ = Describe("MySQLDialect", func() { dialect gorp.MySQLDialect ) + type IntAsBool int + JustBeforeEach(func() { dialect = gorp.MySQLDialect{ Engine: engine, Encoding: encoding, + TypeMap: map[reflect.Type]string{ + reflect.TypeOf(IntAsBool(0)): "boolean", + }, } }) @@ -64,6 +69,7 @@ var _ = Describe("MySQLDialect", func() { Entry("default-size string", "", 0, false, "varchar(255)"), Entry("sized string", "", 50, false, "varchar(50)"), Entry("large string", "", 1024, false, "text"), + Entry("custome type map", IntAsBool(1), 0, false, "boolean"), ) Describe("AutoIncrStr", func() { diff --git a/dialect_oracle.go b/dialect_oracle.go index c381380f..d3945e4f 100644 --- a/dialect_oracle.go +++ b/dialect_oracle.go @@ -18,7 +18,11 @@ import ( ) // Implementation of Dialect for Oracle databases. -type OracleDialect struct{} +type OracleDialect struct { + + // TypeMap overrides the default column types + TypeMap map[reflect.Type]string +} func (d OracleDialect) QuerySuffix() string { return "" } @@ -27,6 +31,11 @@ func (d OracleDialect) CreateIndexSuffix() string { return "" } func (d OracleDialect) DropIndexSuffix() string { return "" } func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + if d.TypeMap != nil { + if typ, ok := d.TypeMap[val]; ok { + return typ + } + } switch val.Kind() { case reflect.Ptr: return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) diff --git a/dialect_postgres.go b/dialect_postgres.go index a53973cd..35ac0b20 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -20,11 +20,19 @@ import ( type PostgresDialect struct { suffix string + + // TypeMap overrides the default column types + TypeMap map[reflect.Type]string } func (d PostgresDialect) QuerySuffix() string { return ";" } func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + if d.TypeMap != nil { + if typ, ok := d.TypeMap[val]; ok { + return typ + } + } switch val.Kind() { case reflect.Ptr: return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index 7d9b2975..82f36962 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -18,11 +18,19 @@ import ( type SqliteDialect struct { suffix string + + // TypeMap overrides the default column types + TypeMap map[reflect.Type]string } func (d SqliteDialect) QuerySuffix() string { return ";" } func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + if d.TypeMap != nil { + if typ, ok := d.TypeMap[val]; ok { + return typ + } + } switch val.Kind() { case reflect.Ptr: return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) diff --git a/dialect_sqlserver.go b/dialect_sqlserver.go index 8808af59..5033aee0 100644 --- a/dialect_sqlserver.go +++ b/dialect_sqlserver.go @@ -25,9 +25,17 @@ type SqlServerDialect struct { // If set to "2005" legacy datatypes will be used Version string + + // TypeMap overrides the default column types + TypeMap map[reflect.Type]string } func (d SqlServerDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + if d.TypeMap != nil { + if typ, ok := d.TypeMap[val]; ok { + return typ + } + } switch val.Kind() { case reflect.Ptr: return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) diff --git a/gorp_test.go b/gorp_test.go index 98ea4d3e..a2a0e7f8 100644 --- a/gorp_test.go +++ b/gorp_test.go @@ -2557,9 +2557,9 @@ func connect(driver string) *sql.DB { func dialectAndDriver() (gorp.Dialect, string) { switch os.Getenv("GORP_TEST_DIALECT") { case "mysql": - return gorp.MySQLDialect{"InnoDB", "UTF8"}, "mymysql" + return gorp.MySQLDialect{"InnoDB", "UTF8", nil}, "mymysql" case "gomysql": - return gorp.MySQLDialect{"InnoDB", "UTF8"}, "mysql" + return gorp.MySQLDialect{"InnoDB", "UTF8", nil}, "mysql" case "postgres": return gorp.PostgresDialect{}, "postgres" case "sqlite":