@@ -5,6 +5,7 @@ package lua
55
66import (
77 "errors"
8+ "fmt"
89 "iter"
910 "strings"
1011
@@ -340,32 +341,92 @@ func matchByte(b byte, characterClass string) bool {
340341// For example, matchBracketClass(b, "^abc") checks whether b matches "[^abc]".
341342func matchBracketClass (b byte , set string ) bool {
342343 set , invert := strings .CutPrefix (set , "^" )
343- for i := 0 ; i < len (set ); i ++ {
344- switch p := set [i ]; {
345- case p == '%' :
346- // Escaped class.
347- i ++
348- if i >= len (set ) {
349- return false
350- }
351- if matchEscapedClass (b , set [i ]) {
352- return ! invert
353- }
354- case i + 2 < len (set ) && set [i + 1 ] == '-' :
355- // Range.
356- if p <= b && b <= set [i + 2 ] {
357- return ! invert
358- }
359- i += 2
360- case b == p :
344+ for len (set ) > 0 {
345+ curr , next , err := cutBracketClassItem (set )
346+ if err != nil {
347+ return false
348+ }
349+ matched := false
350+ switch {
351+ case curr [1 ] != "" :
352+ lo := curr [0 ][len (curr [0 ])- 1 ]
353+ hi := curr [1 ][len (curr [1 ])- 1 ]
354+ matched = lo <= b && b <= hi
355+ case curr [0 ][0 ] == '%' :
356+ matched = matchEscapedClass (b , curr [0 ][1 ])
357+ default :
358+ matched = b == curr [0 ][0 ]
359+ }
360+ if matched {
361361 return ! invert
362362 }
363+ set = next
363364 }
364365 return invert
365366}
366367
368+ // cutBracketClassItem returns set without the leading character class or range.
369+ // If charRange[1] != "", then the leading item is a range.
370+ func cutBracketClassItem (set string ) (charRange [2 ]string , rest string , err error ) {
371+ if len (set ) == 0 {
372+ return [2 ]string {}, "" , nil
373+ }
374+ end1 , err := bracketCharacterClassEnd (set )
375+ if err != nil {
376+ return [2 ]string {}, set , err
377+ }
378+
379+ // If a hyphen immediately follows the character class
380+ // and the hyphen is not the last character of the set,
381+ // then this is a range.
382+ start2 := end1 + 1
383+ if start2 >= len (set ) || set [end1 ] != '-' {
384+ return [2 ]string {set [:end1 ]}, set [end1 :], nil
385+ }
386+ if len (set ) >= 2 && set [0 ] == '%' && isKnownEscapedClass (set [1 ]) {
387+ return [2 ]string {}, set , errors .New ("malformed pattern (character class used in range)" )
388+ }
389+ end2 , err := characterClassEnd (set [start2 :])
390+ if err != nil {
391+ return [2 ]string {}, set , err
392+ }
393+ end2 += start2
394+ if len (set ) >= start2 + 2 && set [start2 ] == '%' && isKnownEscapedClass (set [start2 + 1 ]) {
395+ return [2 ]string {}, set , errors .New ("malformed pattern (character class used in range)" )
396+ }
397+ return [2 ]string {set [:end1 ], set [start2 :end2 ]}, set [end2 :], nil
398+ }
399+
400+ // bracketCharacterClassEnd returns the length of the Lua pattern character class
401+ // at the start of the given set.
402+ // bracketCharacterClassEnd returns 0 if and only if set is empty.
403+ //
404+ // Character classes recognized by bracketCharacterClassEnd
405+ // are bytes or escapes indicating sets of characters.
406+ func bracketCharacterClassEnd (set string ) (int , error ) {
407+ switch {
408+ case len (set ) == 0 :
409+ return 0 , nil
410+ case set [0 ] == '%' :
411+ if len (set ) < 2 {
412+ return - 1 , errors .New ("malformed pattern (ends with '%')" )
413+ }
414+ if isASCIIDigit (rune (set [1 ])) {
415+ return - 1 , errors .New ("patterns with backreferences not supported" )
416+ }
417+ if isASCIILetter (rune (set [1 ])) && ! isKnownEscapedClass (set [1 ]) {
418+ return - 1 , fmt .Errorf ("malformed pattern (unknown character class %q)" , set [:2 ])
419+ }
420+ return 2 , nil
421+ default :
422+ return 1 , nil
423+ }
424+ }
425+
367426// matchEscapedClass reports whether b matches the Lua pattern character class
368427// written as a percent sign followed by the byte p.
428+ //
429+ // If you change this function, update [isKnownEscapedClass].
369430func matchEscapedClass (b byte , p byte ) bool {
370431 var matched bool
371432 switch toLowerASCII (rune (p )) {
@@ -395,26 +456,34 @@ func matchEscapedClass(b byte, p byte) bool {
395456 return matched == isASCIILowercase (rune (p ))
396457}
397458
459+ // isKnownEscapedClass reports whether the given byte
460+ // forms a character class that is not a direct escape of the byte
461+ // when preceded by a '%'.
462+ // For example, isKnownEscapedClass('a') reports true
463+ // and isKnownEscapedClass('[') reports false.
464+ func isKnownEscapedClass (p byte ) bool {
465+ if ! isASCII (rune (p )) {
466+ return false
467+ }
468+ p = byte (toLowerASCII (rune (p )))
469+ return strings .IndexByte ("acdglpsuwx" , p ) != - 1
470+ }
471+
398472// characterClassEnd returns the length of the Lua pattern character class
399473// at the start of pattern.
400474// characterClassEnd returns 0 if and only if pattern is empty.
401475//
402476// Character classes are a byte, an escape indicating a set of characters,
403477// or a bracketed character class.
478+ // characterClassEnd is largely the same as [bracketCharacterClassEnd],
479+ // but parses bracketed character classes.
404480func characterClassEnd (pattern string ) (end int , err error ) {
405- switch {
406- case len (pattern ) == 0 :
407- return 0 , nil
408- case pattern [0 ] == '%' :
409- if len (pattern ) < 2 {
410- return - 1 , errors .New ("malformed pattern (ends with '%')" )
411- }
412- return 2 , nil
413- case pattern [0 ] == '[' :
481+ if len (pattern ) > 0 && pattern [0 ] == '[' {
414482 end := 1
415483 if strings .HasPrefix (pattern [end :], "^" ) {
416484 end ++
417485 }
486+ start := end
418487 if strings .HasPrefix (pattern [end :], "]" ) {
419488 // Don't let ']' in first position terminate class.
420489 end ++
@@ -425,13 +494,20 @@ func characterClassEnd(pattern string) (end int, err error) {
425494 // Skip escape.
426495 end ++
427496 case ']' :
497+ for set := pattern [start :end ]; len (set ) > 0 ; {
498+ _ , rest , err := cutBracketClassItem (set )
499+ if err != nil {
500+ return - 1 , err
501+ }
502+ set = rest
503+ }
428504 return end + 1 , nil
429505 }
430506 }
431507 return - 1 , errors .New ("malformed pattern (missing ']')" )
432- default :
433- return 1 , nil
434508 }
509+
510+ return bracketCharacterClassEnd (pattern )
435511}
436512
437513func isASCIILetter (c rune ) bool {
0 commit comments