Skip to content

Commit b8d91cb

Browse files
Support update and delete returning using output (#116)
* add from as valid clause for update * add support to RETURNING using OUTPUT
1 parent 9c41053 commit b8d91cb

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

sqlserver.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ func New(config Config) gorm.Dialector {
4040
}
4141

4242
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
43-
4443
// register callbacks
45-
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{})
44+
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
45+
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT"},
46+
QueryClauses: []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"},
47+
UpdateClauses: []string{"UPDATE", "SET", "RETURNING", "FROM", "WHERE"},
48+
DeleteClauses: []string{"DELETE", "FROM", "RETURNING", "WHERE"},
49+
})
4650
db.Callback().Create().Replace("gorm:create", Create)
4751
db.Callback().Update().Replace("gorm:update", Update)
4852

@@ -97,6 +101,34 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
97101
}
98102
}
99103
},
104+
"RETURNING": func(c clause.Clause, builder clause.Builder) {
105+
if returning, ok := c.Expression.(clause.Returning); ok {
106+
if stmt, ok := builder.(*gorm.Statement); ok {
107+
var outputTable string
108+
if _, ok := stmt.Clauses["UPDATE"]; ok {
109+
outputTable = "INSERTED"
110+
} else if _, ok := stmt.Clauses["DELETE"]; ok {
111+
outputTable = "DELETED"
112+
}
113+
114+
if outputTable != "" {
115+
stmt.WriteString("OUTPUT ")
116+
117+
if len(returning.Columns) > 0 {
118+
columns := []clause.Column{}
119+
for _, column := range returning.Columns {
120+
column.Table = outputTable
121+
columns = append(columns, column)
122+
}
123+
returning.Columns = columns
124+
returning.Build(stmt)
125+
} else {
126+
stmt.WriteString(outputTable + ".*")
127+
}
128+
}
129+
}
130+
}
131+
},
100132
}
101133
}
102134

update.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"gorm.io/gorm/callbacks"
66
)
77

8-
var updateFunc = callbacks.Update(&callbacks.Config{})
8+
var updateFunc = callbacks.Update(&callbacks.Config{
9+
UpdateClauses: []string{"UPDATE", "SET", "RETURNING", "FROM", "WHERE"},
10+
})
911

1012
func Update(db *gorm.DB) {
1113
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.AutoIncrement {

0 commit comments

Comments
 (0)