diff --git a/migrator.go b/migrator.go index e6f2160..44a3939 100644 --- a/migrator.go +++ b/migrator.go @@ -15,10 +15,10 @@ import ( const indexSQL = ` SELECT + col.name AS column_name, i.name AS index_name, i.is_unique, - i.is_primary_key, - col.name AS column_name + i.is_primary_key FROM sys.indexes i LEFT JOIN sys.index_columns ic ON ic.object_id = i.object_id AND ic.index_id = i.index_id @@ -499,11 +499,10 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error } type Index struct { - TableName string - ColumnName string - IndexName string - IsUnique sql.NullBool - IsPrimaryKey sql.NullBool + ColumnName string `gorm:"column:column_name"` + IndexName string `gorm:"column:index_name"` + IsUnique sql.NullBool `gorm:"column:is_unique"` + IsPrimaryKey sql.NullBool `gorm:"column:is_primary_key"` } func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { diff --git a/migrator_test.go b/migrator_test.go index 056a58a..63a86e0 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -1,6 +1,7 @@ package sqlserver_test import ( + "encoding/json" "os" "reflect" "testing" @@ -209,26 +210,26 @@ func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_fiel func TestMigrator_MigrateColumnComment(t *testing.T) { db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) if err != nil { - t.Error(err) + t.Fatal(err) } - migrator := db.Debug().Migrator() + dm := db.Debug().Migrator() tableModel := new(TestTableFieldComment) defer func() { - if err = migrator.DropTable(tableModel); err != nil { + if err = dm.DropTable(tableModel); err != nil { t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err) } }() - if err = migrator.AutoMigrate(tableModel); err != nil { + if err = dm.AutoMigrate(tableModel); err != nil { t.Fatal(err) } tableModelUpdate := new(TestTableFieldCommentUpdate) - if err = migrator.AutoMigrate(tableModelUpdate); err != nil { + if err = dm.AutoMigrate(tableModelUpdate); err != nil { t.Error(err) } - if m, ok := migrator.(sqlserver.Migrator); ok { + if m, ok := dm.(sqlserver.Migrator); ok { stmt := db.Model(tableModelUpdate).Find(nil).Statement if stmt == nil || stmt.Schema == nil { t.Fatal("expected Statement.Schema, got nil") @@ -248,3 +249,56 @@ func TestMigrator_MigrateColumnComment(t *testing.T) { t.Logf("got comments: %#v", gotComments) } } + +func TestMigrator_GetIndexes(t *testing.T) { + db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) + if err != nil { + t.Fatal(err) + } + dm := db.Debug().Migrator() + + type testTableIndex struct { + Test uint64 `gorm:"index"` + } + type testTableUnique struct { + ID string `gorm:"index:unique_id,class:UNIQUE,where:id IS NOT NULL"` + } + type testTablePrimaryKey struct { + ID string `gorm:"primaryKey"` + } + + type args struct { + value interface{} + } + tests := []struct { + name string + args args + wantErr bool + }{ + {name: "index", args: args{value: new(testTableIndex)}}, + {name: "unique", args: args{value: new(testTableUnique)}}, + {name: "primaryKey", args: args{value: new(testTablePrimaryKey)}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err = dm.AutoMigrate(tt.args.value); err != nil { + t.Error(err) + } + got, gotErr := dm.GetIndexes(tt.args.value) + if (gotErr != nil) != tt.wantErr { + t.Errorf("GetIndexes() error = %v, wantErr %v", gotErr, tt.wantErr) + return + } + for _, index := range got { + _, validUnique := index.Unique() + _, validPK := index.PrimaryKey() + indexBytes, _ := json.Marshal(index) + if index.Name() == "" && !validUnique && !validPK { + t.Errorf("GetIndexes() got = %s empty", indexBytes) + } else { + t.Logf("GetIndexes() got = %s", indexBytes) + } + } + }) + } +}