Skip to content

Commit e7afe7a

Browse files
Merge pull request #103 from oracle-samples/48-invalid-plsql-generated-for-bulk-insert-with-returning-clause
Fix #48 Add support for serializer
2 parents b0805d2 + 1482dc8 commit e7afe7a

File tree

3 files changed

+130
-14
lines changed

3 files changed

+130
-14
lines changed

oracle/common.go

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func getOracleArrayType(field *schema.Field, values []any) string {
8282
case schema.Bytes:
8383
return "TABLE OF BLOB"
8484
default:
85-
return "TABLE OF VARCHAR2(4000)" // Safe default
85+
return "TABLE OF " + strings.ToUpper(string(field.DataType))
8686
}
8787
}
8888

@@ -113,11 +113,47 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field {
113113
return nil
114114
}
115115

116+
// Extra data types to determine the destination type for OUT parameters
117+
// when using a serializer
118+
const (
119+
Timestamp schema.DataType = "timestamp"
120+
TimestampWithTimeZone schema.DataType = "timestamp with time zone"
121+
)
122+
116123
// Create typed destination for OUT parameters
117124
func createTypedDestination(f *schema.Field) interface{} {
118125
if f == nil {
119-
var s string
120-
return &s
126+
return new(string)
127+
}
128+
129+
// If the field has a serializer, the field type may not be directly related to the column type in the database.
130+
// In this case, determine the destination type using the field's data type, which is the column type in the
131+
// database.
132+
if f.Serializer != nil {
133+
dt := strings.ToLower(string(f.DataType))
134+
switch schema.DataType(dt) {
135+
case schema.Bool:
136+
return new(bool)
137+
case schema.Uint:
138+
return new(uint64)
139+
case schema.Int:
140+
return new(int64)
141+
case schema.Float:
142+
return new(float64)
143+
case schema.String:
144+
return new(string)
145+
case Timestamp:
146+
fallthrough
147+
case TimestampWithTimeZone:
148+
fallthrough
149+
case schema.Time:
150+
return new(time.Time)
151+
case schema.Bytes:
152+
return new([]byte)
153+
default:
154+
// Fallback
155+
return new(string)
156+
}
121157
}
122158

123159
ft := f.FieldType
@@ -166,8 +202,7 @@ func createTypedDestination(f *schema.Field) interface{} {
166202
}
167203

168204
// Fallback
169-
var s string
170-
return &s
205+
return new(string)
171206
}
172207

173208
// Convert values for Oracle-specific types
@@ -229,6 +264,13 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
229264
return nil
230265
}
231266

267+
// Deserialize data into objects when a serializer is used
268+
if field.Serializer != nil {
269+
serializerField := field.NewValuePool.Get().(sql.Scanner)
270+
serializerField.Scan(value)
271+
return serializerField
272+
}
273+
232274
targetType := field.FieldType
233275
var converted any
234276

oracle/update.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ func checkMissingWhereConditions(db *gorm.DB) {
168168
}
169169
// Has non-soft-delete equality condition, this is valid
170170
hasMeaningfulConditions = true
171-
break
172171
case clause.IN:
173172
// Has IN condition with values, this is valid
174173
if len(e.Values) > 0 {
@@ -187,11 +186,9 @@ func checkMissingWhereConditions(db *gorm.DB) {
187186
}
188187
// Has non-soft-delete expression condition, consider it valid
189188
hasMeaningfulConditions = true
190-
break
191189
case clause.AndConditions, clause.OrConditions:
192190
// Complex conditions are likely valid (but we could be more thorough here)
193191
hasMeaningfulConditions = true
194-
break
195192
case clause.Where:
196193
// Handle nested WHERE clauses - recursively check their expressions
197194
if len(e.Exprs) > 0 {
@@ -208,7 +205,6 @@ func checkMissingWhereConditions(db *gorm.DB) {
208205
default:
209206
// Unknown condition types - assume they're meaningful for safety
210207
hasMeaningfulConditions = true
211-
break
212208
}
213209

214210
// If we found meaningful conditions, we can stop checking

tests/serializer_test.go

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ type SerializerStruct struct {
6060
Roles3 *Roles `gorm:"serializer:json;not null"`
6161
Contracts map[string]interface{} `gorm:"serializer:json"`
6262
JobInfo Job `gorm:"type:bytes;serializer:gob"`
63-
CreatedTime int64 `gorm:"serializer:unixtime;type:timestamp"` // store time in db, use int as field type
64-
UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamp"` // store time in db, use int as field type
63+
CreatedTime int64 `gorm:"serializer:unixtime;type:timestamp with time zone"` // store time in db, use int as field type
64+
UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamp with time zone"` // store time in db, use int as field type
6565
CustomSerializerString string `gorm:"serializer:custom"`
6666
EncryptedString EncryptedString
6767
}
@@ -122,7 +122,9 @@ func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst r
122122
}
123123

124124
func TestSerializer(t *testing.T) {
125-
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
125+
if _, ok := schema.GetSerializer("custom"); !ok {
126+
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
127+
}
126128
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
127129
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
128130
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
@@ -168,8 +170,82 @@ func TestSerializer(t *testing.T) {
168170
}
169171
}
170172

173+
// Issue 48: https://github.com/oracle-samples/gorm-oracle/issues/48
174+
func TestSerializerBulkInsert(t *testing.T) {
175+
if _, ok := schema.GetSerializer("custom"); !ok {
176+
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
177+
}
178+
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
179+
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
180+
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
181+
}
182+
183+
createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
184+
updatedAt := createdAt.Unix()
185+
186+
data := []SerializerStruct{
187+
{
188+
Name: []byte("jinzhu"),
189+
Roles: []string{"r1", "r2"},
190+
Roles3: &Roles{},
191+
Contracts: map[string]interface{}{"name": "jinzhu", "age": 10},
192+
EncryptedString: EncryptedString("pass"),
193+
CreatedTime: createdAt.Unix(),
194+
UpdatedTime: &updatedAt,
195+
JobInfo: Job{
196+
Title: "programmer",
197+
Number: 9920,
198+
Location: "Kenmawr",
199+
IsIntern: false,
200+
},
201+
CustomSerializerString: "world",
202+
},
203+
{
204+
Name: []byte("john"),
205+
Roles: []string{"l1", "l2"},
206+
Roles3: &Roles{},
207+
Contracts: map[string]interface{}{"name": "john", "age": 20},
208+
EncryptedString: EncryptedString("pass"),
209+
CreatedTime: createdAt.Unix(),
210+
UpdatedTime: &updatedAt,
211+
JobInfo: Job{
212+
Title: "manager",
213+
Number: 7710,
214+
Location: "Redwood City",
215+
IsIntern: false,
216+
},
217+
CustomSerializerString: "foo",
218+
},
219+
}
220+
221+
if err := DB.Create(&data).Error; err != nil {
222+
t.Fatalf("failed to create data, got error %v", err)
223+
}
224+
225+
var result []SerializerStruct
226+
if err := DB.Find(&result).Error; err != nil {
227+
t.Fatalf("failed to query data, got error %v", err)
228+
}
229+
230+
tests.AssertEqual(t, result, data)
231+
232+
// Update all the "roles" columns to "n1"
233+
if err := DB.Model(&SerializerStruct{}).Where("\"roles\" IS NOT NULL").Update("roles", []string{"n1"}).Error; err != nil {
234+
t.Fatalf("failed to update data's roles, got error %v", err)
235+
}
236+
237+
var count int64
238+
if err := DB.Model(&SerializerStruct{}).Where("\"roles\" = ?", "n1").Count(&count).Error; err != nil {
239+
t.Fatalf("failed to query data, got error %v", err)
240+
}
241+
242+
tests.AssertEqual(t, count, 2)
243+
}
244+
171245
func TestSerializerZeroValue(t *testing.T) {
172-
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
246+
if _, ok := schema.GetSerializer("custom"); !ok {
247+
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
248+
}
173249
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
174250
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
175251
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)
@@ -200,7 +276,9 @@ func TestSerializerZeroValue(t *testing.T) {
200276
}
201277

202278
func TestSerializerAssignFirstOrCreate(t *testing.T) {
203-
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
279+
if _, ok := schema.GetSerializer("custom"); !ok {
280+
schema.RegisterSerializer("custom", NewCustomSerializer("hello"))
281+
}
204282
DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{}))
205283
if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil {
206284
t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err)

0 commit comments

Comments
 (0)