diff --git a/sql/types/enum.go b/sql/types/enum.go index 957b084bb0..c01b0de0da 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -359,7 +359,11 @@ func (t EnumType) WithNewCollation(collation sql.CollationID) (sql.Type, error) // StringWithTableCollation implements sql.TypeWithCollation interface. func (t EnumType) StringWithTableCollation(tableCollation sql.CollationID) string { - s := fmt.Sprintf("enum('%v')", strings.Join(t.idxToVal, `','`)) + escapedValues := make([]string, len(t.idxToVal)) + for i, value := range t.idxToVal { + escapedValues[i] = strings.ReplaceAll(value, "'", "''") + } + s := fmt.Sprintf("enum('%s')", strings.Join(escapedValues, `','`)) if t.CharacterSet() != tableCollation.CharacterSet() { s += " CHARACTER SET " + t.CharacterSet().String() } diff --git a/sql/types/set.go b/sql/types/set.go index 6e52a30d8b..3691639ab7 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -306,7 +306,12 @@ func (t SetType) WithNewCollation(collation sql.CollationID) (sql.Type, error) { // StringWithTableCollation implements sql.TypeWithCollation interface. func (t SetType) StringWithTableCollation(tableCollation sql.CollationID) string { - s := fmt.Sprintf("set('%v')", strings.Join(t.Values(), `','`)) + values := t.Values() + escapedValues := make([]string, len(values)) + for i, value := range values { + escapedValues[i] = strings.ReplaceAll(value, "'", "''") + } + s := fmt.Sprintf("set('%s')", strings.Join(escapedValues, `','`)) if t.CharacterSet() != tableCollation.CharacterSet() { s += " CHARACTER SET " + t.CharacterSet().String() }