Skip to content

Commit f297903

Browse files
committed
Make pattern bracketed character classes well-defined
1 parent 1a1d83f commit f297903

File tree

3 files changed

+138
-38
lines changed

3 files changed

+138
-38
lines changed

internal/lua/pattern.go

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package lua
55

66
import (
77
"errors"
8+
"fmt"
89
"iter"
910
"strings"
1011

@@ -340,32 +341,92 @@ func matchByte(b byte, characterClass string) bool {
340341
// For example, matchBracketClass(b, "^abc") checks whether b matches "[^abc]".
341342
func matchBracketClass(b byte, set string) bool {
342343
set, invert := strings.CutPrefix(set, "^")
343-
for i := 0; i < len(set); i++ {
344-
switch p := set[i]; {
345-
case p == '%':
346-
// Escaped class.
347-
i++
348-
if i >= len(set) {
349-
return false
350-
}
351-
if matchEscapedClass(b, set[i]) {
352-
return !invert
353-
}
354-
case i+2 < len(set) && set[i+1] == '-':
355-
// Range.
356-
if p <= b && b <= set[i+2] {
357-
return !invert
358-
}
359-
i += 2
360-
case b == p:
344+
for len(set) > 0 {
345+
curr, next, err := cutBracketClassItem(set)
346+
if err != nil {
347+
return false
348+
}
349+
matched := false
350+
switch {
351+
case curr[1] != "":
352+
lo := curr[0][len(curr[0])-1]
353+
hi := curr[1][len(curr[1])-1]
354+
matched = lo <= b && b <= hi
355+
case curr[0][0] == '%':
356+
matched = matchEscapedClass(b, curr[0][1])
357+
default:
358+
matched = b == curr[0][0]
359+
}
360+
if matched {
361361
return !invert
362362
}
363+
set = next
363364
}
364365
return invert
365366
}
366367

368+
// cutBracketClassItem returns set without the leading character class or range.
369+
// If charRange[1] != "", then the leading item is a range.
370+
func cutBracketClassItem(set string) (charRange [2]string, rest string, err error) {
371+
if len(set) == 0 {
372+
return [2]string{}, "", nil
373+
}
374+
end1, err := bracketCharacterClassEnd(set)
375+
if err != nil {
376+
return [2]string{}, set, err
377+
}
378+
379+
// If a hyphen immediately follows the character class
380+
// and the hyphen is not the last character of the set,
381+
// then this is a range.
382+
start2 := end1 + 1
383+
if start2 >= len(set) || set[end1] != '-' {
384+
return [2]string{set[:end1]}, set[end1:], nil
385+
}
386+
if len(set) >= 2 && set[0] == '%' && isKnownEscapedClass(set[1]) {
387+
return [2]string{}, set, errors.New("malformed pattern (character class used in range)")
388+
}
389+
end2, err := characterClassEnd(set[start2:])
390+
if err != nil {
391+
return [2]string{}, set, err
392+
}
393+
end2 += start2
394+
if len(set) >= start2+2 && set[start2] == '%' && isKnownEscapedClass(set[start2+1]) {
395+
return [2]string{}, set, errors.New("malformed pattern (character class used in range)")
396+
}
397+
return [2]string{set[:end1], set[start2:end2]}, set[end2:], nil
398+
}
399+
400+
// bracketCharacterClassEnd returns the length of the Lua pattern character class
401+
// at the start of the given set.
402+
// bracketCharacterClassEnd returns 0 if and only if set is empty.
403+
//
404+
// Character classes recognized by bracketCharacterClassEnd
405+
// are bytes or escapes indicating sets of characters.
406+
func bracketCharacterClassEnd(set string) (int, error) {
407+
switch {
408+
case len(set) == 0:
409+
return 0, nil
410+
case set[0] == '%':
411+
if len(set) < 2 {
412+
return -1, errors.New("malformed pattern (ends with '%')")
413+
}
414+
if isASCIIDigit(rune(set[1])) {
415+
return -1, errors.New("patterns with backreferences not supported")
416+
}
417+
if isASCIILetter(rune(set[1])) && !isKnownEscapedClass(set[1]) {
418+
return -1, fmt.Errorf("malformed pattern (unknown character class %q)", set[:2])
419+
}
420+
return 2, nil
421+
default:
422+
return 1, nil
423+
}
424+
}
425+
367426
// matchEscapedClass reports whether b matches the Lua pattern character class
368427
// written as a percent sign followed by the byte p.
428+
//
429+
// If you change this function, update [isKnownEscapedClass].
369430
func matchEscapedClass(b byte, p byte) bool {
370431
var matched bool
371432
switch toLowerASCII(rune(p)) {
@@ -395,26 +456,34 @@ func matchEscapedClass(b byte, p byte) bool {
395456
return matched == isASCIILowercase(rune(p))
396457
}
397458

459+
// isKnownEscapedClass reports whether the given byte
460+
// forms a character class that is not a direct escape of the byte
461+
// when preceded by a '%'.
462+
// For example, isKnownEscapedClass('a') reports true
463+
// and isKnownEscapedClass('[') reports false.
464+
func isKnownEscapedClass(p byte) bool {
465+
if !isASCII(rune(p)) {
466+
return false
467+
}
468+
p = byte(toLowerASCII(rune(p)))
469+
return strings.IndexByte("acdglpsuwx", p) != -1
470+
}
471+
398472
// characterClassEnd returns the length of the Lua pattern character class
399473
// at the start of pattern.
400474
// characterClassEnd returns 0 if and only if pattern is empty.
401475
//
402476
// Character classes are a byte, an escape indicating a set of characters,
403477
// or a bracketed character class.
478+
// characterClassEnd is largely the same as [bracketCharacterClassEnd],
479+
// but parses bracketed character classes.
404480
func characterClassEnd(pattern string) (end int, err error) {
405-
switch {
406-
case len(pattern) == 0:
407-
return 0, nil
408-
case pattern[0] == '%':
409-
if len(pattern) < 2 {
410-
return -1, errors.New("malformed pattern (ends with '%')")
411-
}
412-
return 2, nil
413-
case pattern[0] == '[':
481+
if len(pattern) > 0 && pattern[0] == '[' {
414482
end := 1
415483
if strings.HasPrefix(pattern[end:], "^") {
416484
end++
417485
}
486+
start := end
418487
if strings.HasPrefix(pattern[end:], "]") {
419488
// Don't let ']' in first position terminate class.
420489
end++
@@ -425,13 +494,20 @@ func characterClassEnd(pattern string) (end int, err error) {
425494
// Skip escape.
426495
end++
427496
case ']':
497+
for set := pattern[start:end]; len(set) > 0; {
498+
_, rest, err := cutBracketClassItem(set)
499+
if err != nil {
500+
return -1, err
501+
}
502+
set = rest
503+
}
428504
return end + 1, nil
429505
}
430506
}
431507
return -1, errors.New("malformed pattern (missing ']')")
432-
default:
433-
return 1, nil
434508
}
509+
510+
return bracketCharacterClassEnd(pattern)
435511
}
436512

437513
func isASCIILetter(c rune) bool {

internal/lua/stringlib.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ const StringLibraryName = "string"
2828
//
2929
// # Differences from de facto C implementation
3030
//
31-
// - Patterns do not support backreferences (i.e. %0 - %9),
32-
// balances (i.e. %b), or frontiers (%f).
33-
// Attempting to use any of these pattern items will raise an error.
34-
// - Character sets with classes in ranges (e.g. [%a-z]) raise an error
35-
// instead of silently exhibiting undefined behavior.
31+
// - Patterns do not support backreferences (i.e. %0 - %9) or balances (i.e. %b).
32+
// Attempting to use either of these pattern items will raise an error.
33+
// - In patterns, character sets with classes in ranges (e.g. [%a-z])
34+
// raise an error instead of silently exhibiting undefined behavior.
35+
// However, ranges using escapes (e.g. [%]-`]) are well-defined in this implementation.
36+
// - string.unpack is currently missing.
37+
// (https://github.com/256lights/zb/issues/79)
3638
//
3739
// [string manipulation library]: https://www.lua.org/manual/5.4/manual.html#6.4
3840
func OpenString(ctx context.Context, l *State) (int, error) {
@@ -67,7 +69,8 @@ func OpenString(ctx context.Context, l *State) (int, error) {
6769
luacode.UnaryMinus,
6870
}
6971
metaMethods := make(map[string]Function, len(operators)+1)
70-
metaMethods[luacode.TagMethodIndex.String()] = nil
72+
indexMethod := luacode.TagMethodIndex.String()
73+
metaMethods[indexMethod] = nil
7174
for _, op := range operators {
7275
op := op // Capture constant instead of loop variable.
7376
metaMethods[op.TagMethod().String()] = func(ctx context.Context, l *State) (int, error) {
@@ -77,7 +80,7 @@ func OpenString(ctx context.Context, l *State) (int, error) {
7780

7881
NewLib(l, metaMethods)
7982
l.PushValue(-2)
80-
l.RawSetField(-2, "__index")
83+
l.RawSetField(-2, indexMethod)
8184

8285
// Set string metatable.
8386
l.PushString("")

internal/lua/stringlib_test.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ func TestStringFind(t *testing.T) {
108108
init int64
109109
plain bool
110110

111-
want []any
111+
want []any
112+
wantError string
112113
}{
113114
{
114115
s: "",
@@ -494,6 +495,18 @@ func TestStringFind(t *testing.T) {
494495
init: 1,
495496
want: []any{int64(4), int64(12)},
496497
},
498+
{
499+
s: "abc",
500+
pattern: "[%a-z]",
501+
init: 1,
502+
wantError: "character class used in range",
503+
},
504+
{
505+
s: "a^c",
506+
pattern: "[%]-`]",
507+
init: 1,
508+
want: []any{int64(2), int64(2)},
509+
},
497510
}
498511

499512
ctx := context.Background()
@@ -531,7 +544,15 @@ func TestStringFind(t *testing.T) {
531544
testName += ")"
532545

533546
if err := state.Call(ctx, state.Top()-top, MultipleReturns); err != nil {
534-
t.Errorf("%s: %v", testName, err)
547+
if test.wantError == "" {
548+
t.Errorf("%s: %v", testName, err)
549+
} else if got := err.Error(); !strings.Contains(got, test.wantError) {
550+
t.Errorf("%s raised: %s; want message to contain %q", testName, got, test.wantError)
551+
}
552+
return
553+
}
554+
if test.wantError != "" {
555+
t.Errorf("%s did not raise an error", testName)
535556
return
536557
}
537558

0 commit comments

Comments
 (0)