Skip to content

Commit 2e1a9fa

Browse files
committed
Add support for 'OnUpdate'
1 parent 847d907 commit 2e1a9fa

File tree

4 files changed

+186
-6
lines changed

4 files changed

+186
-6
lines changed

oracle/migrator.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ func (m Migrator) CreateTable(values ...interface{}) error {
145145
}
146146
if constraint := rel.ParseConstraint(); constraint != nil {
147147
if constraint.Schema == stmt.Schema {
148+
// Oracle doesn’t support OnUpdate on foreign keys.
149+
// Use a trigger instead to propagate the update to the child table instead.
150+
if len(constraint.References) > 0 && constraint.OnUpdate != "" {
151+
constraint.OnUpdate = ""
152+
defer func(tx *gorm.DB, table string, constraint *schema.Constraint) {
153+
if err == nil {
154+
err = m.createUpadateCascadeTrigger(tx, constraint)
155+
}
156+
}(tx, stmt.Table, constraint)
157+
}
158+
148159
// If the same set of foreign keys already references the parent column,
149160
// remove duplicates to avoid ORA-02274: duplicate referential constraint specifications
150161
var foreignKeys []string
@@ -399,6 +410,32 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
399410
return columnTypes, execErr
400411
}
401412

413+
// CreateConstraint creates constraint based on the given 'value' and 'name'
414+
func (m Migrator) CreateConstraint(value interface{}, name string) error {
415+
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
416+
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
417+
if constraint != nil {
418+
if c, ok := constraint.(*schema.Constraint); ok {
419+
// Oracle doesn’t support OnUpdate on foreign keys.
420+
// Use a trigger instead to propagate the update to the child table instead.
421+
if len(c.References) > 0 && c.OnUpdate != "" {
422+
c.OnUpdate = ""
423+
constraint = c
424+
m.createUpadateCascadeTrigger(m.DB, c)
425+
}
426+
}
427+
428+
vars := []interface{}{clause.Table{Name: table}}
429+
if stmt.TableExpr != nil {
430+
vars[0] = stmt.TableExpr
431+
}
432+
sql, values := constraint.Build()
433+
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
434+
}
435+
return nil
436+
})
437+
}
438+
402439
// HasConstraint checks whether the table for the given `value` contains the specified constraint `name`
403440
func (m Migrator) HasConstraint(value interface{}, name string) bool {
404441
var count int64
@@ -418,6 +455,33 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
418455
return count > 0
419456
}
420457

458+
// DropConstraint drops constraint based on the given 'value' and 'name'
459+
func (m Migrator) DropConstraint(value interface{}, name string) error {
460+
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
461+
462+
constraint, _ := m.GuessConstraintInterfaceAndTable(stmt, name)
463+
464+
if c, ok := constraint.(*schema.Constraint); ok && c != nil {
465+
if len(c.References) > 0 && c.OnUpdate != "" {
466+
for i, fk := range c.ForeignKeys {
467+
triggerName := m.FkTriggerName(
468+
c.ReferenceSchema.Table,
469+
c.References[i].DBName,
470+
c.Schema.Table,
471+
fk.DBName,
472+
)
473+
return m.DB.Exec("DROP TRIGGER ?", clause.Column{Name: triggerName}).Error
474+
}
475+
}
476+
}
477+
return nil
478+
}); err != nil {
479+
return err
480+
}
481+
482+
return m.Migrator.DropConstraint(value, name)
483+
}
484+
421485
// DropIndex drops the index with the specified `name` from the table associated with `value`
422486
func (m Migrator) DropIndex(value interface{}, name string) error {
423487
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
@@ -569,3 +633,65 @@ func (m Migrator) isNumeric(s string) bool {
569633
_, err := strconv.ParseFloat(s, 64)
570634
return err == nil
571635
}
636+
637+
func (m Migrator) FkTriggerName(refTable string, refField string, table string, field string) string {
638+
return fmt.Sprintf("fk_trigger_%s_%s_%s_%s", refTable, refField, table, field)
639+
}
640+
641+
// Creates a trigger to cascade the update to the child table
642+
func (m Migrator) createUpadateCascadeTrigger(tx *gorm.DB, constraint *schema.Constraint) error {
643+
for i, fk := range constraint.ForeignKeys {
644+
var (
645+
tmpBuilder strings.Builder
646+
plsqlBuilder strings.Builder
647+
parentTable string = constraint.ReferenceSchema.Table
648+
parentField string = constraint.References[i].DBName
649+
table string = constraint.Schema.Table
650+
field string = fk.DBName
651+
652+
triggerName string = m.FkTriggerName(parentTable, parentField, table, field)
653+
654+
quotedParentTable string
655+
quotedParentField string
656+
quotedTable string
657+
quotedField string
658+
quotedTriggerName string
659+
)
660+
661+
// Initialize quoted variables according to the driver’s quoting rules
662+
writeQuotedIdentifier(&tmpBuilder, parentTable)
663+
quotedParentTable = tmpBuilder.String()
664+
tmpBuilder.Reset()
665+
666+
writeQuotedIdentifier(&tmpBuilder, parentField)
667+
quotedParentField = tmpBuilder.String()
668+
tmpBuilder.Reset()
669+
670+
writeQuotedIdentifier(&tmpBuilder, table)
671+
quotedTable = tmpBuilder.String()
672+
tmpBuilder.Reset()
673+
674+
writeQuotedIdentifier(&tmpBuilder, field)
675+
quotedField = tmpBuilder.String()
676+
tmpBuilder.Reset()
677+
678+
writeQuotedIdentifier(&tmpBuilder, triggerName)
679+
quotedTriggerName = tmpBuilder.String()
680+
tmpBuilder.Reset()
681+
682+
// Start PL/SQL block
683+
plsqlBuilder.WriteString("CREATE OR REPLACE TRIGGER " + quotedTriggerName + "\n")
684+
plsqlBuilder.WriteString("AFTER UPDATE OF " + quotedParentField + " ON " + quotedParentTable + "\n")
685+
plsqlBuilder.WriteString("FOR EACH ROW\n")
686+
plsqlBuilder.WriteString("BEGIN\n")
687+
plsqlBuilder.WriteString(" UPDATE " + quotedTable + "\n")
688+
plsqlBuilder.WriteString(" SET " + quotedField + " = :NEW." + quotedParentField + "\n")
689+
plsqlBuilder.WriteString(" WHERE " + quotedField + " = :OLD." + quotedParentField)
690+
plsqlBuilder.WriteString(";\n")
691+
plsqlBuilder.WriteString("END;")
692+
if err := tx.Exec(plsqlBuilder.String()).Error; err != nil {
693+
return err
694+
}
695+
}
696+
return nil
697+
}

tests/associations_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ func TestAssociationNotNullClear(t *testing.T) {
112112
}
113113

114114
func TestForeignKeyConstraints(t *testing.T) {
115-
t.Skip()
116115
type Profile struct {
117116
ID uint
118117
Name string
@@ -121,7 +120,7 @@ func TestForeignKeyConstraints(t *testing.T) {
121120

122121
type Member struct {
123122
ID uint
124-
Refer uint `gorm:"uniqueIndex"`
123+
Refer uint `gorm:"unique"`
125124
Name string
126125
Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"`
127126
}
@@ -168,11 +167,10 @@ func TestForeignKeyConstraints(t *testing.T) {
168167
}
169168

170169
func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
171-
t.Skip()
172170
type Profile struct {
173171
ID uint
174172
Name string
175-
Refer uint `gorm:"uniqueIndex"`
173+
Refer uint `gorm:"unique"`
176174
}
177175

178176
type Member struct {

tests/go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ go 1.24.4
55
require gorm.io/gorm v1.30.0
66

77
require (
8-
github.com/oracle-samples/gorm-oracle v0.0.1
8+
github.com/godror/godror v0.49.0
9+
github.com/oracle-samples/gorm-oracle v0.1.0
910
github.com/stretchr/testify v1.10.0
1011
)
1112

1213
require (
1314
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
1415
github.com/davecgh/go-spew v1.1.1 // indirect
1516
github.com/go-logfmt/logfmt v0.6.0 // indirect
16-
github.com/godror/godror v0.49.0 // indirect
1717
github.com/godror/knownpb v0.3.0 // indirect
1818
github.com/jinzhu/inflection v1.0.0 // indirect
1919
github.com/jinzhu/now v1.1.5 // indirect

tests/migrate_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,3 +1721,59 @@ func TestAutoMigrateDecimal(t *testing.T) {
17211721
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
17221722
}
17231723
}
1724+
1725+
func TestMigrateOnUpdateConstraint(t *testing.T) {
1726+
type Owner struct {
1727+
ID int
1728+
Name string
1729+
}
1730+
1731+
type Pen struct {
1732+
gorm.Model
1733+
OwnerID int
1734+
Owner Owner `gorm:"constraint:OnUpdate:CASCADE,OnDelete:SET NULL;"`
1735+
}
1736+
1737+
DB.Migrator().DropTable(&Pen{}, &Owner{})
1738+
1739+
// Verify the trigger is created using CreateTable()
1740+
if err := DB.Migrator().CreateTable(&Owner{}, &Pen{}); err != nil {
1741+
t.Fatalf("Failed to create table, got error: %v", err)
1742+
}
1743+
1744+
triggerName := "fk_trigger_owners_id_pens_owner_id"
1745+
1746+
var count int
1747+
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count)
1748+
if count != 1 {
1749+
t.Errorf("Should find the trigger %s", triggerName)
1750+
}
1751+
1752+
// Verify the trigger is created using CreateConstraint()
1753+
constraintName := "fk_pens_owner"
1754+
if err := DB.Migrator().DropConstraint(&Pen{}, constraintName); err != nil {
1755+
t.Errorf("failed to drop constraint %v, got error %v", constraintName, err)
1756+
}
1757+
1758+
if err := DB.Migrator().CreateConstraint(&Pen{}, constraintName); err != nil {
1759+
t.Errorf("failed to create constraint %v, got error %v", constraintName, err)
1760+
}
1761+
1762+
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count)
1763+
if count != 1 {
1764+
t.Errorf("Should find the trigger %s", triggerName)
1765+
}
1766+
1767+
// Verify the trigger works
1768+
user := Pen{Owner: Owner{ID: 1, Name: "John"}}
1769+
DB.Create(&user)
1770+
1771+
DB.Model(user.Owner).Update("id", 100)
1772+
1773+
var user2 Pen
1774+
if err := DB.First(&user2, "\"id\" = ?", user.ID).Error; err != nil {
1775+
panic(fmt.Errorf("failed to find member, got error: %v", err))
1776+
} else if user2.OwnerID != 100 {
1777+
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, user2.OwnerID))
1778+
}
1779+
}

0 commit comments

Comments
 (0)