Skip to content

Commit 648521d

Browse files
Fixing null time returned value
1 parent ce1074b commit 648521d

File tree

4 files changed

+96
-49
lines changed

4 files changed

+96
-49
lines changed

oracle/common.go

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,30 @@ func findFieldByDBName(schema *schema.Schema, dbName string) *schema.Field {
100100
}
101101

102102
// Create typed destination for OUT parameters
103-
func createTypedDestination(fieldType reflect.Type) interface{} {
104-
// Handle pointer types
105-
if fieldType.Kind() == reflect.Ptr {
106-
fieldType = fieldType.Elem()
103+
func createTypedDestination(f *schema.Field) interface{} {
104+
if f == nil {
105+
var s string
106+
return &s
107107
}
108108

109-
// Type-safe handling for known GORM types and SQL null types
110-
switch fieldType {
111-
case reflect.TypeOf(gorm.DeletedAt{}):
109+
ft := f.FieldType
110+
for ft.Kind() == reflect.Ptr {
111+
ft = ft.Elem()
112+
}
113+
114+
if ft == reflect.TypeOf(gorm.DeletedAt{}) {
112115
return new(sql.NullTime)
113-
case reflect.TypeOf(time.Time{}):
116+
}
117+
if ft == reflect.TypeOf(time.Time{}) {
118+
if !f.NotNull { // nullable column => keep NULLs
119+
return new(sql.NullTime)
120+
}
114121
return new(time.Time)
122+
}
123+
124+
switch ft {
125+
case reflect.TypeOf(sql.NullTime{}):
126+
return new(sql.NullTime)
115127
case reflect.TypeOf(sql.NullInt64{}):
116128
return new(sql.NullInt64)
117129
case reflect.TypeOf(sql.NullInt32{}):
@@ -120,33 +132,28 @@ func createTypedDestination(fieldType reflect.Type) interface{} {
120132
return new(sql.NullFloat64)
121133
case reflect.TypeOf(sql.NullBool{}):
122134
return new(sql.NullBool)
123-
case reflect.TypeOf(sql.NullTime{}):
124-
return new(sql.NullTime)
125135
}
126136

127-
// Handle primitive types by Kind
128-
switch fieldType.Kind() {
137+
switch ft.Kind() {
138+
case reflect.String:
139+
return new(string)
140+
141+
case reflect.Bool:
142+
return new(int64)
143+
129144
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
130-
return new(int64) // Oracle returns NUMBER as int64
131-
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
145+
return new(int64)
146+
147+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
132148
return new(uint64)
149+
133150
case reflect.Float32, reflect.Float64:
134-
return new(float64) // Oracle returns FLOAT as float64
135-
case reflect.Bool:
136-
return new(int64) // Oracle NUMBER(1) for boolean
137-
case reflect.String:
138-
return new(string)
139-
case reflect.Struct:
140-
// For time.Time specifically
141-
if fieldType == reflect.TypeOf(time.Time{}) {
142-
return new(time.Time)
143-
}
144-
// For other structs, use string as safe fallback
145-
return new(string)
146-
default:
147-
// For unknown types, use string as safe fallback
148-
return new(string)
151+
return new(float64)
149152
}
153+
154+
// Fallback
155+
var s string
156+
return &s
150157
}
151158

152159
// Convert values for Oracle-specific types
@@ -182,7 +189,7 @@ func convertValue(val interface{}) interface{} {
182189

183190
// Convert Oracle values back to Go types
184191
func convertFromOracleToField(value interface{}, field *schema.Field) interface{} {
185-
if value == nil {
192+
if value == nil || field == nil {
186193
return nil
187194
}
188195

@@ -194,7 +201,6 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
194201

195202
var converted interface{}
196203

197-
// Handle special types first using type-safe comparisons
198204
switch targetType {
199205
case reflect.TypeOf(gorm.DeletedAt{}):
200206
if nullTime, ok := value.(sql.NullTime); ok {
@@ -203,7 +209,31 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
203209
converted = gorm.DeletedAt{}
204210
}
205211
case reflect.TypeOf(time.Time{}):
206-
converted = value
212+
switch vv := value.(type) {
213+
case time.Time:
214+
converted = vv
215+
case sql.NullTime:
216+
if vv.Valid {
217+
converted = vv.Time
218+
} else {
219+
// DB returned NULL
220+
if isPtr {
221+
return nil // -> *time.Time(nil)
222+
}
223+
// non-pointer time.Time: represent NULL as zero time
224+
return time.Time{}
225+
}
226+
default:
227+
converted = value
228+
}
229+
230+
case reflect.TypeOf(sql.NullTime{}):
231+
if nullTime, ok := value.(sql.NullTime); ok {
232+
converted = nullTime
233+
} else {
234+
converted = sql.NullTime{}
235+
}
236+
207237
case reflect.TypeOf(sql.NullInt64{}):
208238
if nullInt, ok := value.(sql.NullInt64); ok {
209239
converted = nullInt
@@ -228,25 +258,19 @@ func convertFromOracleToField(value interface{}, field *schema.Field) interface{
228258
} else {
229259
converted = sql.NullBool{}
230260
}
231-
case reflect.TypeOf(sql.NullTime{}):
232-
if nullTime, ok := value.(sql.NullTime); ok {
233-
converted = nullTime
234-
} else {
235-
converted = sql.NullTime{}
236-
}
237261
default:
238-
// Handle primitive types
262+
// primitives and everything else
239263
converted = convertPrimitiveType(value, targetType)
240264
}
241265

242-
// Handle pointer types
243-
if isPtr && converted != nil {
244-
if isZeroValueForPointer(converted, targetType) {
266+
// Pointer targets: nil for "zero-ish", else allocate and set.
267+
if isPtr {
268+
if isZeroFor(targetType, converted) {
245269
return nil
246270
}
247271
ptr := reflect.New(targetType)
248272
ptr.Elem().Set(reflect.ValueOf(converted))
249-
converted = ptr.Interface()
273+
return ptr.Interface()
250274
}
251275

252276
return converted
@@ -426,8 +450,6 @@ func isNullValue(value interface{}) bool {
426450

427451
// Check for different NULL types
428452
switch v := value.(type) {
429-
case sql.NullString:
430-
return !v.Valid
431453
case sql.NullInt64:
432454
return !v.Valid
433455
case sql.NullInt32:
@@ -442,3 +464,28 @@ func isNullValue(value interface{}) bool {
442464
return false
443465
}
444466
}
467+
468+
func isZeroFor(t reflect.Type, v interface{}) bool {
469+
if v == nil {
470+
return true
471+
}
472+
rv := reflect.ValueOf(v)
473+
if !rv.IsValid() {
474+
return true
475+
}
476+
// exact type match?
477+
if rv.Type() == t {
478+
// special-case time.Time
479+
if t == reflect.TypeOf(time.Time{}) {
480+
return rv.Interface().(time.Time).IsZero()
481+
}
482+
// generic zero check
483+
z := reflect.Zero(t)
484+
return reflect.DeepEqual(rv.Interface(), z.Interface())
485+
}
486+
// If types differ (e.g., sql.NullTime), treat invalid as zero
487+
if nt, ok := v.(sql.NullTime); ok {
488+
return !nt.Valid
489+
}
490+
return false
491+
}

oracle/create.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ func buildBulkMergePLSQL(db *gorm.DB, createValues clause.Values, onConflictClau
474474
for rowIdx := 0; rowIdx < len(createValues.Values); rowIdx++ {
475475
for _, column := range allColumns {
476476
if field := findFieldByDBName(schema, column); field != nil {
477-
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
477+
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
478478
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_affected_records.COUNT > %d THEN :%d := l_affected_records(%d).", rowIdx, outParamIndex+1, rowIdx+1))
479479
writeQuotedIdentifier(&plsqlBuilder, column)
480480
plsqlBuilder.WriteString("; END IF;\n")
@@ -586,7 +586,7 @@ func buildBulkInsertOnlyPLSQL(db *gorm.DB, createValues clause.Values) {
586586
quotedColumn := columnBuilder.String()
587587

588588
if field := findFieldByDBName(schema, column); field != nil {
589-
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field.FieldType)})
589+
stmt.Vars = append(stmt.Vars, sql.Out{Dest: createTypedDestination(field)})
590590
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_inserted_records.COUNT > %d THEN :%d := l_inserted_records(%d).%s; END IF;\n",
591591
rowIdx, outParamIndex+1, rowIdx+1, quotedColumn))
592592
outParamIndex++

oracle/delete.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ func buildBulkDeletePLSQL(db *gorm.DB) {
254254
for _, column := range allColumns {
255255
field := findFieldByDBName(schema, column)
256256
if field != nil {
257-
dest := createTypedDestination(field.FieldType)
257+
dest := createTypedDestination(field)
258258
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
259259

260260
plsqlBuilder.WriteString(fmt.Sprintf(" IF l_deleted_records.COUNT > %d THEN\n", rowIdx))

oracle/update.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ func buildUpdatePLSQL(db *gorm.DB) {
522522
for _, column := range allColumns {
523523
field := findFieldByDBName(schema, column)
524524
if field != nil {
525-
dest := createTypedDestination(field.FieldType)
525+
dest := createTypedDestination(field)
526526
stmt.Vars = append(stmt.Vars, sql.Out{Dest: dest})
527527
}
528528
}

0 commit comments

Comments
 (0)