Skip to content

Commit c0e81fb

Browse files
authored
Merge pull request #424 from Shopify/refactor-uuid-as-id
NewPaginationKeyFromRow refactor
2 parents 8370cd0 + 5734959 commit c0e81fb

File tree

7 files changed

+80
-265
lines changed

7 files changed

+80
-265
lines changed

batch_writer.go

Lines changed: 7 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77

88
sql "github.com/Shopify/ghostferry/sqlwrapper"
99

10-
"github.com/go-mysql-org/go-mysql/schema"
1110
"github.com/sirupsen/logrus"
1211
)
1312

@@ -57,65 +56,15 @@ func (w *BatchWriter) WriteRowBatch(batch *RowBatch) error {
5756
return nil
5857
}
5958

60-
var startPaginationKeypos, endPaginationKeypos PaginationKey
61-
var err error
62-
6359
paginationColumn := batch.TableSchema().GetPaginationColumn()
6460

65-
switch paginationColumn.Type {
66-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
67-
var startValue, endValue uint64
68-
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
69-
if err != nil {
70-
return err
71-
}
72-
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
73-
if err != nil {
74-
return err
75-
}
76-
startPaginationKeypos = NewUint64Key(startValue)
77-
endPaginationKeypos = NewUint64Key(endValue)
78-
79-
case schema.TYPE_BINARY, schema.TYPE_STRING:
80-
startValueInterface := values[0][batch.PaginationKeyIndex()]
81-
endValueInterface := values[len(values)-1][batch.PaginationKeyIndex()]
82-
83-
getBytes := func(val interface{}) ([]byte, error) {
84-
switch v := val.(type) {
85-
case []byte:
86-
return v, nil
87-
case string:
88-
return []byte(v), nil
89-
default:
90-
return nil, fmt.Errorf("expected binary/string pagination key, got %T", val)
91-
}
92-
}
93-
94-
startValue, err := getBytes(startValueInterface)
95-
if err != nil {
96-
return err
97-
}
98-
99-
endValue, err := getBytes(endValueInterface)
100-
if err != nil {
101-
return err
102-
}
103-
104-
startPaginationKeypos = NewBinaryKey(startValue)
105-
endPaginationKeypos = NewBinaryKey(endValue)
106-
107-
default:
108-
var startValue, endValue uint64
109-
startValue, err = values[0].GetUint64(batch.PaginationKeyIndex())
110-
if err != nil {
111-
return err
112-
}
113-
endValue, err = values[len(values)-1].GetUint64(batch.PaginationKeyIndex())
114-
if err != nil {
115-
return err
116-
}
117-
startPaginationKeypos = NewUint64Key(startValue)
118-
endPaginationKeypos = NewUint64Key(endValue)
61+
startPaginationKeypos, err := NewPaginationKeyFromRow(values[0], batch.PaginationKeyIndex(), paginationColumn)
62+
if err != nil {
63+
return err
64+
}
65+
endPaginationKeypos, err := NewPaginationKeyFromRow(values[len(values)-1], batch.PaginationKeyIndex(), paginationColumn)
66+
if err != nil {
67+
return err
11968
}
12069

12170
db := batch.TableSchema().Schema

cursor.go

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -262,43 +262,10 @@ func (c *Cursor) Fetch(db SqlPreparer) (batch *RowBatch, paginationKeypos Pagina
262262

263263
if len(batchData) > 0 {
264264
lastRowData := batchData[len(batchData)-1]
265-
266-
switch c.paginationKeyColumn.Type {
267-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
268-
var value uint64
269-
value, err = lastRowData.GetUint64(paginationKeyIndex)
270-
if err != nil {
271-
logger.WithError(err).Error("failed to get uint64 paginationKey value")
272-
return
273-
}
274-
paginationKeypos = NewUint64Key(value)
275-
276-
case schema.TYPE_BINARY, schema.TYPE_STRING:
277-
valueInterface := lastRowData[paginationKeyIndex]
278-
279-
var valueBytes []byte
280-
switch v := valueInterface.(type) {
281-
case []byte:
282-
valueBytes = v
283-
case string:
284-
valueBytes = []byte(v)
285-
default:
286-
err = fmt.Errorf("expected binary pagination key to be []byte or string, got %T", valueInterface)
287-
logger.WithError(err).Error("failed to get binary paginationKey value")
288-
return
289-
}
290-
291-
paginationKeypos = NewBinaryKey(valueBytes)
292-
293-
default:
294-
// Fallback for other integer types
295-
var value uint64
296-
value, err = lastRowData.GetUint64(paginationKeyIndex)
297-
if err != nil {
298-
logger.WithError(err).Error("failed to get uint64 paginationKey value")
299-
return
300-
}
301-
paginationKeypos = NewUint64Key(value)
265+
paginationKeypos, err = NewPaginationKeyFromRow(lastRowData, paginationKeyIndex, c.paginationKeyColumn)
266+
if err != nil {
267+
logger.WithError(err).Error("failed to get paginationKey value")
268+
return
302269
}
303270
}
304271

data_iterator.go

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66

77
sql "github.com/Shopify/ghostferry/sqlwrapper"
88

9-
"github.com/go-mysql-org/go-mysql/schema"
109
"github.com/sirupsen/logrus"
1110
)
1211

@@ -115,40 +114,13 @@ func (d *DataIterator) Run(tables []*TableSchema) {
115114
paginationColumn := table.GetPaginationColumn()
116115

117116
for i, rowData := range batch.Values() {
118-
var paginationKeyStr string
119-
120-
switch paginationColumn.Type {
121-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
122-
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
123-
if err != nil {
124-
logger.WithError(err).Error("failed to get uint64 paginationKey data")
125-
return err
126-
}
127-
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
128-
129-
case schema.TYPE_BINARY, schema.TYPE_STRING:
130-
paginationKeyInterface := rowData[batch.PaginationKeyIndex()]
131-
var paginationKeyBytes []byte
132-
switch v := paginationKeyInterface.(type) {
133-
case []byte:
134-
paginationKeyBytes = v
135-
case string:
136-
paginationKeyBytes = []byte(v)
137-
default:
138-
return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
139-
}
140-
paginationKeyStr = NewBinaryKey(paginationKeyBytes).String()
141-
142-
default:
143-
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
144-
if err != nil {
145-
logger.WithError(err).Error("failed to get paginationKey data")
146-
return err
147-
}
148-
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
117+
paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn)
118+
if err != nil {
119+
logger.WithError(err).Error("failed to get paginationKey data")
120+
return err
149121
}
150122

151-
fingerprints[paginationKeyStr] = rowData[len(rowData)-1].([]byte)
123+
fingerprints[paginationKey.String()] = rowData[len(rowData)-1].([]byte)
152124
rows[i] = rowData[:len(rowData)-1]
153125
}
154126

dml_events.go

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -576,35 +576,9 @@ func paginationKeyFromEventData(table *TableSchema, rowData RowData) (string, er
576576
return "", err
577577
}
578578

579-
paginationColumn := table.GetPaginationColumn()
580-
paginationKeyIndex := table.GetPaginationKeyIndex()
581-
582-
switch paginationColumn.Type {
583-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
584-
paginationKeyUint, err := rowData.GetUint64(paginationKeyIndex)
585-
if err != nil {
586-
return "", err
587-
}
588-
return NewUint64Key(paginationKeyUint).String(), nil
589-
590-
case schema.TYPE_BINARY, schema.TYPE_STRING:
591-
paginationKeyInterface := rowData[paginationKeyIndex]
592-
var paginationKeyBytes []byte
593-
switch v := paginationKeyInterface.(type) {
594-
case []byte:
595-
paginationKeyBytes = v
596-
case string:
597-
paginationKeyBytes = []byte(v)
598-
default:
599-
return "", fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
600-
}
601-
return NewBinaryKey(paginationKeyBytes).String(), nil
602-
603-
default:
604-
paginationKeyUint, err := rowData.GetUint64(paginationKeyIndex)
605-
if err != nil {
606-
return "", err
607-
}
608-
return NewUint64Key(paginationKeyUint).String(), nil
579+
paginationKey, err := NewPaginationKeyFromRow(rowData, table.GetPaginationKeyIndex(), table.GetPaginationColumn())
580+
if err != nil {
581+
return "", err
609582
}
583+
return paginationKey.String(), nil
610584
}

inline_verifier.go

Lines changed: 8 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -335,34 +335,11 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target
335335

336336
paginationKeys := make([]interface{}, len(sourceBatch.Values()))
337337
for i, row := range sourceBatch.Values() {
338-
switch paginationColumn.Type {
339-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
340-
paginationKeyUint, err := row.GetUint64(sourceBatch.PaginationKeyIndex())
341-
if err != nil {
342-
return nil, err
343-
}
344-
paginationKeys[i] = paginationKeyUint
345-
346-
case schema.TYPE_BINARY, schema.TYPE_STRING:
347-
paginationKeyInterface := row[sourceBatch.PaginationKeyIndex()]
348-
var paginationKeyBytes []byte
349-
switch v := paginationKeyInterface.(type) {
350-
case []byte:
351-
paginationKeyBytes = v
352-
case string:
353-
paginationKeyBytes = []byte(v)
354-
default:
355-
return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
356-
}
357-
paginationKeys[i] = paginationKeyBytes
358-
359-
default:
360-
paginationKeyUint, err := row.GetUint64(sourceBatch.PaginationKeyIndex())
361-
if err != nil {
362-
return nil, err
363-
}
364-
paginationKeys[i] = paginationKeyUint
338+
paginationKey, err := NewPaginationKeyFromRow(row, sourceBatch.PaginationKeyIndex(), paginationColumn)
339+
if err != nil {
340+
return nil, err
365341
}
342+
paginationKeys[i] = paginationKey.SQLValue()
366343
}
367344

368345
// Fetch target data
@@ -376,36 +353,11 @@ func (v *InlineVerifier) CheckFingerprintInline(tx *sql.Tx, targetSchema, target
376353
sourceDecompressedData := make(map[string]map[string][]byte)
377354

378355
for _, rowData := range sourceBatch.Values() {
379-
var paginationKeyStr string
380-
381-
switch paginationColumn.Type {
382-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
383-
paginationKeyUint, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex())
384-
if err != nil {
385-
return nil, err
386-
}
387-
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
388-
389-
case schema.TYPE_BINARY, schema.TYPE_STRING:
390-
paginationKeyInterface := rowData[sourceBatch.PaginationKeyIndex()]
391-
var paginationKeyBytes []byte
392-
switch v := paginationKeyInterface.(type) {
393-
case []byte:
394-
paginationKeyBytes = v
395-
case string:
396-
paginationKeyBytes = []byte(v)
397-
default:
398-
return nil, fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
399-
}
400-
paginationKeyStr = NewBinaryKey(paginationKeyBytes).String()
401-
402-
default:
403-
paginationKeyUint, err := rowData.GetUint64(sourceBatch.PaginationKeyIndex())
404-
if err != nil {
405-
return nil, err
406-
}
407-
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
356+
paginationKey, err := NewPaginationKeyFromRow(rowData, sourceBatch.PaginationKeyIndex(), paginationColumn)
357+
if err != nil {
358+
return nil, err
408359
}
360+
paginationKeyStr := paginationKey.String()
409361

410362
sourceDecompressedData[paginationKeyStr] = make(map[string][]byte)
411363
for idx, col := range table.Columns {

iterative_verifier.go

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -320,31 +320,12 @@ func (v *IterativeVerifier) GetHashes(db *sql.DB, schemaName, tableName, paginat
320320
return nil, err
321321
}
322322

323-
var paginationKeyStr string
324-
switch paginationColumn.Type {
325-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
326-
paginationKeyUint, err := rowData.GetUint64(0)
327-
if err != nil {
328-
return nil, err
329-
}
330-
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
331-
332-
case schema.TYPE_BINARY, schema.TYPE_STRING:
333-
paginationKeyBytes, ok := rowData[0].([]byte)
334-
if !ok {
335-
return nil, fmt.Errorf("expected []byte for binary pagination key, got %T", rowData[0])
336-
}
337-
paginationKeyStr = NewBinaryKey(paginationKeyBytes).String()
338-
339-
default:
340-
paginationKeyUint, err := rowData.GetUint64(0)
341-
if err != nil {
342-
return nil, err
343-
}
344-
paginationKeyStr = NewUint64Key(paginationKeyUint).String()
323+
paginationKey, err := NewPaginationKeyFromRow(rowData, 0, paginationColumn)
324+
if err != nil {
325+
return nil, err
345326
}
346327

347-
resultSet[paginationKeyStr] = rowData[1].([]byte)
328+
resultSet[paginationKey.String()] = rowData[1].([]byte)
348329
}
349330
return resultSet, nil
350331
}
@@ -422,34 +403,11 @@ func (v *IterativeVerifier) iterateTableFingerprints(table *TableSchema, mismatc
422403
paginationColumn := table.GetPaginationColumn()
423404

424405
for _, rowData := range batch.Values() {
425-
switch paginationColumn.Type {
426-
case schema.TYPE_NUMBER, schema.TYPE_MEDIUM_INT:
427-
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
428-
if err != nil {
429-
return err
430-
}
431-
paginationKeys = append(paginationKeys, paginationKeyUint)
432-
433-
case schema.TYPE_BINARY, schema.TYPE_STRING:
434-
paginationKeyInterface := rowData[batch.PaginationKeyIndex()]
435-
var paginationKeyBytes []byte
436-
switch v := paginationKeyInterface.(type) {
437-
case []byte:
438-
paginationKeyBytes = v
439-
case string:
440-
paginationKeyBytes = []byte(v)
441-
default:
442-
return fmt.Errorf("expected binary/string pagination key, got %T", paginationKeyInterface)
443-
}
444-
paginationKeys = append(paginationKeys, paginationKeyBytes)
445-
446-
default:
447-
paginationKeyUint, err := rowData.GetUint64(batch.PaginationKeyIndex())
448-
if err != nil {
449-
return err
450-
}
451-
paginationKeys = append(paginationKeys, paginationKeyUint)
406+
paginationKey, err := NewPaginationKeyFromRow(rowData, batch.PaginationKeyIndex(), paginationColumn)
407+
if err != nil {
408+
return err
452409
}
410+
paginationKeys = append(paginationKeys, paginationKey.SQLValue())
453411
}
454412

455413
mismatchedPaginationKeys, err := v.compareFingerprints(paginationKeys, batch.TableSchema())

0 commit comments

Comments
 (0)