Skip to content

Commit 95224a7

Browse files
authored
Merge pull request #557 from go-jet/first_row_strict
Check strict scan only on first row.
2 parents eaaa328 + ac8d24f commit 95224a7

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

qrm/qrm.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,8 @@ func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interfac
241241
return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
242242
}
243243

244-
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
245-
scanContext.EnsureEveryColumnRead() // can panic
246-
}
247-
248-
if GlobalConfig.StrictFieldMapping {
249-
scanContext.EnsureEveryFieldMapped() // can panic
244+
if scanContext.rowNum == 1 {
245+
scanContext.ensureStrictness()
250246
}
251247

252248
return nil
@@ -291,11 +287,8 @@ func queryToSlice(ctx context.Context, db Queryable, query string, args []interf
291287
return scanContext.rowNum, err
292288
}
293289

294-
if scanContext.rowNum == 1 && GlobalConfig.StrictScan {
295-
scanContext.EnsureEveryColumnRead()
296-
}
297-
if scanContext.rowNum == 1 && GlobalConfig.StrictFieldMapping {
298-
scanContext.EnsureEveryFieldMapped()
290+
if scanContext.rowNum == 1 {
291+
scanContext.ensureStrictness()
299292
}
300293
}
301294

qrm/scan_context.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,17 @@ func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
6767
}, nil
6868
}
6969

70-
func (s *ScanContext) EnsureEveryColumnRead() {
70+
func (s *ScanContext) ensureStrictness() { // can panic
71+
if GlobalConfig.StrictScan {
72+
s.ensureEveryColumnRead() // can panic
73+
}
74+
75+
if GlobalConfig.StrictFieldMapping {
76+
s.ensureEveryFieldMapped() // can panic
77+
}
78+
}
79+
80+
func (s *ScanContext) ensureEveryColumnRead() {
7181
var neverUsedColumns []string
7282

7383
for index, read := range s.columnIndexRead {
@@ -95,13 +105,13 @@ func (s *ScanContext) recordUnmappedField(structType reflect.Type, parentField *
95105

96106
fieldIdent := fmt.Sprintf("%s.%s", typeName, field.Name)
97107
if parentField != nil {
98-
fieldIdent = fmt.Sprintf("%s %s.%s", parentField.Name, typeName, field.Name)
108+
fieldIdent = fmt.Sprintf("%s %s", parentField.Name, fieldIdent)
99109
}
100110

101111
s.unmappedFields = append(s.unmappedFields, fmt.Sprintf("'%s'", fieldIdent))
102112
}
103113

104-
func (s *ScanContext) EnsureEveryFieldMapped() {
114+
func (s *ScanContext) ensureEveryFieldMapped() {
105115
if len(s.unmappedFields) == 0 {
106116
return
107117
}
@@ -328,7 +338,7 @@ func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
328338
// rowElemValue always returns non-ptr value,
329339
// invalid value is nil
330340
func (s *ScanContext) rowElemValue(index int) reflect.Value {
331-
if s.rowNum == 1 {
341+
if s.rowNum == 1 && GlobalConfig.StrictScan {
332342
s.columnIndexRead[index] = true
333343
}
334344
scannedValue := reflect.ValueOf(s.row[index])

0 commit comments

Comments
 (0)