From 378690e036b01e54f9d11aabe90f35beabad16db Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 10 Mar 2025 13:41:41 -0700 Subject: [PATCH 1/3] converted from _ast.py (o3-mini) --- src/ast.ts | 403 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 403 insertions(+) create mode 100644 src/ast.ts diff --git a/src/ast.ts b/src/ast.ts new file mode 100644 index 0000000..87732b7 --- /dev/null +++ b/src/ast.ts @@ -0,0 +1,403 @@ +export interface LarkGrammar { + name: string; + lark_grammar: string; +} + +export interface JsonGrammar { + name: string; + json_schema: Record; +} + +export interface LLGrammar { + grammars: Array; +} + +export abstract class ASTNode { + simplify(): ASTNode { + return this; + } +} + +export abstract class GrammarNode extends ASTNode { + simplify(): GrammarNode { + return this; + } + + children(): GrammarNode[] { + return []; + } + + /** + * If this returns true, then this node matches empty string and empty string only. + */ + isNull(): boolean { + return false; + } + + /** + * If this returns true, then this node will be compiled down to a regular expression. + * It cannot be recursive. + */ + isTerminal(): boolean { + return this.children().every((child) => child.isTerminal()); + } + + ll_grammar(): LLGrammar { + return new LLSerializer().serialize(this); + } +} + +export class LiteralNode extends GrammarNode { + constructor(public value: string) { + super(); + } + isNull(): boolean { + return this.value === ""; + } +} + +export class RegexNode extends GrammarNode { + constructor(public regex: string) { + super(); + } +} + +export class SelectNode extends GrammarNode { + constructor(public alternatives: GrammarNode[]) { + super(); + } + isNull(): boolean { + return this.alternatives.every((alt) => alt.isNull()); + } + simplify(): GrammarNode { + if (this.isNull()) return new LiteralNode(""); + const alts = this.alternatives + .map((alt) => alt.simplify()) + .filter((alt) => !alt.isNull()); + const node = alts.length === 1 ? alts[0] : new SelectNode(alts); + if (this.alternatives.some((alt) => alt.isNull())) + return new RepeatNode(node, 0, 1); + return node; + } + children(): GrammarNode[] { + return this.alternatives; + } +} + +export class JoinNode extends GrammarNode { + constructor(public nodes: GrammarNode[]) { + super(); + } + isNull(): boolean { + return this.nodes.every((node) => node.isNull()); + } + simplify(): GrammarNode { + if (this.isNull()) return new LiteralNode(""); + const simplified = this.nodes + .map((node) => node.simplify()) + .filter((node) => !node.isNull()); + return simplified.length === 1 ? simplified[0] : new JoinNode(simplified); + } + children(): GrammarNode[] { + return this.nodes; + } +} + +export class RepeatNode extends GrammarNode { + constructor( + public node: GrammarNode, + public min: number, + public max?: number + ) { + super(); + if (min < 0) throw new Error("min must be >= 0"); + if (max !== undefined && max < min) throw new Error("max must be >= min"); + } + isNull(): boolean { + return this.node.isNull() || (this.min === 0 && this.max === 0); + } + children(): GrammarNode[] { + return [this.node]; + } + simplify(): GrammarNode { + return new RepeatNode(this.node.simplify(), this.min, this.max); + } +} + +export class SubstringNode extends GrammarNode { + constructor(public chunks: string[]) { + super(); + } + isTerminal(): boolean { + return true; + } +} + +/** + * This creates a name for the given grammar node (value), which can be referenced + * via RuleRefNode (or directly). + * In Lark syntax this results in approx. "{name}: {value}" + * This can either Lark rule (non-terminal) or terminal definition + * (meaning name can be upper- or lowercase). + */ +export class RuleNode extends GrammarNode { + public capture?: string; + public list_append: boolean = false; + public temperature?: number; + public max_tokens?: number; + public stop?: RegexNode | LiteralNode; + public suffix?: LiteralNode; + public stop_capture?: string; + constructor(public name: string, public value: GrammarNode) { + super(); + if ( + (this.temperature !== undefined || + this.max_tokens !== undefined || + this.stop !== undefined || + this.suffix !== undefined || + this.stop_capture !== undefined) && + !(this.value.isTerminal() || this.value instanceof BaseSubgrammarNode) + ) { + throw new Error( + "RuleNode is not terminal, so it cannot have a temperature, max_tokens, or stop condition" + ); + } + } + isTerminal(): boolean { + return ( + this.capture === undefined && + this.temperature === undefined && + this.max_tokens === undefined && + this.stop === undefined && + this.suffix === undefined && + this.stop_capture === undefined && + this.value.isTerminal() && + !(this.value instanceof BaseSubgrammarNode) + ); + } + children(): GrammarNode[] { + return [this.value]; + } +} + +export class RuleRefNode extends GrammarNode { + private target?: RuleNode; + setTarget(target: RuleNode): void { + if (this.target) throw new Error("RuleRefNode target already set"); + this.target = target; + } + isTerminal(): boolean { + // RuleRefNode should only ever be used to enable recursive rule definitions, + // so it should never be terminal. + return false; + } +} + +export abstract class BaseSubgrammarNode extends GrammarNode { + constructor(public name: string) { + super(); + } + isTerminal(): boolean { + return false; + } +} + +export class SubgrammarNode extends BaseSubgrammarNode { + constructor( + name: string, + public body: GrammarNode, + public skip_regex?: string + ) { + super(name); + } +} + +export class JsonNode extends BaseSubgrammarNode { + constructor(name: string, public schema: Record) { + super(name); + } +} + +export class LLSerializer { + public grammars: { [name: string]: JsonGrammar | LarkGrammar } = {}; + public names: Map = new Map(); + + serialize(node: GrammarNode): LLGrammar { + if (node instanceof BaseSubgrammarNode) { + this.visit(node); + } else { + this.visit(new SubgrammarNode("main", node)); + } + const arr = Array.from(this.names.values()).map( + (name) => this.grammars[name] + ); + return { grammars: arr }; + } + visit(node: BaseSubgrammarNode): string { + if (this.names.has(node)) return this.names.get(node)!; + let name = node.name; + const used = new Set(this.names.values()); + if (used.has(name)) { + let i = 1; + while (used.has(`${name}_${i}`)) { + i++; + } + name = `${name}_${i}`; + } + if (node instanceof SubgrammarNode) { + // Important: insert name BEFORE visiting body to avoid infinite recursion + this.names.set(node, name); + const lark_grammar = + new LarkSerializer(this).serialize(node.body) + + (node.skip_regex ? `\n%ignore /${node.skip_regex}/` : ""); + this.grammars[name] = { + name, + lark_grammar, + }; + } else if (node instanceof JsonNode) { + this.names.set(node, name); + this.grammars[name] = { + name, + json_schema: node.schema, + }; + } else { + throw new TypeError(`Unknown subgrammar type: ${node}`); + } + return name; + } +} + +// LarkSerializer +export class LarkSerializer { + public rules: { [name: string]: string } = {}; + public names: Map = new Map(); + constructor(public llSerializer: LLSerializer) {} + serialize(node: GrammarNode): string { + if (node instanceof RuleNode && node.name === "start") { + this.visit(node); + } else { + this.visit(new RuleNode("start", node)); + } + let res = "%llguidance {}\n\n"; + if (!("start" in this.rules)) { + if ("START" in this.rules) res += "start: START\n"; + } + let prevNl = true; + for (const name of this.names.values()) { + let s = this.rules[name]; + if (!prevNl && s.indexOf("\n") !== -1) res += "\n"; + res += s + "\n"; + prevNl = s.indexOf("\n") !== -1; + if (prevNl) res += "\n"; + } + return res; + } + visit(node: GrammarNode, top: boolean = false): string { + if (node instanceof BaseSubgrammarNode) { + return "@" + this.llSerializer.visit(node); + } + if (node instanceof RuleNode) { + if (this.names.has(node)) return this.names.get(node)!; + let name = this.normalizeName(node.name, node.isTerminal()); + const used = new Set(this.names.values()); + if (used.has(name)) { + let i = 1; + while (used.has(`${name}_${i}`)) { + i++; + } + name = `${name}_${i}`; + } + this.names.set(node, name); + let res = name; + const attrs: string[] = []; + if (node.capture !== undefined) { + let captureName = node.capture; + if (node.list_append) { + captureName = `__LIST_APPEND:${captureName}`; + } + attrs.push(`capture=${JSON.stringify(captureName)}`); + } else { + attrs.push("capture"); + } + if (node.temperature !== undefined) { + attrs.push(`temperature=${node.temperature}`); + } + if (node.max_tokens !== undefined) { + attrs.push(`max_tokens=${node.max_tokens}`); + } + if (node.stop) { + attrs.push(`stop=${this.visit(node.stop)}`); + } + if (node.suffix) { + attrs.push(`suffix=${this.visit(node.suffix)}`); + } + if (node.stop_capture !== undefined) { + attrs.push(`stop_capture=${JSON.stringify(node.stop_capture)}`); + } + if (attrs.length > 0) res += `[${attrs.join(", ")}]`; + res += ": " + this.visit(node.value.simplify(), true); + this.rules[name] = res; + return name; + } + if (node.isNull()) return '""'; + if (node instanceof LiteralNode) { + return JSON.stringify(node.value); + } + if (node instanceof RegexNode) { + let rx = node.regex; + if (rx === undefined) rx = "(?s:.*)"; + return this.regex(rx); + } + if (node instanceof SelectNode) { + if (top) { + return node.alternatives + .map((alt) => this.visit(alt)) + .join("\n | "); + } else { + return ( + "(" + + node.alternatives.map((alt) => this.visit(alt)).join(" | ") + + ")" + ); + } + } + if (node instanceof JoinNode) { + return node.nodes + .filter((n) => !n.isNull()) + .map((n) => this.visit(n)) + .join(" "); + } + if (node instanceof RepeatNode) { + let inner = this.visit(node.node); + if (node.node instanceof JoinNode || node.node instanceof RepeatNode) { + inner = `(${inner})`; + } + if (node.min === 0 && node.max === undefined) return `${inner}*`; + if (node.min === 1 && node.max === undefined) return `${inner}+`; + if (node.min === 0 && node.max === 1) return `${inner}?`; + if (node.max === undefined) return `${inner}{${node.min},}`; + return `${inner}{${node.min},${node.max}}`; + } + if (node instanceof SubstringNode) { + return `%regex ${JSON.stringify( + { substring_chunks: node.chunks }, + null, + 2 + )}`; + } + if (node instanceof RuleRefNode) { + if (!node["target"]) throw new Error("RuleRefNode has no target"); + return this.visit(node["target"]); + } + throw new TypeError(`Unknown node type: ${node}`); + } + normalizeName(name: string, terminal: boolean): string { + let newName = name.replace(/-/g, "_"); + newName = newName.replace(/([a-z])([A-Z])/g, "$1_$2"); + return terminal ? newName.toUpperCase() : newName.toLowerCase(); + } + regex(pattern: string): string { + const escaped = pattern.replace(/(? Date: Mon, 10 Mar 2025 13:47:28 -0700 Subject: [PATCH 2/3] make simplify static --- src/ast.ts | 63 +++++++++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/ast.ts b/src/ast.ts index 87732b7..bde6978 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -12,17 +12,9 @@ export interface LLGrammar { grammars: Array; } -export abstract class ASTNode { - simplify(): ASTNode { - return this; - } -} +export abstract class ASTNode {} export abstract class GrammarNode extends ASTNode { - simplify(): GrammarNode { - return this; - } - children(): GrammarNode[] { return []; } @@ -69,16 +61,6 @@ export class SelectNode extends GrammarNode { isNull(): boolean { return this.alternatives.every((alt) => alt.isNull()); } - simplify(): GrammarNode { - if (this.isNull()) return new LiteralNode(""); - const alts = this.alternatives - .map((alt) => alt.simplify()) - .filter((alt) => !alt.isNull()); - const node = alts.length === 1 ? alts[0] : new SelectNode(alts); - if (this.alternatives.some((alt) => alt.isNull())) - return new RepeatNode(node, 0, 1); - return node; - } children(): GrammarNode[] { return this.alternatives; } @@ -91,13 +73,6 @@ export class JoinNode extends GrammarNode { isNull(): boolean { return this.nodes.every((node) => node.isNull()); } - simplify(): GrammarNode { - if (this.isNull()) return new LiteralNode(""); - const simplified = this.nodes - .map((node) => node.simplify()) - .filter((node) => !node.isNull()); - return simplified.length === 1 ? simplified[0] : new JoinNode(simplified); - } children(): GrammarNode[] { return this.nodes; } @@ -119,9 +94,6 @@ export class RepeatNode extends GrammarNode { children(): GrammarNode[] { return [this.node]; } - simplify(): GrammarNode { - return new RepeatNode(this.node.simplify(), this.min, this.max); - } } export class SubstringNode extends GrammarNode { @@ -267,7 +239,6 @@ export class LLSerializer { } } -// LarkSerializer export class LarkSerializer { public rules: { [name: string]: string } = {}; public names: Map = new Map(); @@ -335,7 +306,7 @@ export class LarkSerializer { attrs.push(`stop_capture=${JSON.stringify(node.stop_capture)}`); } if (attrs.length > 0) res += `[${attrs.join(", ")}]`; - res += ": " + this.visit(node.value.simplify(), true); + res += ": " + this.visit(simplify(node.value), true); this.rules[name] = res; return name; } @@ -401,3 +372,33 @@ export class LarkSerializer { return `/${escaped}/`; } } + +export function simplify(node: GrammarNode): GrammarNode { + if (node instanceof SelectNode) { + if (node.isNull()) return new LiteralNode(""); + const simplifiedAlts = node.alternatives + .map(simplify) + .filter((alt) => !alt.isNull()); + const simplifiedNode = + simplifiedAlts.length === 1 + ? simplifiedAlts[0] + : new SelectNode(simplifiedAlts); + if (node.alternatives.some((alt) => alt.isNull())) + return new RepeatNode(simplifiedNode, 0, 1); + return simplifiedNode; + } else if (node instanceof JoinNode) { + if (node.isNull()) return new LiteralNode(""); + const simplifiedNodes = node.nodes + .map(simplify) + .filter((child) => !child.isNull()); + return simplifiedNodes.length === 1 + ? simplifiedNodes[0] + : new JoinNode(simplifiedNodes); + } else if (node instanceof RepeatNode) { + return new RepeatNode(simplify(node.node), node.min, node.max); + } + if (node.children().length > 0) + throw new Error("Unexpected node type: " + node.constructor.name); + // For child-less nodes, return as is + return node; +} From e38f3034a49e2eee30f409aa9461b754d8ba374f Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Mon, 10 Mar 2025 18:14:44 -0700 Subject: [PATCH 3/3] remove old code --- src/api.ts | 217 +------------------- src/ast.ts | 60 +++--- src/cli.ts | 1 - src/gen.ts | 86 ++++---- src/grammarnode.ts | 455 ----------------------------------------- src/index.ts | 5 +- src/regexnode.ts | 97 --------- test/inferstop.test.ts | 35 ---- 8 files changed, 79 insertions(+), 877 deletions(-) delete mode 100644 src/grammarnode.ts delete mode 100644 src/regexnode.ts delete mode 100644 test/inferstop.test.ts diff --git a/src/api.ts b/src/api.ts index 3d56e1d..d98b360 100644 --- a/src/api.ts +++ b/src/api.ts @@ -1,181 +1,16 @@ /// This represents a collection of grammars, with a designated /// "start" grammar at first position. -/// Grammars can refer to each other via GrammarRef nodes. export interface TopLevelGrammar { grammars: GrammarWithLexer[]; max_tokens?: number; - test_trace?: boolean; } -export const DEFAULT_CONTEXTUAL: boolean = true; - -/// The start symbol is at nodes[0] export interface GrammarWithLexer { - nodes: NodeJSON[]; - - /// When enabled, the grammar can use `Lexeme` but not `Gen`. - /// When disabled, the grammar can use `Gen` but not `Lexeme`. - /// `String` is allowed in either case as a shorthand for either `Lexeme` or `Gen`. - greedy_lexer?: boolean; - - /// Only applies to greedy_lexer grammars. - /// This adds a new lexeme that will be ignored when parsing. - greedy_skip_rx?: RegexSpec; - - /// The default value for 'contextual' in Lexeme nodes. - contextual?: boolean; - - /// When set, the regexps can be referenced by their id (position in this list). - rx_nodes: RegexJSON[]; - - /// If set, the grammar will allow skip_rx as the first lexeme. - allow_initial_skip?: boolean; - - /// Normally, when a sequence of bytes is forced by grammar, it is tokenized - /// canonically and forced as tokens. - /// With `no_forcing`, we let the model decide on tokenization. - /// This generally reduces both quality and speed, so should not be used - /// outside of testing. - no_forcing?: boolean; - - /// If set, the grammar will allow invalid utf8 byte sequences. - /// Any Unicode regex will cause an error. - allow_invalid_utf8?: boolean; -} - -export type NodeJSON = - // Terminals: - /// Force generation of the specific string. - | { String: NodeString } - /// Generate according to regex. - | { Gen: NodeGen } - /// Lexeme in a greedy grammar. - | { Lexeme: NodeLexeme } - /// Generate according to specified grammar. - | { GenGrammar: NodeGenGrammar } - // Non-terminals: - /// Generate one of the options. - | { Select: NodeSelect } - /// Generate all of the nodes in sequence. - | { Join: NodeJoin }; - -/// Optional fields allowed on any Node -export interface NodeProps { - max_tokens?: number; name?: string; - capture_name?: string; -} - -export interface NodeString extends NodeProps { - literal: string; -} - -export interface NodeGen extends NodeProps { - /// Regular expression matching the body of generation. - body_rx: RegexSpec; - - /// The whole generation must match `body_rx + stop_rx`. - /// Whatever matched `stop_rx` is discarded. - /// If `stop_rx` is empty, it's assumed to be EOS. - stop_rx: RegexSpec; - - /// When set, the string matching `stop_rx` will be output as a capture - /// with the given name. - stop_capture_name?: string; - - /// Lazy gen()s take the shortest match. Non-lazy take the longest. - /// If not specified, the gen() is lazy if stop_rx is non-empty. - lazy?: boolean; - - /// Override sampling temperature. - temperature?: number; -} - -export interface NodeLexeme extends NodeProps { - /// The regular expression that will greedily match the input. - rx: RegexSpec; - - /// If false, all other lexemes are excluded when this lexeme is recognized. - /// This is normal behavior for keywords in programming languages. - /// Set to true for eg. a JSON schema with both `/"type"/` and `/"[^"]*"/` as lexemes, - /// or for "get"/"set" contextual keywords in C#. - contextual?: boolean; - - /// Override sampling temperature. - temperature?: number; - - /// When set, the lexeme will be quoted as a JSON string. - /// For example, /[a-z"]+/ will be quoted as /([a-z]|\\")+/ - json_string?: boolean; - - /// It lists the allowed escape sequences, typically one of: - /// "nrbtf\\\"u" - to allow all JSON escapes, including \u00XX for control characters - /// this is the default - /// "nrbtf\\\"" - to disallow \u00XX control characters - /// "nrt\\\"" - to also disallow unusual escapes (\f and \b) - /// "" - to disallow all escapes - /// Note that \uXXXX for non-control characters (code points above U+001F) are never allowed, - /// as they never have to be quoted in JSON. - json_allowed_escapes?: string; - - /// When set and json_string is also set, "..." will not be added around the regular expression. - json_raw?: boolean; -} - -export interface NodeGenGrammar extends NodeProps { - grammar: GrammarId; - - /// Override sampling temperature. - temperature?: number; -} - -export interface NodeSelect extends NodeProps { - among: NodeId[]; + lark_grammar?: string; + json_schema?: Record; } -export interface NodeJoin extends NodeProps { - sequence: NodeId[]; -} - -export type RegexJSON = - /// Intersection of the regexes - | { And: RegexId[] } - /// Union of the regexes - | { Or: RegexId[] } - /// Concatenation of the regexes - | { Concat: RegexId[] } - /// Matches the regex; should be at the end of the main regex. - /// The length of the lookahead can be recovered from the engine. - | { LookAhead: RegexId } - /// Matches everything the regex doesn't match. - /// Can lead to invalid utf8. - | { Not: RegexId } - /// Repeat the regex at least min times, at most max times - | { Repeat: [RegexId, number, number?] } - /// Matches the empty string. Same as Concat([]). - // | "EmptyString" - /// Matches nothing. Same as Or([]). - // | "NoMatch" - /// Compile the regex using the regex_syntax crate - | { Regex: string } - /// Matches this string only - | { Literal: string } - /// Matches this string of bytes only. Can lead to invalid utf8. - | { ByteLiteral: number[] } - /// Matches this byte only. If byte is not in 0..127, it may lead to invalid utf8 - | { Byte: number } - /// Matches any byte in the set, expressed as bitset. - /// Can lead to invalid utf8 if the set is not a subset of 0..127 - | { ByteSet: number[] }; - -// The actual wire format allows for direct strings, but we always use nodes -// TODO-SERVER -export type RegexSpec = string | RegexId; - -export type GrammarId = number; -export type NodeId = number; -export type RegexId = number; - // Output of llguidance parser export interface BytesOutput { @@ -233,51 +68,3 @@ export interface ParserStats { hidden_bytes: number; } -// AICI stuff: - -export interface RunUsageResponse { - sampled_tokens: number; - ff_tokens: number; - cost: number; -} - -export interface InitialRunResponse { - id: string; - object: "initial-run"; - created: number; - model: string; -} - -export interface RunResponse { - object: "run"; - forks: RunForkResponse[]; - usage: RunUsageResponse; -} - -export interface RunForkResponse { - index: number; - finish_reason?: string; - text: string; - error: string; - logs: string; - storage: any[]; - micros: number; -} - -export type AssistantPromptRole = "system" | "user" | "assistant"; - -export interface AssistantPrompt { - role: AssistantPromptRole; - content: string; -} - -export interface RunRequest { - controller: string; - controller_arg: { grammar: TopLevelGrammar }; - prompt?: string; // Optional with a default value - messages?: AssistantPrompt[]; // Optional with a default value - temperature?: number; // Optional with a default value of 0.0 - top_p?: number; // Optional with a default value of 1.0 - top_k?: number; // Optional with a default value of -1 - max_tokens?: number; // Optional with a default value based on context size -} diff --git a/src/ast.ts b/src/ast.ts index bde6978..6f003e4 100644 --- a/src/ast.ts +++ b/src/ast.ts @@ -15,7 +15,7 @@ export interface LLGrammar { export abstract class ASTNode {} export abstract class GrammarNode extends ASTNode { - children(): GrammarNode[] { + children(): ReadonlyArray { return []; } @@ -34,9 +34,19 @@ export abstract class GrammarNode extends ASTNode { return this.children().every((child) => child.isTerminal()); } - ll_grammar(): LLGrammar { + llGrammar(): LLGrammar { return new LLSerializer().serialize(this); } + + static from(grammar: string | GrammarNode | RegExp): GrammarNode { + if (typeof grammar === "string") { + return new LiteralNode(grammar); + } + if (grammar instanceof RegExp) { + return new RegexNode(grammar.source); + } + return grammar; + } } export class LiteralNode extends GrammarNode { @@ -61,7 +71,7 @@ export class SelectNode extends GrammarNode { isNull(): boolean { return this.alternatives.every((alt) => alt.isNull()); } - children(): GrammarNode[] { + children(): ReadonlyArray { return this.alternatives; } } @@ -73,7 +83,7 @@ export class JoinNode extends GrammarNode { isNull(): boolean { return this.nodes.every((node) => node.isNull()); } - children(): GrammarNode[] { + children(): ReadonlyArray { return this.nodes; } } @@ -82,16 +92,16 @@ export class RepeatNode extends GrammarNode { constructor( public node: GrammarNode, public min: number, - public max?: number + public max: number | null ) { super(); if (min < 0) throw new Error("min must be >= 0"); - if (max !== undefined && max < min) throw new Error("max must be >= min"); + if (max !== null && max < min) throw new Error("max must be >= min"); } isNull(): boolean { return this.node.isNull() || (this.min === 0 && this.max === 0); } - children(): GrammarNode[] { + children(): ReadonlyArray { return [this.node]; } } @@ -114,20 +124,20 @@ export class SubstringNode extends GrammarNode { */ export class RuleNode extends GrammarNode { public capture?: string; - public list_append: boolean = false; + public listAppend: boolean = false; public temperature?: number; - public max_tokens?: number; - public stop?: RegexNode | LiteralNode; - public suffix?: LiteralNode; - public stop_capture?: string; + public maxTokens?: number; + public stop?: GrammarNode; + public suffix?: GrammarNode; + public stopCapture?: string; constructor(public name: string, public value: GrammarNode) { super(); if ( (this.temperature !== undefined || - this.max_tokens !== undefined || + this.maxTokens !== undefined || this.stop !== undefined || this.suffix !== undefined || - this.stop_capture !== undefined) && + this.stopCapture !== undefined) && !(this.value.isTerminal() || this.value instanceof BaseSubgrammarNode) ) { throw new Error( @@ -139,15 +149,15 @@ export class RuleNode extends GrammarNode { return ( this.capture === undefined && this.temperature === undefined && - this.max_tokens === undefined && + this.maxTokens === undefined && this.stop === undefined && this.suffix === undefined && - this.stop_capture === undefined && + this.stopCapture === undefined && this.value.isTerminal() && !(this.value instanceof BaseSubgrammarNode) ); } - children(): GrammarNode[] { + children(): ReadonlyArray { return [this.value]; } } @@ -283,7 +293,7 @@ export class LarkSerializer { const attrs: string[] = []; if (node.capture !== undefined) { let captureName = node.capture; - if (node.list_append) { + if (node.listAppend) { captureName = `__LIST_APPEND:${captureName}`; } attrs.push(`capture=${JSON.stringify(captureName)}`); @@ -293,8 +303,8 @@ export class LarkSerializer { if (node.temperature !== undefined) { attrs.push(`temperature=${node.temperature}`); } - if (node.max_tokens !== undefined) { - attrs.push(`max_tokens=${node.max_tokens}`); + if (node.maxTokens !== undefined) { + attrs.push(`max_tokens=${node.maxTokens}`); } if (node.stop) { attrs.push(`stop=${this.visit(node.stop)}`); @@ -302,8 +312,8 @@ export class LarkSerializer { if (node.suffix) { attrs.push(`suffix=${this.visit(node.suffix)}`); } - if (node.stop_capture !== undefined) { - attrs.push(`stop_capture=${JSON.stringify(node.stop_capture)}`); + if (node.stopCapture !== undefined) { + attrs.push(`stop_capture=${JSON.stringify(node.stopCapture)}`); } if (attrs.length > 0) res += `[${attrs.join(", ")}]`; res += ": " + this.visit(simplify(node.value), true); @@ -343,10 +353,10 @@ export class LarkSerializer { if (node.node instanceof JoinNode || node.node instanceof RepeatNode) { inner = `(${inner})`; } - if (node.min === 0 && node.max === undefined) return `${inner}*`; - if (node.min === 1 && node.max === undefined) return `${inner}+`; + if (node.min === 0 && node.max === null) return `${inner}*`; + if (node.min === 1 && node.max === null) return `${inner}+`; if (node.min === 0 && node.max === 1) return `${inner}?`; - if (node.max === undefined) return `${inner}{${node.min},}`; + if (node.max === null) return `${inner}{${node.min},}`; return `${inner}{${node.min},${node.max}}`; } if (node instanceof SubstringNode) { diff --git a/src/cli.ts b/src/cli.ts index 4ef02b7..d125c92 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -1,6 +1,5 @@ import { gen, - GrammarNode, grm, keyword, lexeme, diff --git a/src/gen.ts b/src/gen.ts index 5f2fbf1..a0c4363 100644 --- a/src/gen.ts +++ b/src/gen.ts @@ -1,19 +1,16 @@ import { - Gen, + LiteralNode, + RegexNode, GrammarNode, - Grammar, - Join, - Lexeme, - Select, - StringLiteral, -} from "./grammarnode"; -import { RegexNode, BaseNode } from "./regexnode"; + RuleNode, + JoinNode, + SelectNode, + RepeatNode, +} from "./ast"; import { assert } from "./util"; -export { GrammarNode, RegexNode, BaseNode }; -export type { Grammar }; - -export type RegexDef = RegExp | RegexNode; +export type RegexDef = RegExp | GrammarNode; +export type Grammar = string | GrammarNode; export interface GenOptions { name?: string; @@ -29,7 +26,7 @@ function isPlainObject(obj: any): boolean { } function isRegexDef(obj: any): boolean { - return obj instanceof RegExp || obj instanceof RegexNode; + return obj instanceof RegExp || obj instanceof GrammarNode; } export function gen(options?: GenOptions): GrammarNode; @@ -50,72 +47,69 @@ export function gen(...args: any[]): GrammarNode { if (isPlainObject(args[0])) options = args.shift(); assert(args.length == 0); - const stop = !options.stop - ? undefined - : typeof options.stop == "string" - ? RegexNode.literal(options.stop) - : RegexNode.from(options.stop); - const g = new Gen(RegexNode.from(regex ?? options.regex ?? /.*/), stop); - if (options.maxTokens !== undefined) g.maxTokens = options.maxTokens; - if (options.temperature !== undefined) g.temperature = options.temperature; + const stop = !options.stop ? undefined : GrammarNode.from(options.stop); + name ??= options.name; - if (name !== undefined) { - if (options.listAppend) name = Gen.LIST_APPEND_PREFIX + name; - // TODO-SERVER: capture name on gen doesn't work - const r = new Join([g]); - r.captureName = name; - return r; - } + const body = RegexNode.from(regex ?? options.regex ?? /.*/); + const g = new RuleNode(name ?? "r", body); + + if (options.maxTokens !== undefined) g.maxTokens = options.maxTokens; + if (options.temperature !== undefined) g.temperature = options.temperature; + if (options.listAppend) g.listAppend = true; + if (name !== undefined) g.capture = name; + g.stop = stop; return g; } export function capture(name: string, grammar: Grammar) { - const r = new Join([GrammarNode.from(grammar)]); - r.captureName = name; - return r; + const g = new RuleNode(name, GrammarNode.from(grammar)); + return g; } export function select(...values: Grammar[]) { - return new Select(values.map(GrammarNode.from)); + return new SelectNode(values.map(GrammarNode.from)); } export function join(...values: Grammar[]) { - return new Join(values.map(GrammarNode.from)); + return new JoinNode(values.map(GrammarNode.from)); } export function lexeme(rx: RegexDef) { - return new Lexeme(RegexNode.from(rx)); + return RegexNode.from(rx); } export function keyword(s: string) { - return new Lexeme(RegexNode.literal(s), true); + return new LiteralNode(s); } export function str(s: string) { - return new StringLiteral(s); + return new LiteralNode(s); +} + +export function repeat(g: Grammar, min: number, max: number | null) { + return new RepeatNode(GrammarNode.from(g), min, max); } export function oneOrMore(g: Grammar) { - const inner = GrammarNode.from(g); - const n = new Select([inner]); - n.among.push(join(n, inner)); - return n; + return repeat(g, 1, null); } export function zeroOrMore(g: Grammar) { - const n = new Select([str("")]); - n.among.push(join(n, g)); - return n; + return repeat(g, 0, null); +} + +export function optional(g: Grammar) { + return repeat(g, 0, 1); } function concatStrings(acc: GrammarNode[]) { for (let i = 1; i < acc.length; ++i) { const a = acc[i - 1]; const b = acc[i]; - if (a instanceof StringLiteral && b instanceof StringLiteral) { - acc[i - 1] = str(a.literal + b.literal); + if (a instanceof LiteralNode && b instanceof LiteralNode) { + acc[i - 1] = str(a.value + b.value); acc.splice(i, 1); i--; } @@ -213,5 +207,5 @@ export function grm( if (acc.length == 0) return str(""); else if (acc.length == 1) return acc[0]; - else return new Join(acc); + else return new JoinNode(acc); } diff --git a/src/grammarnode.ts b/src/grammarnode.ts deleted file mode 100644 index 140aca4..0000000 --- a/src/grammarnode.ts +++ /dev/null @@ -1,455 +0,0 @@ -import { - GrammarId, - GrammarWithLexer, - NodeJSON, - NodeProps, - RegexJSON, - RegexSpec, - TopLevelGrammar, -} from "./api"; -import { BaseNode, RegexNode } from "./regexnode"; -import { assert } from "./util"; - -export type Grammar = GrammarNode | string; - -export abstract class GrammarNode extends BaseNode { - maxTokens?: number; - captureName?: string; - nullable = false; - - static LIST_APPEND_PREFIX = "__LIST_APPEND:"; - - protected constructor() { - super(); - } - - abstract serializeInner(s: Serializer): NodeJSON; - - serialize(): TopLevelGrammar { - const gens = this.rightNodes().filter( - (n) => n instanceof Gen && n.stop === undefined - ) as Gen[]; - for (const g of gens) { - if (g._inferStop === undefined) g._inferStop = ""; - } - return new NestedGrammar(this).serialize(); - } - - join(other: Grammar): Join { - return new Join([this, GrammarNode.from(other)]); - } - - toString() { - return this.pp(); - } - - pp() { - const useCount = new Map(); - - { - const visited = new Set(); - const visit = (n: GrammarNode) => { - const v = useCount.get(n) ?? 0; - useCount.set(n, v + 1); - if (visited.has(n)) return; - visited.add(n); - n.getChildren()?.forEach(visit); - }; - visit(this); - } - - { - const visited = new Set(); - const visit = (n: GrammarNode) => { - if (visited.has(n)) return "#" + n.id; - visited.add(n); - const ch = n.getChildren()?.map(visit); - let res = useCount.get(n) > 1 ? `#${n.id}: ` : ``; - res += n.ppInner(ch); - if (ch) return `(${res})`; - else return res; - }; - return visit(this); - } - } - - getChildren(): readonly GrammarNode[] | undefined { - return undefined; - } - - leftNodes(): GrammarNode[] { - return Array.from(this.sideNodes(false)); - } - - rightNodes(): GrammarNode[] { - return Array.from(this.sideNodes(true)); - } - - private sideNodes(right: boolean): Set { - const r = new Set(); - const todo: GrammarNode[] = [this]; - while (todo.length > 0) { - const e = todo.pop(); - if (r.has(e)) continue; - r.add(e); - if (e instanceof Select) { - todo.push(...e.getChildren()); - } else if (e instanceof Join) { - const elts = e.sequence.slice(); - if (right) elts.reverse(); - for (const e of elts) { - todo.push(e); - if (!e.nullable) break; - } - } - } - return r; - } - - abstract ppInner(children?: string[]): string; - - static from(s: Grammar) { - if (typeof s === "string") return new StringLiteral(s); - return s; - } -} - -function ppProps(n: GrammarNode & { temperature?: number }) { - let res = ""; - if (n.maxTokens !== undefined) res += ` maxTokens:${n.maxTokens}`; - if (n.temperature !== undefined) res += ` temp:${n.temperature}`; - if (n.captureName !== undefined) - res += ` name:${JSON.stringify(n.captureName)}`; - return res; -} - -export class Gen extends GrammarNode { - public temperature?: number; - public lazy?: boolean; - public _inferStop?: string; - - constructor(public regex: RegexNode, public stop?: RegexNode) { - super(); - Serializer.checkRegex(this.regex); - Serializer.checkRegex(this.stop); - // if (!stop) this.lazy = true; - } - - override ppInner() { - const stop = this.stop - ? this.stop.pp() - : this._inferStop !== undefined - ? JSON.stringify(this._inferStop) - : "???"; - return ( - `gen(` + - `regex:${this.regex.pp()} ` + - `stop:${stop}` + - ppProps(this) + - `)` - ); - } - - override serializeInner(s: Serializer): NodeJSON { - const stop = this.stop - ? s.regex(this.stop) - : this._inferStop !== undefined - ? this._inferStop - : undefined; - if (stop === undefined) throw new Error(`can't infer gen({ stop: ... })`); - return { - Gen: { - body_rx: s.regex(this.regex), - // TODO-SERVER: passing noMatch doesn't work - need "" - stop_rx: stop, - temperature: this.temperature, - lazy: this.lazy, - }, - }; - } -} - -export class Select extends GrammarNode { - constructor(public among: GrammarNode[]) { - super(); - this.nullable = this.among.some((e) => e.nullable); - } - - override getChildren() { - return this.among; - } - - override ppInner(children?: string[]) { - return children.join(" | ") + ppProps(this); - } - - override serializeInner(s: Serializer): NodeJSON { - return { - Select: { - among: this.among.map(s.serialize), - }, - }; - } -} - -export class Join extends GrammarNode { - constructor(public sequence: GrammarNode[]) { - super(); - this.nullable = this.sequence.every((e) => e.nullable); - this.inferGenStop(); - } - - private inferGenStop() { - for (let i = 0; i < this.sequence.length - 1; ++i) { - const gens = this.sequence[i] - .rightNodes() - .filter((n) => n instanceof Gen && n.stop === undefined) as Gen[]; - if (gens.length == 0) continue; - const infer: string[] = []; - for (const n of this.sequence[i + 1].leftNodes()) { - if (n instanceof Select || n instanceof Join) continue; - if (n instanceof StringLiteral) { - infer.push(n.literal); - } else { - throw new Error( - `can't infer gen({ stop: ... }): ${gens[0].pp()} followed by ${n.pp()}` - ); - } - } - if (infer.length > 0) { - const inferStop = infer[0][0]; - if (infer.some((x) => x[0] != inferStop)) - throw new Error( - `can't infer gen({ stop: ...}): ${gens[0].pp()} is followed by ${infer}` - ); - for (const g of gens) { - if (g._inferStop !== undefined && g._inferStop != inferStop) - throw new Error( - `can't infer gen({ stop: ...}): ${g.pp()} already has ${ - g._inferStop - }; setting to ${inferStop}` - ); - g._inferStop = inferStop; - } - } - } - } - - override ppInner(children?: string[]) { - return children.join(" + ") + ppProps(this); - } - - override getChildren() { - return this.sequence; - } - - override serializeInner(s: Serializer): NodeJSON { - return { - Join: { - sequence: this.sequence.map(s.serialize), - }, - }; - } -} - -export class StringLiteral extends GrammarNode { - constructor(public literal: string) { - super(); - this.nullable = this.literal == ""; - } - - override ppInner() { - return JSON.stringify(this.literal) + ppProps(this); - } - - override serializeInner(s: Serializer): NodeJSON { - return { - String: { - literal: this.literal, - }, - }; - } -} - -export class Lexeme extends GrammarNode { - public temperature?: number; - - constructor(public rx: RegexNode, public contextual?: boolean) { - super(); - Serializer.checkRegex(rx); - } - - override ppInner() { - const kw = this.contextual ? "keyword" : "lexeme"; - return `${kw}(${this.rx.pp()}${ppProps(this)})`; - } - - override serializeInner(s: Serializer): NodeJSON { - return { - Lexeme: { - rx: s.regex(this.rx), - contextual: this.contextual, - temperature: this.temperature, - }, - }; - } -} - -export class NestedGrammar extends GrammarNode { - public temperature?: number; - - constructor(public start: GrammarNode, public skip_rx?: RegexNode) { - super(); - Serializer.checkRegex(this.skip_rx); - } - - override serialize(): TopLevelGrammar { - return Serializer.grammar(this); - } - - override getChildren() { - return [this.start]; - } - - override serializeInner(s: Serializer): NodeJSON { - const gid = s.saveGrammar(this); - return { - GenGrammar: { - grammar: gid, - temperature: this.temperature, - }, - }; - } - - override ppInner(children?: string[]): string { - return `grammar(${children[0]})`; - } -} - -function nodeProps(node: NodeJSON): NodeProps { - return Object.values(node)[0]; -} - -class Serializer { - private nodeCache: Map = new Map(); - private grammarCache: Map = new Map(); - private rxCache: Map = new Map(); - private grmNodes: NodeJSON[] = []; - private rxNodes: RegexJSON[] = []; - private rxHashCons: Map = new Map(); - private grammars: GrammarWithLexer[] = []; - private grammarSrc: NestedGrammar[] = []; - - constructor() { - this.serialize = this.serialize.bind(this); - this.regex = this.regex.bind(this); - } - - saveGrammar(n: NestedGrammar) { - let gid = this.grammarCache.get(n); - if (gid !== undefined) return gid; - gid = this.grammars.length; - this.grammars.push({ - greedy_lexer: false, // no longer used - contextual: false, - greedy_skip_rx: undefined, - rx_nodes: [], - nodes: [], - }); - this.grammarSrc.push(n); - this.grammarCache.set(n, gid); - return gid; - } - - private fixpoint() { - // note that this.grammarSrc.length grows during this loop - for (let i = 0; i < this.grammarSrc.length; ++i) { - const g = this.grammars[i]; - this.grmNodes = g.nodes; - this.rxNodes = g.rx_nodes; - this.rxHashCons.clear(); - this.nodeCache.clear(); - this.rxCache.clear(); - const s = this.grammarSrc[i]; - const id = this.serialize(s.start); - if (s.skip_rx) g.greedy_skip_rx = this.regex(s.skip_rx); - assert(id == 0); - } - } - - static grammar(top: NestedGrammar): TopLevelGrammar { - const s = new Serializer(); - const id = s.saveGrammar(top); - assert(id == 0); - s.fixpoint(); - return { - grammars: s.grammars, - }; - } - - /** - * Throws is regex is recursive. - */ - static checkRegex(n: RegexNode) { - const s = new Serializer(); - s.regex(n); - } - - serialize(n: GrammarNode): GrammarId { - let sid = this.nodeCache.get(n); - if (sid !== undefined) return sid; - sid = this.grmNodes.length; - this.grmNodes.push(null); - this.nodeCache.set(n, sid); - const serial = n.serializeInner(this); - const props = nodeProps(serial); - if (n.maxTokens !== undefined) props.max_tokens = n.maxTokens; - if (n.captureName !== undefined) props.capture_name = n.captureName; - this.grmNodes[sid] = serial; - return sid; - } - - regex(top?: RegexNode): RegexSpec { - if (top === undefined) top = RegexNode.noMatch(); - - const cache = this.rxCache; - - const lookup = (n: RegexNode) => { - const r = cache.get(n); - if (r == null) throw new Error("circular regex"); - assert(typeof r == "number"); - return r; - }; - - const missing = (n: RegexNode) => !cache.has(n); - - const todo = [top]; - while (todo.length > 0) { - const n = todo.pop(); - let sid = cache.get(n); - if (sid === null) throw new Error("circular regex"); - cache.set(n, null); - - const unfinished = n.getChildren()?.filter(missing); - if (unfinished?.length) { - unfinished.reverse(); - todo.push(n, ...unfinished); - continue; - } - - const serial = n.serializeInner(lookup); - const key = JSON.stringify(serial); - const cached = this.rxHashCons.get(key); - if (cached !== undefined) { - sid = cached; - } else { - sid = this.rxNodes.length; - this.rxNodes.push(serial); - this.rxHashCons.set(key, sid); - } - cache.set(n, sid); - } - - return lookup(top); - } -} diff --git a/src/index.ts b/src/index.ts index 1bfddf7..363f9b2 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,8 +1,7 @@ export * from "./gen"; -export { BaseNode, RegexNode } from "./regexnode"; export { Session, Generation } from "./client"; export type { GenerationOptions } from "./client"; -import * as grammar from "./grammarnode"; +import * as ast from "./ast"; import * as api from "./api"; -export { grammar, api }; +export { ast, api }; diff --git a/src/regexnode.ts b/src/regexnode.ts deleted file mode 100644 index 45d6bc5..0000000 --- a/src/regexnode.ts +++ /dev/null @@ -1,97 +0,0 @@ -import { RegexJSON } from "./api"; -import { assert, panic } from "./util"; - -export class BaseNode { - static nextNodeId = 1; - - id: number; - - protected constructor() { - this.id = BaseNode.nextNodeId++; - } -} - -const REGEX_PLACEHOLDER = -1000; -export class RegexNode extends BaseNode { - private constructor( - private simple: RegexJSON, - private children?: RegexNode[] - ) { - super(); - } - - getChildren(): RegexNode[] | undefined { - return this.children; - } - - serializeInner(rec: (n: RegexNode) => number): RegexJSON { - const simple = JSON.parse(JSON.stringify(this.simple)); - const key = Object.keys(simple)[0]; - let arg = Object.values(simple)[0]; - - if (this.children) { - const mapped = this.children.map(rec); - if (Array.isArray(arg)) { - arg = mapped.concat(arg); - } else if (arg === REGEX_PLACEHOLDER) { - assert(mapped.length == 1); - arg = mapped[0]; - } else { - panic(); - } - simple[key] = arg; - } - - return simple; - } - - static literal(s: string) { - return new RegexNode({ Literal: s }); - } - - static regex(s: string | RegExp) { - if (typeof s != "string") s = s.source; - return new RegexNode({ Regex: s }); - } - - static noMatch() { - return new RegexNode({ Or: [] }); - } - - static from(s: undefined | RegexNode | RegExp) { - if (s === undefined) return RegexNode.noMatch(); - if (s instanceof RegexNode) return s; - return RegexNode.regex(s); - } - - asRegexString() { - if ("Regex" in this.simple) { - return this.simple["Regex"]; - } else { - return null; - } - } - - pp() { - let res = ""; - - const visit = (n: RegexNode) => { - if (res.length > 1024) return; - if ("Literal" in n.simple) { - const lit = n.simple["Literal"]; - res += JSON.stringify(lit); - } else if ("Regex" in n.simple) { - const rx = n.simple["Regex"]; - res += "" + new RegExp(rx); - } else { - res += JSON.stringify(this.simple); - } - }; - - visit(this); - - if (res.length > 1024) res += "..."; - - return res; - } -} diff --git a/test/inferstop.test.ts b/test/inferstop.test.ts deleted file mode 100644 index 31c0318..0000000 --- a/test/inferstop.test.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { - gen, - GrammarNode, - grm, - keyword, - lexeme, - oneOrMore, - select, - str, - join, -} from "../src/index"; - -import { test } from "uvu"; -import * as assert from "uvu/assert"; - -test("infer stop", () => { - function tstInfer(stop: string, mk: (stop: string) => GrammarNode) { - assert.equal(mk(stop).serialize(), mk(undefined).serialize()); - } - - tstInfer(".", (stop) => grm`${gen({ stop })}.`); - tstInfer(".", (stop) => grm`${gen({ stop })}${select(".a", ".b")}`); - tstInfer("", (stop) => grm`${gen({ stop })}`); - -}); - -test("infer error", () => { - // assert.throws(() => grm`${gen()}`.serialize(), /infer/); - assert.throws(() => grm`${gen()}${gen()}`, /followed/); - assert.throws(() => grm`${gen()}${select("XXX", "YYY")}`, /XXX/); - const g = gen(); - assert.throws(() => grm`${g}X${g}Y`, /X/); -}); - -test.run();