diff --git a/go.mod b/go.mod index b70d909..0dda3e2 100644 --- a/go.mod +++ b/go.mod @@ -2,18 +2,24 @@ module github.com/oracle-samples/gorm-oracle go 1.24.4 -require gorm.io/gorm v1.30.0 +require gorm.io/gorm v1.31.0 -require github.com/godror/godror v0.49.0 +require github.com/godror/godror v0.49.3 require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/VictoriaMetrics/easyproto v0.1.4 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/godror/knownpb v0.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - golang.org/x/exp v0.0.0-20250531010427-b6e5de432a8b // indirect - golang.org/x/term v0.27.0 // indirect - golang.org/x/text v0.25.0 // indirect - google.golang.org/protobuf v1.36.6 // indirect + golang.org/x/exp v0.0.0-20250911091902-df9299821621 // indirect + golang.org/x/sys v0.36.0 // indirect + golang.org/x/term v0.35.0 // indirect + golang.org/x/text v0.29.0 // indirect + google.golang.org/protobuf v1.36.9 // indirect + gorm.io/datatypes v1.2.6 // indirect + gorm.io/driver/mysql v1.5.6 // indirect ) diff --git a/oracle/common.go b/oracle/common.go index 365f948..c849ede 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -40,11 +40,13 @@ package oracle import ( "database/sql" + "encoding/json" "fmt" "reflect" "strings" "time" + "gorm.io/datatypes" "gorm.io/gorm" "gorm.io/gorm/schema" ) @@ -174,6 +176,17 @@ func convertValue(val interface{}) interface{} { } switch v := val.(type) { + case json.RawMessage: + if v == nil { + return nil + } + return []byte(v) + case *json.RawMessage: + if v == nil { + return nil + } + b := []byte(*v) + return b case bool: if v { return 1 @@ -198,7 +211,25 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ if isPtr { targetType = targetType.Elem() } - + if field.FieldType == reflect.TypeOf(json.RawMessage{}) { + switch v := value.(type) { + case []byte: + return json.RawMessage(v) // from BLOB + case *[]byte: + if v == nil { + return json.RawMessage(nil) + } + return json.RawMessage(*v) + } + } + if isJSONField(field) { + switch v := value.(type) { + case string: + return datatypes.JSON([]byte(v)) + case []byte: + return datatypes.JSON(v) + } + } var converted interface{} switch targetType { @@ -276,6 +307,24 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{ return converted } +func isJSONField(f *schema.Field) bool { + _rawMsgT := reflect.TypeOf(json.RawMessage{}) + _gormJSON := reflect.TypeOf(datatypes.JSON{}) + if f == nil { + return false + } + ft := f.FieldType + return ft == _rawMsgT || ft == _gormJSON +} + +func isRawMessageField(f *schema.Field) bool { + t := f.FieldType + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t == reflect.TypeOf(json.RawMessage(nil)) +} + // Helper function to handle primitive type conversions func convertPrimitiveType(value interface{}, targetType reflect.Type) interface{} { switch targetType.Kind() { diff --git a/oracle/create.go b/oracle/create.go index 6278dcf..e7ba921 100644 --- a/oracle/create.go +++ b/oracle/create.go @@ -507,15 +507,37 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau } plsqlBuilder.WriteString("\n BULK COLLECT INTO l_affected_records;\n") - // Add OUT parameter population + // Add OUT parameter population (JSON serialized to CLOB) outParamIndex := len(stmt.Vars) for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ { for _, column := range allColumns { if field := findFieldByDBName(schema, column); field != nil { - stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) - plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1)) - db.QuoteTo(&plsqlBuilder, column) - plsqlBuilder.WriteString("; END IF;\n") + if isJSONField(field) { + if isRawMessageField(field) { + // Column is a BLOB, return raw bytes; no JSON_SERIALIZE + stmt.Vars = append(stmt.Vars, sql.Out{Dest: new([]byte)}) + plsqlBuilder.WriteString(fmt.Sprintf( + " IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", + rowIdx, outParamIndex+1, rowIdx+1, + )) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString("; END IF;\n") + } else { + // datatypes.JSON (text-based) -> serialize to CLOB + stmt.Vars = append(stmt.Vars, sql.Out{Dest: new(string)}) + plsqlBuilder.WriteString(fmt.Sprintf( + " IF l_affected_records.COUNT > %d THEN :%d := JSON_SERIALIZE(l_affected_records(%d).", + rowIdx, outParamIndex+1, rowIdx+1, + )) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString(" RETURNING CLOB); END IF;\n") + } + } else { + stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) + plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1)) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString("; END IF;\n") + } outParamIndex++ } } @@ -613,7 +635,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { } plsqlBuilder.WriteString("\n BULK COLLECT INTO l_inserted_records;\n") - // Add OUT parameter population + // Add OUT parameter population (JSON serialized to CLOB) outParamIndex := len(stmt.Vars) for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ { for _, column := range allColumns { @@ -622,9 +644,29 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) { quotedColumn := columnBuilder.String() if field := findFieldByDBName(schema, column); field != nil { - stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) - plsqlBuilder.WriteString(fmt.Sprintf(" IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n", - rowIdx, outParamIndex+1, rowIdx+1, quotedColumn)) + if isJSONField(field) { + if isRawMessageField(field) { + // Column is a BLOB, return raw bytes; no JSON_SERIALIZE + stmt.Vars = append(stmt.Vars, sql.Out{Dest: new([]byte)}) + plsqlBuilder.WriteString(fmt.Sprintf( + " IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n", + rowIdx, outParamIndex+1, rowIdx+1, quotedColumn, + )) + } else { + // datatypes.JSON (text-based) -> serialize to CLOB + stmt.Vars = append(stmt.Vars, sql.Out{Dest: new(string)}) + plsqlBuilder.WriteString(fmt.Sprintf( + " IF l_inserted_records.COUNT > %d THEN :%d := JSON_SERIALIZE(l_inserted_records(%d).%s RETURNING CLOB); END IF;\n", + rowIdx, outParamIndex+1, rowIdx+1, quotedColumn, + )) + } + } else { + stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)}) + plsqlBuilder.WriteString(fmt.Sprintf( + " IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n", + rowIdx, outParamIndex+1, rowIdx+1, quotedColumn, + )) + } outParamIndex++ } } diff --git a/oracle/delete.go b/oracle/delete.go index 817f69a..b658971 100644 --- a/oracle/delete.go +++ b/oracle/delete.go @@ -283,27 +283,48 @@ func buildBulkDeletePLSQL(db *gorm.DB) { } plsqlBuilder.WriteString("\n BULK COLLECT INTO l_deleted_records;\n") - // Create OUT parameters for each field and each row that will be deleted + // Create OUT parameters for each field and each row that will be deleted (JSON-safe) outParamIndex := len(stmt.Vars) - //TODO make it configurable - estimatedRows := 100 // Estimate maximum rows to delete + // keep your current fixed cap (same as other callbacks) + estimatedRows := 100 for rowIdx := 0; rowIdx < estimatedRows; rowIdx++ { for _, column := range allColumns { - field := findFieldByDBName(schema, column) - if field != nil { - dest := createTypedDestination(field) - stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest}) - - plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx)) - plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_deleted_records(%d).", outParamIndex+1, rowIdx+1)) - db.QuoteTo(&plsqlBuilder, column) - plsqlBuilder.WriteString(";\n") - plsqlBuilder.WriteString(" END IF;\n") + if field := findFieldByDBName(schema, column); field != nil { + if isJSONField(field) { + if isRawMessageField(field) { + // Column is a BLOB, return raw bytes; no JSON_SERIALIZE + stmt.Vars = append(stmt.Vars, sql.Out{Dest: new([]byte)}) + plsqlBuilder.WriteString(fmt.Sprintf( + " IF l_deleted_records.COUNT > %d THEN :%d := l_deleted_records(%d).", + rowIdx, outParamIndex+1, rowIdx+1, + )) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString("; END IF;\n") + } else { + // JSON -> text bind + stmt.Vars = append(stmt.Vars, sql.Out{Dest: new(string)}) + plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx)) + plsqlBuilder.WriteString(fmt.Sprintf(" :%d := JSON_SERIALIZE(l_deleted_records(%d).", outParamIndex+1, rowIdx+1)) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString(" RETURNING CLOB);\n") + plsqlBuilder.WriteString(" END IF;\n") + } + } else { + // non-JSON as before + dest := createTypedDestination(field) + stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest}) + plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx)) + plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_deleted_records(%d).", outParamIndex+1, rowIdx+1)) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString(";\n") + plsqlBuilder.WriteString(" END IF;\n") + } outParamIndex++ } } } + plsqlBuilder.WriteString("END;") stmt.SQL.Reset() diff --git a/oracle/update.go b/oracle/update.go index 2a2ee19..a54eebb 100644 --- a/oracle/update.go +++ b/oracle/update.go @@ -542,7 +542,18 @@ func buildUpdatePLSQL(db *gorm.DB) { for _, column := range allColumns { field := findFieldByDBName(schema, column) if field != nil { - dest := createTypedDestination(field) + var dest interface{} + if isJSONField(field) { + if isRawMessageField(field) { + // RawMessage -> BLOB -> []byte + dest = new([]byte) + } else { + // datatypes.JSON -> text -> string (CLOB) + dest = new(string) + } + } else { + dest = createTypedDestination(field) + } stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest}) } } @@ -553,18 +564,32 @@ func buildUpdatePLSQL(db *gorm.DB) { for colIdx, column := range allColumns { field := findFieldByDBName(schema, column) if field != nil { - // Calculate the correct parameter index (1-based for Oracle) paramIndex := outParamStartIndex + (rowIdx * len(allColumns)) + colIdx + 1 - // Add the assignment to PL/SQL with correct parameter reference - plsqlBuilder.WriteString(fmt.Sprintf(" IF l_updated_records.COUNT > %d THEN\n", rowIdx)) - plsqlBuilder.WriteString(fmt.Sprintf(" :%d := l_updated_records(%d).", paramIndex, rowIdx+1)) - db.QuoteTo(&plsqlBuilder, column) - plsqlBuilder.WriteString(";\n") - plsqlBuilder.WriteString(" END IF;\n") + plsqlBuilder.WriteString(fmt.Sprintf(" IF l_updated_records.COUNT > %d THEN ", rowIdx)) + plsqlBuilder.WriteString(fmt.Sprintf(":%d := ", paramIndex)) + + if isJSONField(field) { + if isRawMessageField(field) { + plsqlBuilder.WriteString(fmt.Sprintf("l_updated_records(%d).", rowIdx+1)) + writeQuotedIdentifier(&plsqlBuilder, column) + } else { + // serialize JSON so it binds as text + plsqlBuilder.WriteString("JSON_SERIALIZE(") + plsqlBuilder.WriteString(fmt.Sprintf("l_updated_records(%d).", rowIdx+1)) + writeQuotedIdentifier(&plsqlBuilder, column) + plsqlBuilder.WriteString(" RETURNING CLOB)") + } + } else { + plsqlBuilder.WriteString(fmt.Sprintf("l_updated_records(%d).", rowIdx+1)) + writeQuotedIdentifier(&plsqlBuilder, column) + } + + plsqlBuilder.WriteString("; END IF;\n") } } } + plsqlBuilder.WriteString("END;") stmt.SQL.Reset() diff --git a/tests/json_bulk_test.go b/tests/json_bulk_test.go new file mode 100644 index 0000000..e0a2eef --- /dev/null +++ b/tests/json_bulk_test.go @@ -0,0 +1,193 @@ +/* +** Copyright (c) 2025 Oracle and/or its affiliates. +** +** The Universal Permissive License (UPL), Version 1.0 +** +** Subject to the condition set forth below, permission is hereby granted to any +** person obtaining a copy of this software, associated documentation and/or data +** (collectively the "Software"), free of charge and under any and all copyright +** rights in the Software, and any and all patent rights owned or freely +** licensable by each licensor hereunder covering either (i) the unmodified +** Software as contributed to or provided by such licensor, or (ii) the Larger +** Works (as defined below), to deal in both +** +** (a) the Software, and +** (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if +** one is included with the Software (each a "Larger Work" to which the Software +** is contributed by such licensors), +** +** without restriction, including without limitation the rights to copy, create +** derivative works of, display, perform, and distribute the Software and make, +** use, sell, offer for sale, import, export, have made, and have sold the +** Software and the Larger Work(s), and to sublicense the foregoing rights on +** either these or other terms. +** +** This license is subject to the following condition: +** The above copyright notice and either this complete permission notice or at +** a minimum a reference to the UPL must be included in all copies or +** substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. + */ + +package tests + +import ( + "encoding/json" + "errors" + "testing" + + "gorm.io/datatypes" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +func TestBasicCRUD_JSONText(t *testing.T) { + type JsonRecord struct { + ID uint `gorm:"primaryKey;autoIncrement;column:record_id"` + Name string `gorm:"column:name"` + Properties datatypes.JSON `gorm:"column:properties"` + } + + DB.Migrator().DropTable(&JsonRecord{}) + if err := DB.AutoMigrate(&JsonRecord{}); err != nil { + t.Fatalf("migrate failed: %v", err) + } + + // INSERT + rec := JsonRecord{ + Name: "json-text", + Properties: datatypes.JSON([]byte(`{"env":"prod","owner":"team-x"}`)), + } + if err := DB.Create(&rec).Error; err != nil { + t.Fatalf("create failed: %v", err) + } + if rec.ID == 0 { + t.Fatalf("expected ID to be set") + } + + // UPDATE (with RETURNING) + var ret JsonRecord + if err := DB. + Clauses(clause.Returning{ + Columns: []clause.Column{ + {Name: "record_id"}, + {Name: "name"}, + {Name: "properties"}, + }, + }). + Model(&ret). + Where("\"record_id\" = ?", rec.ID). + Updates(map[string]any{ + "name": "json-text-upd", + "properties": datatypes.JSON([]byte(`{"env":"staging","owner":"team-y","flag":true}`)), + }).Error; err != nil { + t.Fatalf("update returning failed: %v", err) + } + if ret.ID != rec.ID || ret.Name != "json-text-upd" || len(ret.Properties) == 0 { + t.Fatalf("unexpected returning row: %#v", ret) + } + + // DELETE (with RETURNING) + var deleted []JsonRecord + if err := DB. + Where("\"record_id\" = ?", rec.ID). + Clauses(clause.Returning{ + Columns: []clause.Column{ + {Name: "record_id"}, + {Name: "name"}, + {Name: "properties"}, + }, + }). + Delete(&deleted).Error; err != nil { + t.Fatalf("delete returning failed: %v", err) + } + if len(deleted) != 1 || deleted[0].ID != rec.ID { + t.Fatalf("unexpected deleted rows: %#v", deleted) + } + + // verify gone + var check JsonRecord + err := DB.First(&check, "\"record_id\" = ?", rec.ID).Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected not found after delete, got: %v", err) + } +} + +func TestBasicCRUD_RawMessage(t *testing.T) { + type RawRecord struct { + ID uint `gorm:"primaryKey;autoIncrement;column:record_id"` + Name string `gorm:"column:name"` + Properties json.RawMessage `gorm:"column:properties"` + } + + DB.Migrator().DropTable(&RawRecord{}) + if err := DB.AutoMigrate(&RawRecord{}); err != nil { + t.Fatalf("migrate failed: %v", err) + } + + // INSERT + rec := RawRecord{ + Name: "raw-json", + Properties: json.RawMessage(`{"a":1,"b":"x"}`), + } + if err := DB.Create(&rec).Error; err != nil { + t.Fatalf("create failed: %v", err) + } + if rec.ID == 0 { + t.Fatalf("expected ID to be set") + } + + // UPDATE (with RETURNING) + var ret RawRecord + if err := DB. + Clauses(clause.Returning{ + Columns: []clause.Column{ + {Name: "record_id"}, + {Name: "name"}, + {Name: "properties"}, + }, + }). + Model(&ret). + Where("\"record_id\" = ?", rec.ID). + Updates(map[string]any{ + "name": "raw-json-upd", + "properties": json.RawMessage(`{"a":2,"c":true}`), + }).Error; err != nil { + t.Fatalf("update returning failed: %v", err) + } + if ret.ID != rec.ID || ret.Name != "raw-json-upd" || len(ret.Properties) == 0 { + t.Fatalf("unexpected returning row: %#v", ret) + } + + // DELETE (with RETURNING) + var deleted []RawRecord + if err := DB. + Where("\"record_id\" = ?", rec.ID). + Clauses(clause.Returning{ + Columns: []clause.Column{ + {Name: "record_id"}, + {Name: "name"}, + {Name: "properties"}, + }, + }). + Delete(&deleted).Error; err != nil { + t.Fatalf("delete returning failed: %v", err) + } + if len(deleted) != 1 || deleted[0].ID != rec.ID { + t.Fatalf("unexpected deleted rows: %#v", deleted) + } + + // verify gone + var check RawRecord + err := DB.First(&check, "\"record_id\" = ?", rec.ID).Error + if !errors.Is(err, gorm.ErrRecordNotFound) { + t.Fatalf("expected not found after delete, got: %v", err) + } +}