diff --git a/sql/hash/hash.go b/sql/hash/hash.go index 62d5ed2c85..57feabd2fe 100644 --- a/sql/hash/hash.go +++ b/sql/hash/hash.go @@ -60,40 +60,80 @@ func HashOf(ctx *sql.Context, sch sql.Schema, row sql.Row) (uint64, error) { return 0, fmt.Errorf("error unwrapping value: %w", err) } - // TODO: we may not always have the type information available, so we check schema length. - // Then, defer to original behavior - if i >= len(sch) || v == nil { - _, err := fmt.Fprintf(hash, "%v", v) - if err != nil { + if v == nil { + if _, err := hash.WriteString(""); err != nil { return 0, err } continue } - switch typ := sch[i].Type.(type) { - case sql.ExtendedType: - // TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL, - // so we're using the old (probably incorrect) behavior for now - _, err = fmt.Fprintf(hash, "%v", v) - if err != nil { - return 0, err + // TODO: we may not always have the type information available, so we check schema length. + // Then, defer to original behavior + if i < len(sch) { + switch typ := sch[i].Type.(type) { + case sql.ExtendedType: + // TODO: Doltgres follows Postgres conventions which don't align with the expectations of MySQL, + // so we're using the old (probably incorrect) behavior for now + _, err := hash.WriteString(fmt.Sprintf("%v", v)) + if err != nil { + return 0, err + } + continue + case types.StringType: + var strVal string + strVal, err = types.ConvertToString(ctx, v, typ, nil) + if err != nil { + return 0, err + } + err = typ.Collation().WriteWeightString(hash, strVal) + if err != nil { + return 0, err + } + continue } - case types.StringType: - var strVal string - strVal, err = types.ConvertToString(ctx, v, typ, nil) - if err != nil { - return 0, err + } + switch v := v.(type) { + case int: + _, err = hash.WriteString(strconv.FormatInt(int64(v), 10)) + case int8: + _, err = hash.WriteString(strconv.FormatInt(int64(v), 10)) + case int16: + _, err = hash.WriteString(strconv.FormatInt(int64(v), 10)) + case int32: + _, err = hash.WriteString(strconv.FormatInt(int64(v), 10)) + case int64: + _, err = hash.WriteString(strconv.FormatInt(v, 10)) + case uint: + _, err = hash.WriteString(strconv.FormatUint(uint64(v), 10)) + case uint8: + _, err = hash.WriteString(strconv.FormatUint(uint64(v), 10)) + case uint16: + _, err = hash.WriteString(strconv.FormatUint(uint64(v), 10)) + case uint32: + _, err = hash.WriteString(strconv.FormatUint(uint64(v), 10)) + case uint64: + _, err = hash.WriteString(strconv.FormatUint(v, 10)) + case float32: + str := strconv.FormatFloat(float64(v), 'f', -1, 32) + if str == "-0" { + str = "0" } - err = typ.Collation().WriteWeightString(hash, strVal) - if err != nil { - return 0, err + _, err = hash.WriteString(str) + case float64: + str := strconv.FormatFloat(v, 'f', -1, 64) + if str == "-0" { + str = "0" } + _, err = hash.WriteString(str) + case string: + _, err = hash.WriteString(v) + case []byte: + _, err = hash.Write(v) default: - // TODO: probably much faster to do this with a type switch - _, err = fmt.Fprintf(hash, "%v", v) - if err != nil { - return 0, err - } + _, err = hash.WriteString(fmt.Sprintf("%v", v)) + } + if err != nil { + return 0, err } } return hash.Sum64(), nil