Skip to content

Commit 649ef38

Browse files
committed
resolve conflict in tests/generics_test.go
2 parents cbefda8 + ede94ac commit 649ef38

30 files changed

+2265
-966
lines changed

oracle/clause_builder.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) {
300300
if len(conflictColumns) == 0 {
301301
// If no columns specified, try to use primary key fields as default
302302
if stmt.Schema == nil || len(stmt.Schema.PrimaryFields) == 0 {
303-
stmt.AddError(fmt.Errorf("OnConflict requires either explicit columns or primary key fields"))
304303
return
305304
}
306305
for _, primaryField := range stmt.Schema.PrimaryFields {
@@ -326,7 +325,16 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) {
326325
missingColumns = append(missingColumns, conflictCol.Name)
327326
}
328327
}
328+
329329
if len(missingColumns) > 0 {
330+
// primary keys with auto increment will always be missing from create values columns
331+
for _, missingCol := range missingColumns {
332+
field := stmt.Schema.LookUpField(missingCol)
333+
if field != nil && field.PrimaryKey && field.AutoIncrement {
334+
return
335+
}
336+
}
337+
330338
var selectedColumns []string
331339
for col := range selectedColumnSet {
332340
selectedColumns = append(selectedColumns, col)
@@ -336,6 +344,34 @@ func OnConflictClauseBuilder(c clause.Clause, builder clause.Builder) {
336344
return
337345
}
338346

347+
// exclude primary key, default value columns from merge update clause
348+
if len(onConflict.DoUpdates) > 0 {
349+
hasPrimaryKey := false
350+
351+
for _, assignment := range onConflict.DoUpdates {
352+
field := stmt.Schema.LookUpField(assignment.Column.Name)
353+
if field != nil && field.PrimaryKey {
354+
hasPrimaryKey = true
355+
break
356+
}
357+
}
358+
359+
if hasPrimaryKey {
360+
onConflict.DoUpdates = nil
361+
columns := make([]string, 0, len(values.Columns)-1)
362+
for _, col := range values.Columns {
363+
field := stmt.Schema.LookUpField(col.Name)
364+
365+
if field != nil && !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
366+
strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
367+
columns = append(columns, col.Name)
368+
}
369+
370+
}
371+
onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
372+
}
373+
}
374+
339375
// Build MERGE statement
340376
buildMergeInClause(stmt, onConflict, values, conflictColumns)
341377
}

oracle/common.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,50 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) {
424424
builder.WriteByte('"')
425425
}
426426

427+
// writeTableRecordCollectionDecl writes the PL/SQL declarations needed to
428+
// define a custom record type and a collection of that record type,
429+
// based on the schema of the given table.
430+
//
431+
// Specifically, it generates:
432+
// - A RECORD type (`t_record`) with fields corresponding to the table's columns.
433+
// - A nested TABLE type (`t_records`) of `t_record`.
434+
//
435+
// The declarations are written into the provided strings.Builder in the
436+
// correct PL/SQL syntax, so they can be used as part of a larger PL/SQL block.
437+
//
438+
// Example output:
439+
//
440+
// TYPE t_record IS RECORD (
441+
// "id" "users"."id"%TYPE,
442+
// "created_at" "users"."created_at"%TYPE,
443+
// ...
444+
// );
445+
// TYPE t_records IS TABLE OF t_record;
446+
//
447+
// Parameters:
448+
// - plsqlBuilder: The builder to write the PL/SQL code into.
449+
// - dbNames: The slice containing the column names.
450+
// - table: The table name
451+
func writeTableRecordCollectionDecl(plsqlBuilder *strings.Builder, dbNames []string, table string) {
452+
// Declare a record where each element has the same structure as a row from the given table
453+
plsqlBuilder.WriteString(" TYPE t_record IS RECORD (\n")
454+
for i, field := range dbNames {
455+
if i > 0 {
456+
plsqlBuilder.WriteString(",\n")
457+
}
458+
plsqlBuilder.WriteString(" ")
459+
writeQuotedIdentifier(plsqlBuilder, field)
460+
plsqlBuilder.WriteString(" ")
461+
writeQuotedIdentifier(plsqlBuilder, table)
462+
plsqlBuilder.WriteString(".")
463+
writeQuotedIdentifier(plsqlBuilder, field)
464+
plsqlBuilder.WriteString("%TYPE")
465+
}
466+
plsqlBuilder.WriteString("\n")
467+
plsqlBuilder.WriteString(" );\n")
468+
plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF t_record;\n")
469+
}
470+
427471
// Helper function to check if a value represents NULL
428472
func isNullValue(value interface{}) bool {
429473
if value == nil {

oracle/create.go

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ func validateCreateData(stmt *gorm.Statement) error {
176176

177177
// Build PL/SQL block for bulk INSERT/MERGE with RETURNING
178178
func buildBulkInsertPLSQL(db *gorm.DB, createValues clause.Values) {
179+
sanitizeCreateValuesForBulkArrays(db.Statement, &createValues)
180+
179181
stmt := db.Statement
180182
schema := stmt.Schema
181183

@@ -217,7 +219,6 @@ func buildBulkInsertPLSQL(db *gorm.DB, createValues clause.Values) {
217219
conflictColumns := onConflict.Columns
218220
if len(conflictColumns) == 0 {
219221
if len(schema.PrimaryFields) == 0 {
220-
db.AddError(fmt.Errorf("OnConflict requires either explicit columns or primary key fields"))
221222
return
222223
}
223224
for _, primaryField := range schema.PrimaryFields {
@@ -238,6 +239,8 @@ func buildBulkInsertPLSQL(db *gorm.DB, createValues clause.Values) {
238239

239240
// Build PL/SQL block for bulk MERGE with RETURNING (OnConflict case)
240241
func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClause clause.Clause) {
242+
sanitizeCreateValuesForBulkArrays(db.Statement, &createValues)
243+
241244
stmt := db.Statement
242245
schema := stmt.Schema
243246

@@ -251,7 +254,6 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
251254
conflictColumns := onConflict.Columns
252255
if len(conflictColumns) == 0 {
253256
if schema == nil || len(schema.PrimaryFields) == 0 {
254-
db.AddError(fmt.Errorf("OnConflict requires either explicit columns or primary key fields"))
255257
return
256258
}
257259
for _, primaryField := range schema.PrimaryFields {
@@ -265,9 +267,11 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
265267
valuesColumnMap[strings.ToUpper(column.Name)] = true
266268
}
267269

270+
// Filter conflict columns to remove non unique columns
268271
var filteredConflictColumns []clause.Column
269272
for _, conflictCol := range conflictColumns {
270-
if valuesColumnMap[strings.ToUpper(conflictCol.Name)] {
273+
field := stmt.Schema.LookUpField(conflictCol.Name)
274+
if valuesColumnMap[strings.ToUpper(conflictCol.Name)] && (field.Unique || field.AutoIncrement) {
271275
filteredConflictColumns = append(filteredConflictColumns, conflictCol)
272276
}
273277
}
@@ -285,9 +289,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
285289

286290
// Start PL/SQL block
287291
plsqlBuilder.WriteString("DECLARE\n")
288-
plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ")
289-
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
290-
plsqlBuilder.WriteString("%ROWTYPE;\n")
292+
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
291293
plsqlBuilder.WriteString(" l_affected_records t_records;\n")
292294

293295
// Create array types and variables for each column
@@ -336,6 +338,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
336338

337339
// Build ON clause using conflict columns
338340
plsqlBuilder.WriteString(" ON (")
341+
339342
for idx, conflictCol := range conflictColumns {
340343
if idx > 0 {
341344
plsqlBuilder.WriteString(" AND ")
@@ -409,6 +412,25 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
409412
}
410413
}
411414
plsqlBuilder.WriteString("\n")
415+
} else {
416+
onCols := map[string]struct{}{}
417+
for _, c := range conflictColumns {
418+
onCols[strings.ToUpper(c.Name)] = struct{}{}
419+
}
420+
421+
// Picking the first non-ON column from the INSERT/MERGE columns
422+
var noopCol string
423+
for _, c := range createValues.Columns {
424+
if _, inOn := onCols[strings.ToUpper(c.Name)]; !inOn {
425+
noopCol = c.Name
426+
break
427+
}
428+
}
429+
plsqlBuilder.WriteString(" WHEN MATCHED THEN UPDATE SET t.")
430+
writeQuotedIdentifier(&plsqlBuilder, noopCol)
431+
plsqlBuilder.WriteString(" = s.")
432+
writeQuotedIdentifier(&plsqlBuilder, noopCol)
433+
plsqlBuilder.WriteString("\n")
412434
}
413435

414436
// WHEN NOT MATCHED THEN INSERT (unless DoNothing for inserts)
@@ -526,9 +548,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
526548

527549
// Start PL/SQL block
528550
plsqlBuilder.WriteString("DECLARE\n")
529-
plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ")
530-
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
531-
plsqlBuilder.WriteString("%ROWTYPE;\n")
551+
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
532552
plsqlBuilder.WriteString(" l_inserted_records t_records;\n")
533553

534554
// Create array types and variables for each column
@@ -791,19 +811,6 @@ func handleSingleRowReturning(db *gorm.DB) {
791811
}
792812
}
793813

794-
// Simplified RETURNING clause addition for single row operations
795-
func addReturningClause(db *gorm.DB, fields []*schema.Field) {
796-
if len(fields) == 0 {
797-
return
798-
}
799-
800-
columns := make([]clause.Column, len(fields))
801-
for idx, field := range fields {
802-
columns[idx] = clause.Column{Name: field.DBName}
803-
}
804-
db.Statement.AddClauseIfNotExists(clause.Returning{Columns: columns})
805-
}
806-
807814
// Handle bulk RETURNING results for PL/SQL operations
808815
func getBulkReturningValues(db *gorm.DB, rowCount int) {
809816
if db.Statement.Schema == nil {
@@ -923,3 +930,34 @@ func handleLastInsertId(db *gorm.DB, result sql.Result) {
923930
}
924931
}
925932
}
933+
934+
// This replaces expressions (clause.Expr) in bulk insert values
935+
// with appropriate NULL placeholders based on the column's data type. This ensures that
936+
// PL/SQL array binding remains consistent and avoids unsupported expressions during
937+
// FORALL bulk operations.
938+
func sanitizeCreateValuesForBulkArrays(stmt *gorm.Statement, cv *clause.Values) {
939+
for r := range cv.Values {
940+
for c, col := range cv.Columns {
941+
v := cv.Values[r][c]
942+
switch v.(type) {
943+
case clause.Expr:
944+
if f := findFieldByDBName(stmt.Schema, col.Name); f != nil {
945+
switch f.DataType {
946+
case schema.Int, schema.Uint:
947+
cv.Values[r][c] = sql.NullInt64{}
948+
case schema.Float:
949+
cv.Values[r][c] = sql.NullFloat64{}
950+
case schema.String:
951+
cv.Values[r][c] = sql.NullString{}
952+
case schema.Time:
953+
cv.Values[r][c] = sql.NullTime{}
954+
default:
955+
cv.Values[r][c] = nil
956+
}
957+
} else {
958+
cv.Values[r][c] = nil
959+
}
960+
}
961+
}
962+
}
963+
}

oracle/delete.go

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ func Delete(db *gorm.DB) {
9393
addPrimaryKeyWhereClause(db)
9494
}
9595

96+
// redirect soft-delete to update clause with bulk returning
97+
if stmt.Schema != nil {
98+
if deletedAtField := stmt.Schema.LookUpField("deleted_at"); deletedAtField != nil && !stmt.Unscoped {
99+
for _, c := range stmt.Schema.DeleteClauses {
100+
stmt.AddClause(c)
101+
}
102+
delete(stmt.Clauses, "DELETE")
103+
delete(stmt.Clauses, "FROM")
104+
stmt.SQL.Reset()
105+
stmt.Vars = stmt.Vars[:0]
106+
stmt.AddClauseIfNotExists(clause.Update{})
107+
Update(db)
108+
return
109+
}
110+
}
111+
96112
// This prevents soft deletes from bypassing the safety check
97113
checkMissingWhereConditions(db)
98114
if db.Error != nil {
@@ -239,9 +255,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
239255

240256
// Start PL/SQL block
241257
plsqlBuilder.WriteString("DECLARE\n")
242-
plsqlBuilder.WriteString(" TYPE t_records IS TABLE OF ")
243-
writeQuotedIdentifier(&plsqlBuilder, stmt.Table)
244-
plsqlBuilder.WriteString("%ROWTYPE;\n")
258+
writeTableRecordCollectionDecl(&plsqlBuilder, stmt.Schema.DBNames, stmt.Table)
245259
plsqlBuilder.WriteString(" l_deleted_records t_records;\n")
246260
plsqlBuilder.WriteString("BEGIN\n")
247261

@@ -434,25 +448,7 @@ func executeDelete(db *gorm.DB) {
434448
_, hasReturning := stmt.Clauses["RETURNING"]
435449

436450
if hasReturning {
437-
// For RETURNING, we need to check if it's a soft delete or hard delete
438-
if stmt.Schema != nil {
439-
if deletedAtField := stmt.Schema.LookUpField("deleted_at"); deletedAtField != nil && !stmt.Unscoped {
440-
// Soft delete with RETURNING - use QueryContext
441-
if rows, err := stmt.ConnPool.QueryContext(stmt.Context, stmt.SQL.String(), stmt.Vars...); err == nil {
442-
defer rows.Close()
443-
gorm.Scan(rows, db, gorm.ScanInitialized)
444-
445-
if stmt.Result != nil {
446-
stmt.Result.RowsAffected = db.RowsAffected
447-
}
448-
} else {
449-
db.AddError(err)
450-
}
451-
return
452-
}
453-
}
454-
455-
// Hard delete with RETURNING - use ExecContext (for PL/SQL blocks)
451+
// Hard delete & soft delete with RETURNING - use ExecContext (for PL/SQL blocks)
456452
result, err := stmt.ConnPool.ExecContext(stmt.Context, stmt.SQL.String(), stmt.Vars...)
457453
if err == nil {
458454
db.RowsAffected, _ = result.RowsAffected()

oracle/oracle.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ func (d Dialector) Initialize(db *gorm.DB) (err error) {
102102
callback.Create().Replace("gorm:create", Create)
103103
callback.Delete().Replace("gorm:delete", Delete)
104104
callback.Update().Replace("gorm:update", Update)
105+
callback.Query().Before("gorm:query").Register("oracle:before_query", BeforeQuery)
105106

106107
maps.Copy(db.ClauseBuilders, OracleClauseBuilders())
107108

0 commit comments

Comments
 (0)