Skip to content

Commit 3891928

Browse files
committed
Add structural tag
1 parent d818bf4 commit 3891928

File tree

9 files changed

+459
-15
lines changed

9 files changed

+459
-15
lines changed

Sources/CMLXStructured/grammar_compiler.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,24 @@ extern "C" void* compile_json_schema_grammar(
6060
}
6161
}
6262

63+
extern "C" void* compile_structural_tag(
64+
void* tokenizer_info,
65+
const char* structural_tag_utf8,
66+
size_t structural_tag_len
67+
) {
68+
try {
69+
const std::string structural_tag(structural_tag_utf8, structural_tag_len);
70+
auto& tokenizer_info_ptr = *static_cast<TokenizerInfo*>(tokenizer_info);
71+
auto* compiled_grammar_ptr = new CompiledGrammar(
72+
GrammarCompiler(tokenizer_info_ptr).CompileStructuralTag(structural_tag)
73+
);
74+
return compiled_grammar_ptr;
75+
} catch (const std::exception& e) {
76+
catch_error(e.what());
77+
return nullptr;
78+
}
79+
}
80+
6381
extern "C" void compiled_grammar_free(void* compiled_grammar) {
6482
if (compiled_grammar) {
6583
delete static_cast<CompiledGrammar*>(compiled_grammar);

Sources/CMLXStructured/include/mlx_structured/grammar_compiler.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ void* compile_json_schema_grammar(
2525
int indent
2626
);
2727

28+
void* compile_structural_tag(
29+
void* tokenizer_info,
30+
const char* structural_tag_utf8,
31+
size_t structural_tag_len
32+
);
33+
2834
void compiled_grammar_free(void* compiled_grammar);
2935

3036
#ifdef __cplusplus

Sources/MLXStructured/Backends/XGrammar.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ final class XGrammar {
9797
schema.utf8CString.withUnsafeBufferPointer {
9898
compile_json_schema_grammar(tokenizerInfo, $0.baseAddress, $0.count, Int32(indent ?? -1))
9999
}
100+
case .structural(let tag):
101+
tag.utf8CString.withUnsafeBufferPointer {
102+
compile_structural_tag(tokenizerInfo, $0.baseAddress, $0.count)
103+
}
100104
}
101105

102106
defer {

Sources/MLXStructured/Grammar+Generable.swift renamed to Sources/MLXStructured/Grammar/Grammar+Generable.swift

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@ public extension Grammar {
1717
static func schema<Content: Generable>(generable type: Content.Type, indent: Int? = nil) throws -> Grammar {
1818
let encoder = JSONEncoder()
1919
let data = try encoder.encode(type.generationSchema)
20-
guard let string = String(data: data, encoding: .utf8) else {
21-
throw EncodingError.invalidValue(
22-
type,
23-
EncodingError.Context(codingPath: [], debugDescription: "Failed to encode generation schema using UTF-8.")
24-
)
25-
}
20+
let string = String(decoding: data, as: UTF8.self)
2621
return .schema(string, indent: indent)
2722
}
2823
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//
2+
// Grammar+Schema.swift
3+
// MLXStructured
4+
//
5+
// Created by Ivan Petrukha on 04.10.2025.
6+
//
7+
8+
import Foundation
9+
import JSONSchema
10+
11+
public extension Grammar {
12+
static func schema(_ schema: JSONSchema = .object(), indent: Int? = nil) throws -> Grammar {
13+
let encoder = JSONEncoder()
14+
let data = try encoder.encode(schema)
15+
let string = String(decoding: data, as: UTF8.self)
16+
return .schema(string, indent: indent)
17+
}
18+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//
2+
// Grammar+Structural.swift
3+
// MLXStructured
4+
//
5+
// Created by Ivan Petrukha on 27.09.2025.
6+
//
7+
8+
import Foundation
9+
10+
public extension Grammar {
11+
init(@FormatBuilder _ content: () -> Encodable) throws {
12+
let tag = StructuralTag(format: content())
13+
let encoder = JSONEncoder()
14+
let data = try encoder.encode(tag)
15+
let string = String(decoding: data, as: UTF8.self)
16+
self = Grammar.structural(string)
17+
}
18+
}

Sources/MLXStructured/Grammar.swift renamed to Sources/MLXStructured/Grammar/Grammar.swift

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,12 @@ public enum Grammar {
1212
case ebnf(String)
1313
case regex(String)
1414
case schema(String, indent: Int? = nil)
15-
}
16-
17-
public extension Grammar {
18-
static func schema(_ schema: JSONSchema = .object(), indent: Int? = nil) throws -> Grammar {
19-
let encoder = JSONEncoder()
20-
let data = try encoder.encode(schema)
21-
let string = String(decoding: data, as: UTF8.self)
22-
return .schema(string, indent: indent)
23-
}
15+
case structural(String)
2416
}
2517

2618
public extension Grammar {
2719

20+
@available(*, deprecated, message: "Prefer constructing prompt manually, this property will be removed in the future versions")
2821
var raw: String {
2922
switch self {
3023
case .ebnf(let ebnf):
@@ -33,9 +26,12 @@ public extension Grammar {
3326
return regex
3427
case .schema(let schema, _):
3528
return schema
29+
case .structural(let tag):
30+
return tag
3631
}
3732
}
3833

34+
@available(*, deprecated, message: "Prefer constructing prompt manually, this property will be removed in the future versions")
3935
var guidance: String? {
4036
switch self {
4137
case .ebnf:
@@ -44,6 +40,8 @@ public extension Grammar {
4440
return "Output is regex constrained: \(regex)"
4541
case .schema(let schema, _):
4642
return "Output is JSON schema constrained: \(schema)"
43+
case .structural:
44+
return nil
4745
}
4846
}
4947
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//
2+
// StructuralTag+Builder.swift
3+
// MLXStructured
4+
//
5+
// Created by Ivan Petrukha on 25.09.2025.
6+
//
7+
8+
@resultBuilder
9+
public enum FormatBuilder {
10+
11+
public static func buildExpression(_ expression: Encodable) -> Encodable {
12+
expression
13+
}
14+
15+
public static func buildBlock(_ component: Encodable) -> Encodable {
16+
component
17+
}
18+
19+
public static func buildOptional(_ component: Encodable?) -> Encodable {
20+
component ?? AnyTextFormat()
21+
}
22+
23+
public static func buildEither(first component: Encodable) -> Encodable {
24+
component
25+
}
26+
27+
public static func buildEither(second component: Encodable) -> Encodable {
28+
component
29+
}
30+
31+
public static func buildLimitedAvailability(_ component: Encodable) -> Encodable {
32+
component
33+
}
34+
}
35+
36+
@resultBuilder
37+
public enum FormatListBuilder {
38+
39+
public static func buildExpression(_ expression: Encodable) -> [Encodable] {
40+
[expression]
41+
}
42+
43+
public static func buildBlock(_ components: [Encodable]...) -> [Encodable] {
44+
components.flatMap { $0 }
45+
}
46+
47+
public static func buildOptional(_ component: [Encodable]?) -> [Encodable] {
48+
component ?? []
49+
}
50+
51+
public static func buildEither(first component: [Encodable]) -> [Encodable] {
52+
component
53+
}
54+
55+
public static func buildEither(second component: [Encodable]) -> [Encodable] {
56+
component
57+
}
58+
59+
public static func buildArray(_ components: [[Encodable]]) -> [Encodable] {
60+
components.flatMap { $0 }
61+
}
62+
63+
public static func buildLimitedAvailability(_ component: [Encodable]) -> [Encodable] {
64+
component
65+
}
66+
}
67+
68+
@resultBuilder
69+
public enum TagListBuilder {
70+
71+
public static func buildExpression(_ expression: TagFormat) -> [TagFormat] {
72+
[expression]
73+
}
74+
75+
public static func buildBlock(_ components: [TagFormat]...) -> [TagFormat] {
76+
components.flatMap { $0 }
77+
}
78+
79+
public static func buildOptional(_ component: [TagFormat]?) -> [TagFormat] {
80+
component ?? []
81+
}
82+
83+
public static func buildEither(first component: [TagFormat]) -> [TagFormat] {
84+
component
85+
}
86+
87+
public static func buildEither(second component: [TagFormat]) -> [TagFormat] {
88+
component
89+
}
90+
91+
public static func buildArray(_ components: [[TagFormat]]) -> [TagFormat] {
92+
components.flatMap { $0 }
93+
}
94+
95+
public static func buildLimitedAvailability(_ component: [TagFormat]) -> [TagFormat] {
96+
component
97+
}
98+
}
99+
100+
public extension StructuralTag {
101+
init(@FormatBuilder _ content: () -> Encodable) {
102+
self.init(format: content())
103+
}
104+
}
105+
106+
public extension SequenceFormat {
107+
init(@FormatListBuilder _ content: () -> [Encodable]) {
108+
self.init(elements: content())
109+
}
110+
}
111+
112+
public extension OrFormat {
113+
init(@FormatListBuilder _ content: () -> [Encodable]) {
114+
self.init(elements: content())
115+
}
116+
}
117+
118+
public extension TagFormat {
119+
init(begin: String, end: String, @FormatBuilder _ content: () -> Encodable) {
120+
self.init(begin: begin, content: content(), end: end)
121+
}
122+
}
123+
124+
public extension TriggeredTagsFormat {
125+
init(triggers: [String], options: Options = [], @TagListBuilder _ content: () -> [TagFormat]) {
126+
self.init(triggers: triggers, tags: content(), options: options)
127+
}
128+
}
129+
130+
131+
public extension OrFormat {
132+
mutating func appending(_ formats: [Encodable]) -> OrFormat {
133+
OrFormat(elements: self.elements + formats)
134+
}
135+
}
136+
137+
public extension SequenceFormat {
138+
mutating func appending(_ formats: [Encodable]) -> SequenceFormat {
139+
SequenceFormat(elements: self.elements + formats)
140+
}
141+
}

0 commit comments

Comments
 (0)