Skip to content

Commit c2a262c

Browse files
committed
Add support for SET NULL and SET DEFAULT
1 parent 2e1a9fa commit c2a262c

File tree

4 files changed

+177
-79
lines changed

4 files changed

+177
-79
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,49 @@ func main() {
3636
}
3737
```
3838

39+
## Documentation
40+
41+
### OnUpdate Foreign Key Constraint
42+
43+
Since Oracle doesn’t support `ON UPDATE` in foreign keys, the driver simulates it using **triggers**.
44+
45+
When a field has a constraint tagged with `OnUpdate`, the driver:
46+
47+
1. Skips generating the unsupported `ON UPDATE` clause in the foreign key definition.
48+
2. Creates a trigger on the parent table that automatically cascades updates to the child table(s) whenever the referenced column is changed.
49+
50+
The `OnUpdate` tag accepts the following values (case-insensitive): `CASCADE`, `SET NULL`, and `SET DEFAULT`.
51+
52+
Take the following struct for an example:
53+
54+
```go
55+
type Profile struct {
56+
ID uint
57+
Name string
58+
Refer uint
59+
}
60+
61+
type Member struct {
62+
ID uint
63+
Name string
64+
ProfileID uint
65+
Profile Profile `gorm:"Constraint:OnUpdate:CASCADE"`
66+
}
67+
```
68+
69+
Trigger SQL created by the driver when migrating:
70+
71+
```sql
72+
CREATE OR REPLACE TRIGGER "fk_trigger_profiles_id_members_profile_id"
73+
AFTER UPDATE OF "id" ON "profiles"
74+
FOR EACH ROW
75+
BEGIN
76+
UPDATE "members"
77+
SET "profile_id" = :NEW."id"
78+
WHERE "profile_id" = :OLD."id";
79+
END;
80+
```
81+
3982
## Contributing
4083

4184
This project welcomes contributions from the community. Before submitting a pull request, please [review our contribution guide](./CONTRIBUTING.md)

oracle/common.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,12 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) {
424424
builder.WriteByte('"')
425425
}
426426

427+
func QuoteIdentifier(identifier string) string {
428+
var builder strings.Builder
429+
writeQuotedIdentifier(&builder, identifier)
430+
return builder.String()
431+
}
432+
427433
// writeTableRecordCollectionDecl writes the PL/SQL declarations needed to
428434
// define a custom record type and a collection of that record type,
429435
// based on the schema of the given table.

oracle/migrator.go

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,14 @@ func (m Migrator) CreateTable(values ...interface{}) error {
148148
// Oracle doesn’t support OnUpdate on foreign keys.
149149
// Use a trigger instead to propagate the update to the child table instead.
150150
if len(constraint.References) > 0 && constraint.OnUpdate != "" {
151-
constraint.OnUpdate = ""
152-
defer func(tx *gorm.DB, table string, constraint *schema.Constraint) {
151+
defer func(tx *gorm.DB, table string, constraint *schema.Constraint, onUpdate string) {
153152
if err == nil {
153+
// retore the OnUpdate value
154+
constraint.OnUpdate = onUpdate
154155
err = m.createUpadateCascadeTrigger(tx, constraint)
155156
}
156-
}(tx, stmt.Table, constraint)
157+
}(tx, stmt.Table, constraint, constraint.OnUpdate)
158+
constraint.OnUpdate = ""
157159
}
158160

159161
// If the same set of foreign keys already references the parent column,
@@ -419,9 +421,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
419421
// Oracle doesn’t support OnUpdate on foreign keys.
420422
// Use a trigger instead to propagate the update to the child table instead.
421423
if len(c.References) > 0 && c.OnUpdate != "" {
424+
m.createUpadateCascadeTrigger(m.DB, c)
422425
c.OnUpdate = ""
423426
constraint = c
424-
m.createUpadateCascadeTrigger(m.DB, c)
425427
}
426428
}
427429

@@ -640,58 +642,57 @@ func (m Migrator) FkTriggerName(refTable string, refField string, table string,
640642

641643
// Creates a trigger to cascade the update to the child table
642644
func (m Migrator) createUpadateCascadeTrigger(tx *gorm.DB, constraint *schema.Constraint) error {
645+
onUpdate := strings.TrimSpace(strings.ToLower(constraint.OnUpdate))
646+
if onUpdate != "cascade" && onUpdate != "set null" && onUpdate != "set default" {
647+
return nil
648+
}
649+
650+
parentTable := constraint.ReferenceSchema.Table
651+
quotedParentTable := QuoteIdentifier(parentTable)
652+
table := constraint.Schema.Table
653+
quotedTable := QuoteIdentifier(table)
654+
643655
for i, fk := range constraint.ForeignKeys {
644-
var (
645-
tmpBuilder strings.Builder
646-
plsqlBuilder strings.Builder
647-
parentTable string = constraint.ReferenceSchema.Table
648-
parentField string = constraint.References[i].DBName
649-
table string = constraint.Schema.Table
650-
field string = fk.DBName
651-
652-
triggerName string = m.FkTriggerName(parentTable, parentField, table, field)
653-
654-
quotedParentTable string
655-
quotedParentField string
656-
quotedTable string
657-
quotedField string
658-
quotedTriggerName string
656+
parentField := constraint.References[i].DBName
657+
quotedParentField := QuoteIdentifier(parentField)
658+
field := fk.DBName
659+
quotedField := QuoteIdentifier(field)
660+
triggerName := m.FkTriggerName(parentTable, parentField, table, field)
661+
quotedTriggerName := QuoteIdentifier(triggerName)
662+
663+
var updateValue string
664+
switch onUpdate {
665+
case "cascade":
666+
updateValue = ":NEW." + quotedParentField
667+
case "set null":
668+
updateValue = "NULL"
669+
case "set default":
670+
updateValue = "DEFAULT"
671+
}
672+
673+
plsql := fmt.Sprintf(
674+
`CREATE OR REPLACE TRIGGER %s
675+
AFTER UPDATE OF %s ON %s
676+
FOR EACH ROW
677+
BEGIN
678+
UPDATE %s
679+
SET %s = %s
680+
WHERE %s = :OLD.%s;
681+
END;`,
682+
quotedTriggerName,
683+
quotedParentField,
684+
quotedParentTable,
685+
quotedTable,
686+
quotedField,
687+
updateValue,
688+
quotedField,
689+
quotedParentField,
659690
)
660691

661-
// Initialize quoted variables according to the driver’s quoting rules
662-
writeQuotedIdentifier(&tmpBuilder, parentTable)
663-
quotedParentTable = tmpBuilder.String()
664-
tmpBuilder.Reset()
665-
666-
writeQuotedIdentifier(&tmpBuilder, parentField)
667-
quotedParentField = tmpBuilder.String()
668-
tmpBuilder.Reset()
669-
670-
writeQuotedIdentifier(&tmpBuilder, table)
671-
quotedTable = tmpBuilder.String()
672-
tmpBuilder.Reset()
673-
674-
writeQuotedIdentifier(&tmpBuilder, field)
675-
quotedField = tmpBuilder.String()
676-
tmpBuilder.Reset()
677-
678-
writeQuotedIdentifier(&tmpBuilder, triggerName)
679-
quotedTriggerName = tmpBuilder.String()
680-
tmpBuilder.Reset()
681-
682-
// Start PL/SQL block
683-
plsqlBuilder.WriteString("CREATE OR REPLACE TRIGGER " + quotedTriggerName + "\n")
684-
plsqlBuilder.WriteString("AFTER UPDATE OF " + quotedParentField + " ON " + quotedParentTable + "\n")
685-
plsqlBuilder.WriteString("FOR EACH ROW\n")
686-
plsqlBuilder.WriteString("BEGIN\n")
687-
plsqlBuilder.WriteString(" UPDATE " + quotedTable + "\n")
688-
plsqlBuilder.WriteString(" SET " + quotedField + " = :NEW." + quotedParentField + "\n")
689-
plsqlBuilder.WriteString(" WHERE " + quotedField + " = :OLD." + quotedParentField)
690-
plsqlBuilder.WriteString(";\n")
691-
plsqlBuilder.WriteString("END;")
692-
if err := tx.Exec(plsqlBuilder.String()).Error; err != nil {
692+
if err := tx.Exec(plsql).Error; err != nil {
693693
return err
694694
}
695695
}
696+
696697
return nil
697698
}

tests/migrate_test.go

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,52 +1728,100 @@ func TestMigrateOnUpdateConstraint(t *testing.T) {
17281728
Name string
17291729
}
17301730

1731-
type Pen struct {
1731+
type Pen1 struct {
17321732
gorm.Model
17331733
OwnerID int
1734-
Owner Owner `gorm:"constraint:OnUpdate:CASCADE,OnDelete:SET NULL;"`
1734+
Owner Owner `gorm:"constraint:OnUpdate:CASCADE;"`
17351735
}
17361736

1737-
DB.Migrator().DropTable(&Pen{}, &Owner{})
1737+
type Pen2 struct {
1738+
gorm.Model
1739+
OwnerID int `gorm:"default: 18"`
1740+
Owner Owner `gorm:"constraint:OnUpdate:SET DEFAULT;"`
1741+
}
17381742

1739-
// Verify the trigger is created using CreateTable()
1740-
if err := DB.Migrator().CreateTable(&Owner{}, &Pen{}); err != nil {
1741-
t.Fatalf("Failed to create table, got error: %v", err)
1743+
type Pen3 struct {
1744+
gorm.Model
1745+
OwnerID int
1746+
Owner Owner `gorm:"constraint:OnUpdate:SET NULL;"`
17421747
}
17431748

1744-
triggerName := "fk_trigger_owners_id_pens_owner_id"
1749+
DB.Migrator().DropTable(&Owner{}, &Pen1{}, &Pen2{}, &Pen3{})
17451750

1746-
var count int
1747-
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count)
1748-
if count != 1 {
1749-
t.Errorf("Should find the trigger %s", triggerName)
1751+
// Test 1: Verify the trigger is created using CreateTable()
1752+
if err := DB.Migrator().CreateTable(&Owner{}, &Pen1{}, &Pen2{}, &Pen3{}); err != nil {
1753+
t.Fatalf("Failed to create table, got error: %v", err)
17501754
}
17511755

1752-
// Verify the trigger is created using CreateConstraint()
1753-
constraintName := "fk_pens_owner"
1754-
if err := DB.Migrator().DropConstraint(&Pen{}, constraintName); err != nil {
1755-
t.Errorf("failed to drop constraint %v, got error %v", constraintName, err)
1756+
triggerNames := []string{
1757+
"fk_trigger_owners_id_pen1_owner_id",
1758+
"fk_trigger_owners_id_pen2_owner_id",
1759+
"fk_trigger_owners_id_pen3_owner_id",
17561760
}
17571761

1758-
if err := DB.Migrator().CreateConstraint(&Pen{}, constraintName); err != nil {
1759-
t.Errorf("failed to create constraint %v, got error %v", constraintName, err)
1762+
for _, triggerName := range triggerNames {
1763+
var count int
1764+
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count)
1765+
if count != 1 {
1766+
t.Errorf("Should find the trigger %s", triggerName)
1767+
}
17601768
}
17611769

1762-
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count)
1763-
if count != 1 {
1764-
t.Errorf("Should find the trigger %s", triggerName)
1770+
// Test 2: Verify the trigger is created using CreateConstraint()
1771+
penStructs := []interface{}{&Pen1{}, &Pen2{}, &Pen3{}}
1772+
constraintNames := []string{"fk_pen1_owner", "fk_pen2_owner", "fk_pen3_owner"}
1773+
for i := range 3 {
1774+
if err := DB.Migrator().DropConstraint(penStructs[i], constraintNames[i]); err != nil {
1775+
t.Errorf("failed to drop constraint %v, got error %v", constraintNames[i], err)
1776+
}
1777+
1778+
if err := DB.Migrator().CreateConstraint(penStructs[i], constraintNames[i]); err != nil {
1779+
t.Errorf("failed to create constraint %v, got error %v", constraintNames[i], err)
1780+
}
1781+
1782+
var count int
1783+
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerNames[i]).Scan(&count)
1784+
if count != 1 {
1785+
t.Errorf("Should find the trigger %s", triggerNames[i])
1786+
}
17651787
}
17661788

1767-
// Verify the trigger works
1768-
user := Pen{Owner: Owner{ID: 1, Name: "John"}}
1769-
DB.Create(&user)
1789+
// Test 3: Verify each trigger work
1790+
pen1 := Pen1{Owner: Owner{ID: 1, Name: "John"}}
1791+
DB.Create(&pen1)
1792+
DB.Model(pen1.Owner).Update("id", 100)
1793+
1794+
var updatedPen1 Pen1
1795+
if err := DB.First(&updatedPen1, "\"id\" = ?", pen1.ID).Error; err != nil {
1796+
panic(fmt.Errorf("failed to find member, got error: %v", err))
1797+
} else if updatedPen1.OwnerID != 100 {
1798+
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, updatedPen1.OwnerID))
1799+
}
1800+
1801+
pen2 := Pen2{Owner: Owner{ID: 2, Name: "Mary"}}
1802+
DB.Create(&pen2)
1803+
// When the ID in the owners table is updated, the primary key in pen2 (owner_id column)
1804+
// is set to its default value (18). To avoid violating the foreign key constraint in pen2,
1805+
// we need to insert this record into the owners table in advance.
1806+
owner := Owner{ID: 18, Name: "MaryBackup"}
1807+
DB.Create(&owner)
1808+
DB.Model(pen2.Owner).Update("id", 200)
1809+
1810+
var updatedPen2 Pen2
1811+
if err := DB.First(&updatedPen2, "\"id\" = ?", pen2.ID).Error; err != nil {
1812+
panic(fmt.Errorf("failed to find member, got error: %v", err))
1813+
} else if updatedPen2.OwnerID != 18 {
1814+
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 18, updatedPen2.OwnerID))
1815+
}
17701816

1771-
DB.Model(user.Owner).Update("id", 100)
1817+
pen3 := Pen3{Owner: Owner{ID: 3, Name: "Jane"}}
1818+
DB.Create(&pen3)
1819+
DB.Model(pen3.Owner).Update("id", 300)
17721820

1773-
var user2 Pen
1774-
if err := DB.First(&user2, "\"id\" = ?", user.ID).Error; err != nil {
1821+
var updatedPen3 Pen3
1822+
if err := DB.First(&updatedPen3, "\"id\" = ?", pen3.ID).Error; err != nil {
17751823
panic(fmt.Errorf("failed to find member, got error: %v", err))
1776-
} else if user2.OwnerID != 100 {
1777-
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, user2.OwnerID))
1824+
} else if updatedPen3.OwnerID != 0 {
1825+
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 0, updatedPen3.OwnerID))
17781826
}
17791827
}

0 commit comments

Comments
 (0)