Skip to content

Commit d8da52a

Browse files
authored
Enum case matching by remaining buffer length (#26)
This PR adds support of matching enum case using remaining buffer length. This matching method is exclusive with byte matching. The changes (implementation, test, and specs) are reviewed by the owner.
1 parent 4aa4a47 commit d8da52a

File tree

13 files changed

+795
-46
lines changed

13 files changed

+795
-46
lines changed

Sources/BinaryParseKit/BinaryParseKit.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,44 @@ public macro match(bytes: [UInt8]) = #externalMacro(
401401
type: "EmptyPeerMacro",
402402
)
403403

404+
/// Matches when the remaining buffer bytes equals the specified length exactly.
405+
///
406+
/// Use this macro to match cases based on remaining data size rather than specific byte patterns.
407+
/// The buffer is NOT consumed, allowing subsequent parsing operations to read the data.
408+
///
409+
/// - Parameter length: The exact number of remaining bytes to match
410+
///
411+
/// - Note: An enum using `@ParseEnum` can only use byte-based matching OR length-based matching,
412+
/// but not both. `@matchDefault` is allowed with either strategy. That is, it's not allowed to
413+
/// use a mix of `@match(byte:)`/`@match(bytes:)`/`@matchAndTake(byte:)` and `@match(length:)`
414+
/// in the same enum.
415+
///
416+
///
417+
/// Example:
418+
/// ```swift
419+
/// @ParseEnum
420+
/// enum VariableSizeData {
421+
/// @match(length: 4)
422+
/// @parse(endianness: .big)
423+
/// case shortPayload(UInt32)
424+
///
425+
/// @match(length: 8)
426+
/// @parse(endianness: .big)
427+
/// case longPayload(UInt64)
428+
///
429+
/// @matchDefault
430+
/// case unknown
431+
/// }
432+
///
433+
/// let data = try VariableSizeData(parsing: Data([0x12, 0x34, 0x56, 0x78]))
434+
/// // data == .shortPayload(0x12345678) because buffer has exactly 4 bytes
435+
/// ```
436+
@attached(peer)
437+
public macro match(length: Int) = #externalMacro(
438+
module: "BinaryParseKitMacros",
439+
type: "EmptyPeerMacro",
440+
)
441+
404442
/// Matches and consumes bytes from the buffer using the enum's raw value.
405443
///
406444
/// Use this macro for ``Matchable`` enums where each case's raw value

Sources/BinaryParseKit/Utils/ParsingUtils.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import BinaryParsing
88

99
/// Matches the given bytes in the input parser span.
10-
/// - Warning: This function is used to `@parseEnum` macro and should not be used directly.
10+
/// - Warning: This function is used by `@ParseEnum` macro and should not be used directly.
1111
@inlinable
1212
public func __match(_ bytes: borrowing [UInt8], in input: inout BinaryParsing.ParserSpan) -> Bool {
1313
if bytes.isEmpty { return true }
@@ -22,6 +22,13 @@ public func __match(_ bytes: borrowing [UInt8], in input: inout BinaryParsing.Pa
2222
return toMatch == bytes
2323
}
2424

25+
/// Matches when the remaining bytes in the input parser span equals the specified length.
26+
/// - Warning: This function is used by `@ParseEnum` macro and should not be used directly.
27+
@inlinable
28+
public func __match(length: Int, in input: borrowing BinaryParsing.ParserSpan) -> Bool {
29+
input.count == length
30+
}
31+
2532
/// Asserts that the given type conforms to `Parsable`.
2633
/// - Warning: This function is used to `@parse` macro and should not be used directly.
2734
@inlinable

Sources/BinaryParseKitMacros/Macros/ParseEnum/ConsructParseEnumMacro.swift renamed to Sources/BinaryParseKitMacros/Macros/ParseEnum/ConstructParseEnumMacro.swift

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// ConsructParseEnumMacro.swift
2+
// ConstructParseEnumMacro.swift
33
// BinaryParseKit
44
//
55
// Created by Larry Zeng on 7/26/25.
@@ -43,13 +43,36 @@ public struct ConstructEnumParseMacro: ExtensionMacro {
4343
"\(accessorInfo.parsingAccessor) init(parsing span: inout \(raw: Constants.BinaryParsing.parserSpan)) throws(\(raw: Constants.BinaryParsing.thrownParsingError))",
4444
) {
4545
for caseParseInfo in parseInfo.caseParseInfo {
46-
let toBeMatched = caseParseInfo.bytesToMatch(of: type)
46+
// Generate the match condition based on match type
47+
let matchCondition = try ConditionElementListSyntax {
48+
if let matchLength = caseParseInfo.lengthToMatch() {
49+
// Length-based matching: use __match(length:in:) with borrowing span
50+
ExprSyntax(
51+
"\(raw: Constants.UtilityFunctions.matchLength)(length: \(matchLength), in: span)",
52+
)
53+
} else if let toBeMatched = caseParseInfo.bytesToMatch(of: type) {
54+
// Byte-array-based matching: use __matchBytes with inout span
55+
ExprSyntax("\(raw: Constants.UtilityFunctions.matchBytes)(\(toBeMatched), in: &span)")
56+
} else if caseParseInfo.matchAction.target.isDefaultMatch {
57+
// Default matching: always true
58+
ExprSyntax("true")
59+
} else {
60+
// Otherwise, it's a failure on our side
61+
throw ParseEnumMacroError.unexpectedError(
62+
description: "Failed to obtain matching bytes for \(caseParseInfo.caseElementName)",
63+
)
64+
}
65+
}
4766

48-
try IfExprSyntax(
49-
"if \(raw: Constants.UtilityFunctions.matchBytes)(\(toBeMatched), in: &span)",
50-
) {
67+
try IfExprSyntax("if \(matchCondition)") {
5168
if caseParseInfo.matchAction.matchPolicy == .matchAndTake {
52-
"try span.seek(toRelativeOffset: \(toBeMatched).count)"
69+
if let toBeMatched = caseParseInfo.bytesToMatch(of: type) {
70+
"try span.seek(toRelativeOffset: \(toBeMatched).count)"
71+
} else {
72+
throw ParseEnumMacroError.unexpectedError(
73+
description: "Failed to obtain matching bytes for \(caseParseInfo.caseElementName) when taking",
74+
)
75+
}
5376
}
5477

5578
var arguments: OrderedDictionary<TokenSyntax, EnumCaseParameterParseInfo> = [:]
@@ -179,16 +202,29 @@ public struct ConstructEnumParseMacro: ExtensionMacro {
179202
}
180203
})
181204

182-
let caseCodeBlock = CodeBlockItemListSyntax {
183-
let toBeMatched = caseParseInfo.bytesToMatch(of: type)
205+
let caseCodeBlock = try CodeBlockItemListSyntax {
206+
let bytesTakenInMatching = context.makeUniqueName("bytesTakenInMatching")
207+
208+
if caseParseInfo.matchAction.target.isLengthMatch
209+
|| caseParseInfo.matchAction.target.isDefaultMatch {
210+
// For length-based matching, use empty bytes array (similar to matchDefault)
211+
"let \(bytesTakenInMatching): [UInt8] = []"
212+
} else if let bytesToMatch = caseParseInfo.bytesToMatch(of: type) {
213+
"let \(bytesTakenInMatching): [UInt8] = \(bytesToMatch)"
214+
} else {
215+
throw ParseEnumMacroError.unexpectedError(
216+
description: "Failed to obtain matching bytes for \(caseParseInfo.caseElementName)",
217+
)
218+
}
219+
184220
let matchPolicy = caseParseInfo.matchAction.matchPolicy
185221

186222
let fields = ArrayExprSyntax(elements: generatePrintableFields(parseSkipMacroInfo))
187223

188224
#"""
189225
return .enum(
190226
.init(
191-
bytes: \#(toBeMatched),
227+
bytes: \#(bytesTakenInMatching),
192228
parseType: .\#(raw: matchPolicy),
193229
fields: \#(fields),
194230
)

Sources/BinaryParseKitMacros/Macros/ParseEnum/EnumCaseParseInfo.swift

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,92 @@ enum EnumParseAction {
4545
case skip(SkipMacroInfo)
4646
}
4747

48+
enum EnumMatchTarget {
49+
// @match() -> .bytes(nil)
50+
// @match(byte: byte) -> .byte([byte])
51+
// @match(bytes: bytes) -> .bytes(bytes)
52+
// @matchAndTake() -> .bytes(nil)
53+
// @matchAndTake(byte: byte) -> .byte([byte])
54+
// @matchAndTake(bytes: bytes) -> .bytes(bytes)
55+
case bytes(ExprSyntax?)
56+
// @match(length: length) -> .length(length)
57+
case length(ExprSyntax)
58+
// @matchDefault
59+
case `default`
60+
61+
var matchBytes: ExprSyntax?? {
62+
if case let .bytes(bytes) = self {
63+
bytes
64+
} else {
65+
nil
66+
}
67+
}
68+
69+
var matchLength: ExprSyntax? {
70+
if case let .length(length) = self {
71+
length
72+
} else {
73+
nil
74+
}
75+
}
76+
77+
var isLengthMatch: Bool {
78+
if case .length = self {
79+
true
80+
} else {
81+
false
82+
}
83+
}
84+
85+
var isByteMatch: Bool {
86+
if case .bytes = self {
87+
true
88+
} else {
89+
false
90+
}
91+
}
92+
93+
var isDefaultMatch: Bool {
94+
if case .default = self {
95+
true
96+
} else {
97+
false
98+
}
99+
}
100+
}
101+
48102
struct EnumCaseMatchAction {
49-
let matchBytes: ExprSyntax?
103+
let target: EnumMatchTarget
50104
let matchPolicy: EnumCaseMatchPolicy
51105

52106
static func match(bytes: ExprSyntax?) -> EnumCaseMatchAction {
53-
EnumCaseMatchAction(matchBytes: bytes, matchPolicy: .match)
107+
EnumCaseMatchAction(target: .bytes(bytes), matchPolicy: .match)
108+
}
109+
110+
static func matchLength(_ length: ExprSyntax) -> EnumCaseMatchAction {
111+
EnumCaseMatchAction(target: .length(length), matchPolicy: .match)
54112
}
55113

56114
static func matchDefault() -> EnumCaseMatchAction {
57-
EnumCaseMatchAction(matchBytes: "[]", matchPolicy: .matchDefault)
115+
EnumCaseMatchAction(target: .default, matchPolicy: .matchDefault)
58116
}
59117

60118
static func matchAndTake(bytes: ExprSyntax?) -> EnumCaseMatchAction {
61-
EnumCaseMatchAction(matchBytes: bytes, matchPolicy: .matchAndTake)
119+
EnumCaseMatchAction(target: .bytes(bytes), matchPolicy: .matchAndTake)
62120
}
63121

64122
static func parseMatch(from attribute: AttributeSyntax) throws(ParseEnumMacroError) -> EnumCaseMatchAction {
65123
let arguments = attribute.arguments?.as(LabeledExprListSyntax.self)
66-
let bytes = try parseBytesArgument(in: arguments, at: 0)
67124

125+
// Check if this is @match(length:)
126+
if let args = arguments,
127+
let firstArg = args.first,
128+
firstArg.label?.text == "length" {
129+
return .matchLength(firstArg.expression)
130+
}
131+
132+
// Otherwise, parse as byte-based match
133+
let bytes = try parseBytesArgument(in: arguments, at: 0)
68134
return .match(bytes: bytes)
69135
}
70136

@@ -117,22 +183,35 @@ struct EnumCaseParseInfo {
117183
let matchAction: EnumCaseMatchAction
118184
let parseActions: [EnumParseAction]
119185
let caseElementName: TokenSyntax
120-
121-
init(matchAction: EnumCaseMatchAction, parseActions: [EnumParseAction], caseElementName: TokenSyntax) {
186+
let source: Syntax
187+
188+
init(
189+
matchAction: EnumCaseMatchAction,
190+
parseActions: [EnumParseAction],
191+
caseElementName: TokenSyntax,
192+
source: Syntax,
193+
) {
122194
self.matchAction = matchAction
123195
self.parseActions = parseActions
124196
self.caseElementName = caseElementName.trimmed
197+
self.source = source
125198
}
126199

127-
func bytesToMatch(of type: some TypeSyntaxProtocol) -> ExprSyntax {
128-
if let matchBytes = matchAction.matchBytes {
129-
matchBytes
130-
} else {
131-
ExprSyntax(
132-
"(\(type).\(caseElementName) as any \(raw: Constants.Protocols.matchableProtocol)).bytesToMatch()",
133-
)
200+
func bytesToMatch(of type: some TypeSyntaxProtocol) -> ExprSyntax? {
201+
matchAction.target.matchBytes.map { matchBytes in
202+
if let matchBytes {
203+
matchBytes
204+
} else {
205+
ExprSyntax(
206+
"(\(type).\(caseElementName) as any \(raw: Constants.Protocols.matchableProtocol)).bytesToMatch()",
207+
)
208+
}
134209
}
135210
}
211+
212+
func lengthToMatch() -> ExprSyntax? {
213+
matchAction.target.matchLength
214+
}
136215
}
137216

138217
struct EnumParseInfo {

Sources/BinaryParseKitMacros/Macros/ParseEnum/ParseEnumCase.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class ParseEnumCase: SyntaxVisitor {
1818
private(set) var parsedInfo: EnumParseInfo?
1919

2020
private var hasMatchDefault = false
21+
private var hasByteMatch = false
22+
private var hasLengthMatch = false
2123

2224
private var errors: [Diagnostic] = []
2325

@@ -107,6 +109,29 @@ class ParseEnumCase: SyntaxVisitor {
107109
),
108110
)
109111
}
112+
113+
// Track matching strategy for mutual exclusivity validation
114+
if matchAction.target.isLengthMatch {
115+
if hasByteMatch {
116+
errors.append(
117+
.init(
118+
node: node,
119+
message: ParseEnumMacroError.mixedMatchingStrategies,
120+
),
121+
)
122+
}
123+
hasLengthMatch = true
124+
} else if matchAction.target.isByteMatch {
125+
if hasLengthMatch {
126+
errors.append(
127+
.init(
128+
node: node,
129+
message: ParseEnumMacroError.mixedMatchingStrategies,
130+
),
131+
)
132+
}
133+
hasByteMatch = true
134+
}
110135
}
111136

112137
for currentCaseElement in currentCaseElements {
@@ -120,6 +145,7 @@ class ParseEnumCase: SyntaxVisitor {
120145
matchAction: matchAction,
121146
parseActions: enumParseActions,
122147
caseElementName: currentCaseElement.name,
148+
source: Syntax(node),
123149
),
124150
)
125151
}

Sources/BinaryParseKitMacros/Macros/ParseEnum/ParseEnumMacroError.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ enum ParseEnumMacroError: Error, DiagnosticMessage {
1515
case defaultCaseMustBeLast
1616
case onlyOneMatchDefaultAllowed
1717
case matchDefaultShouldBeLast
18+
case mixedMatchingStrategies
1819
case unexpectedError(description: String)
1920

2021
var message: String {
@@ -29,6 +30,8 @@ enum ParseEnumMacroError: Error, DiagnosticMessage {
2930
case .defaultCaseMustBeLast: "The `matchDefault` case must be the last case in the enum."
3031
case .onlyOneMatchDefaultAllowed: "Only one `matchDefault` case is allowed in a enum."
3132
case .matchDefaultShouldBeLast: "The `matchDefault` case should be the last case in the enum."
33+
case .mixedMatchingStrategies:
34+
"An enum cannot mix byte-based matching (@match, @match(byte:), @match(bytes:), @matchAndTake) with length-based matching (@match(length:))."
3235
case let .unexpectedError(description: description):
3336
"Unexpected error: \(description)"
3437
}

Sources/BinaryParseKitMacros/Macros/Supports/Constants.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ extension Constants {
4141
extension Constants {
4242
enum UtilityFunctions {
4343
static let matchBytes = PackageMember(name: "__match")
44+
static let matchLength = PackageMember(name: "__match")
4445
static let assertParsable = PackageMember(name: "__assertParsable")
4546
static let assertSizedParsable = PackageMember(name: "__assertSizedParsable")
4647
static let assertEndianParsable = PackageMember(name: "__assertEndianParsable")

0 commit comments

Comments
 (0)