Skip to content

Commit ab09da7

Browse files
committed
More unicode.
1 parent a159b54 commit ab09da7

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

ext/unicode/unicode.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,18 @@
55
// - LIKE and REGEXP operators,
66
// - collation sequences.
77
//
8-
// It also provides, from PostgreSQL:
9-
// - unaccent(),
10-
// - initcap().
11-
//
128
// The implementation is not 100% compatible with the [ICU extension]:
139
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
1410
// - the LIKE operator follows [strings.EqualFold] rules;
1511
// - the REGEXP operator uses Go [regexp/syntax];
1612
// - collation sequences use [collate].
1713
//
14+
// It also provides (approximately) from PostgreSQL:
15+
// - casefold(),
16+
// - initcap(),
17+
// - normalize(),
18+
// - unaccent().
19+
//
1820
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
1921
//
2022
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
@@ -48,21 +50,24 @@ var RegisterLike = true
4850
// Register registers Unicode aware functions for a database connection.
4951
func Register(db *sqlite3.Conn) error {
5052
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
51-
var errs util.ErrorJoiner
53+
var lkfn sqlite3.ScalarFunction
5254
if RegisterLike {
53-
errs.Join(
54-
db.CreateFunction("like", 2, flags, like),
55-
db.CreateFunction("like", 3, flags, like))
55+
lkfn = like
5656
}
57-
errs.Join(
57+
return errors.Join(
58+
db.CreateFunction("like", 2, flags, lkfn),
59+
db.CreateFunction("like", 3, flags, lkfn),
5860
db.CreateFunction("upper", 1, flags, upper),
5961
db.CreateFunction("upper", 2, flags, upper),
6062
db.CreateFunction("lower", 1, flags, lower),
6163
db.CreateFunction("lower", 2, flags, lower),
6264
db.CreateFunction("regexp", 2, flags, regex),
6365
db.CreateFunction("initcap", 1, flags, initcap),
6466
db.CreateFunction("initcap", 2, flags, initcap),
67+
db.CreateFunction("casefold", 1, flags, casefold),
6568
db.CreateFunction("unaccent", 1, flags, unaccent),
69+
db.CreateFunction("normalize", 1, flags, normalize),
70+
db.CreateFunction("normalize", 2, flags, normalize),
6671
db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY,
6772
func(ctx sqlite3.Context, arg ...sqlite3.Value) {
6873
name := arg[1].Text()
@@ -76,7 +81,6 @@ func Register(db *sqlite3.Conn) error {
7681
return // notest
7782
}
7883
}))
79-
return errors.Join(errs...)
8084
}
8185

8286
// RegisterCollation registers a Unicode collation sequence for a database connection.
@@ -154,6 +158,10 @@ func initcap(ctx sqlite3.Context, arg ...sqlite3.Value) {
154158
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
155159
}
156160

161+
func casefold(ctx sqlite3.Context, arg ...sqlite3.Value) {
162+
ctx.ResultRawText(cases.Fold().Bytes(arg[0].RawText()))
163+
}
164+
157165
func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
158166
unaccent := transform.Chain(norm.NFD, runes.Remove(runes.In(unicode.Mn)), norm.NFC)
159167
res, _, err := transform.Bytes(unaccent, arg[0].RawText())
@@ -164,6 +172,31 @@ func unaccent(ctx sqlite3.Context, arg ...sqlite3.Value) {
164172
}
165173
}
166174

175+
func normalize(ctx sqlite3.Context, arg ...sqlite3.Value) {
176+
form := norm.NFC
177+
if len(arg) > 1 {
178+
switch strings.ToUpper(arg[1].Text()) {
179+
case "NFC":
180+
//
181+
case "NFD":
182+
form = norm.NFD
183+
case "NFKC":
184+
form = norm.NFKC
185+
case "NFKD":
186+
form = norm.NFKD
187+
default:
188+
ctx.ResultError(util.ErrorString("unicode: invalid form"))
189+
return
190+
}
191+
}
192+
res, _, err := transform.Bytes(form, arg[0].RawText())
193+
if err != nil {
194+
ctx.ResultError(err) // notest
195+
} else {
196+
ctx.ResultRawText(res)
197+
}
198+
}
199+
167200
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
168201
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
169202
if !ok {

ext/unicode/unicode_test.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ func TestRegister(t *testing.T) {
4949
{`upper('Dünyanın İlk Borsası', 'tr-TR')`, "DÜNYANIN İLK BORSASI"},
5050
{`initcap('Kad je hladno Marko nosi džemper')`, "Kad Je Hladno Marko Nosi Džemper"},
5151
{`initcap('Kad je hladno Marko nosi džemper', 'hr-HR')`, "Kad Je Hladno Marko Nosi Džemper"},
52+
{`normalize(X'61cc88')`, "ä"},
53+
{`normalize(X'61cc88', 'NFC' )`, "ä"},
54+
{`normalize(X'61cc88', 'NFKC')`, "ä"},
55+
{`normalize('ä', 'NFD' )`, "\x61\xcc\x88"},
56+
{`normalize('ä', 'NFKD')`, "\x61\xcc\x88"},
57+
{`casefold('Maße')`, "masse"},
5258
{`unaccent('Hôtel')`, "Hotel"},
5359
{`'Hello' REGEXP 'ell'`, "1"},
5460
{`'Hello' REGEXP 'el.'`, "1"},
@@ -208,6 +214,14 @@ func TestRegister_error(t *testing.T) {
208214
t.Errorf("got %v, want sqlite3.ERROR", err)
209215
}
210216

217+
err = db.Exec(`SELECT normalize('', 'NF')`)
218+
if err == nil {
219+
t.Error("want error")
220+
}
221+
if !errors.Is(err, sqlite3.ERROR) {
222+
t.Errorf("got %v, want sqlite3.ERROR", err)
223+
}
224+
211225
err = db.Exec(`SELECT 'hello' REGEXP '\'`)
212226
if err == nil {
213227
t.Error("want error")

0 commit comments

Comments
 (0)