Skip to content

Commit 5ea5557

Browse files
actions: instruct=False support
1 parent a083489 commit 5ea5557

File tree

3 files changed

+73
-70
lines changed

3 files changed

+73
-70
lines changed

src/lmql/language/compiler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def scope(self, query: LMQLQuery):
118118

119119
# also collect variable reads from where clause
120120
if query.where is not None:
121-
FreeVarCollector(self.free_vars, exclude_criteria=[self.exclude_identifier]).visit(query.where)
121+
self.visit_where(query.where)
122122
if query.from_ast is not None:
123123
FreeVarCollector(self.free_vars, exclude_criteria=[self.exclude_identifier]).visit(query.from_ast)
124124
if query.decode is not None:
@@ -131,6 +131,9 @@ def scope(self, query: LMQLQuery):
131131

132132
query.scope = self
133133

134+
def visit_where(self, node):
135+
FreeVarCollector(self.free_vars, exclude_criteria=[self.exclude_identifier]).visit(node)
136+
134137
def visit_Expr(self, expr):
135138
if type(expr.value) is ast.Constant:
136139
self.scope_Constant(expr.value)
@@ -140,6 +143,8 @@ def visit_Expr(self, expr):
140143
def visit_BoolOp(self, node: ast.BoolOp) -> Any:
141144
if is_query_string_with_constraints(node):
142145
self.scope_Constant(node.values[0])
146+
for constraint in node.values[1:]:
147+
self.visit_where(constraint)
143148
else:
144149
super().generic_visit(node)
145150

@@ -519,7 +524,9 @@ def transform_node(self, expr, snf):
519524
assert type(expr.func) is ast.Name, "In LMQL constraint expressions, only function calls to direct function references are allowed: {}".format(astunparse.unparse(expr))
520525
tfunc = ast.unparse(expr.func)
521526
targs = [self.transform_node(a, snf) for a in expr.args]
522-
targs_list = ", ".join(targs)
527+
kwargs = [f"('__kw:{e.arg}', {self.transform_node(e.value, snf)})" for e in expr.keywords]
528+
targs_list = ", ".join(targs + kwargs)
529+
523530
return f"{OPS_NAMESPACE}.CallOp([{tfunc}, [{targs_list}]], locals(), globals())"
524531
elif type(expr) is ast.List:
525532
return self.default_transform_node(expr, snf).strip()

src/lmql/lib/actions.py

Lines changed: 56 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -74,86 +74,76 @@ def make_fct(f):
7474
@lmql.query
7575
async def fct_call(fcts):
7676
'''lmql
77-
incontext
78-
action_fcts = {str(f.__name__): make_fct(f) for f in fcts}
79-
"[CALL]"
80-
truncated = CALL
81-
if not CALL.endswith("|") and not CALL.endswith(DELIMITER_END):
77+
action_fcts = {str(f.__name__): make_fct(f) for f in fcts}
78+
"[CALL]" where STOPS_AT(CALL, "|") and STOPS_AT(CALL, DELIMITER_END)
79+
80+
truncated = CALL
81+
if not CALL.endswith("|") and not CALL.endswith(DELIMITER_END):
82+
return CALL
83+
else:
84+
if CALL.endswith(DELIMITER_END):
85+
CALL = CALL[:-len(DELIMITER_END)]
86+
else:
87+
CALL = CALL[:-len("|")]
88+
89+
if "(" not in CALL:
8290
return CALL
83-
else:
84-
if CALL.endswith(DELIMITER_END):
85-
CALL = CALL[:-len(DELIMITER_END)]
86-
else:
87-
CALL = CALL[:-len("|")]
88-
89-
if "(" not in CALL:
90-
return CALL
91-
92-
action, args = CALL.split("(", 1)
93-
action = action.strip()
94-
if action not in action_fcts.keys():
95-
print("unknown action", [action], list(action_fcts.keys()))
96-
" Unknown action: {action} {DELIMITER_END}"
97-
result = ""
91+
92+
action, args = CALL.split("(", 1)
93+
action = action.strip()
94+
if action not in action_fcts.keys():
95+
print("unknown action", [action], list(action_fcts.keys()))
96+
" Unknown action: {action} {DELIMITER_END}"
97+
result = ""
98+
return "(error)"
99+
else:
100+
try:
101+
result = await action_fcts[action].call("(" + args.strip())
102+
return DELIMITER + str(CALL) + "| " + str(result)
103+
except Exception:
104+
result = "Error."
98105
return "(error)"
99-
else:
100-
try:
101-
result = await action_fcts[action].call("(" + args.strip())
102-
return DELIMITER + str(CALL) + "| " + str(result)
103-
except Exception:
104-
result = "Error."
105-
return "(error)"
106-
where
107-
STOPS_AT(CALL, "|") and STOPS_AT(CALL, DELIMITER_END)
108106
'''
109107

110108

111109
@lmql.query
112110
async def inline_segment(fcts):
113111
'''lmql
114-
incontext
115-
"[SEGMENT]"
116-
if not SEGMENT.endswith(DELIMITER):
117-
return SEGMENT
118-
else:
119-
"[CALL]"
120-
result = CALL.split("|", 1)[1]
121-
return SEGMENT[:-len(DELIMITER)] + CALL + DELIMITER_END
122-
where
123-
STOPS_AT(SEGMENT, DELIMITER) and fct_call(CALL, fcts)
112+
"[SEGMENT]" where STOPS_AT(SEGMENT, DELIMITER)
113+
if not SEGMENT.endswith(DELIMITER):
114+
return SEGMENT
115+
else:
116+
"[CALL]" where fct_call(CALL, fcts) and len(TOKENS(CALL)) > 0
117+
result = CALL.split("|", 1)[1]
118+
return SEGMENT[:-len(DELIMITER)] + CALL + DELIMITER_END
124119
'''
125120

126121
@lmql.query
127-
async def inline_use(fcts):
122+
async def inline_use(fcts, instruct=True):
128123
'''lmql
129-
incontext
130-
action_fcts = {str(f.__name__): make_fct(f) for f in fcts}
131-
first_tool_name = list(action_fcts.keys())[0] if len(action_fcts) > 0 else "tool"
124+
action_fcts = {str(f.__name__): make_fct(f) for f in fcts}
125+
first_tool_name = list(action_fcts.keys())[0] if len(action_fcts) > 0 else "tool"
132126
133-
# add instruction prompt if no few-shot prompt was already used
134-
if not INLINE_USE_PROMPT in context.prompt:
135-
"""
136-
\n\n{:system} Instructions: In your reasoning, you can use the following tools:"""
137-
138-
for fct in action_fcts.values():
139-
"\n - {fct.name}: {fct.description} Usage: {DELIMITER}{fct.example} | {fct.example_result}{DELIMITER_END}"
140-
' Example Use: ... this means they had <<calc("5-2") | 3 >> 3 apples left...\n'
141-
" You can also use the tools multiple times in one reasoning step.\n\n"
142-
"Reasoning with Tools: {:assistant}"
143-
else:
144-
"\n\nInline Tool Use:\n\n"
127+
# add instruction prompt if no few-shot prompt was already used
128+
if instruct and not INLINE_USE_PROMPT in context.prompt:
129+
"""
130+
\n\nInstructions: In your reasoning, you can use the following tools:"""
145131
146-
# decode segment-by-segment, handling action calls along the way
147-
truncated = ""
148-
while True:
149-
"[SEGMENT]"
150-
if not SEGMENT.endswith(DELIMITER_END):
151-
" " # seems to be needed for now
152-
return truncated + SEGMENT
153-
truncated += SEGMENT
154-
return truncated
155-
where
156-
inline_segment(SEGMENT, fcts)
132+
for fct in action_fcts.values():
133+
"\n - {fct.name}: {fct.description} Usage: {DELIMITER}{fct.example} | {fct.example_result}{DELIMITER_END}"
134+
' Example Use: ... this means they had <<calc("5-2") | 3 >> 3 apples left...\n'
135+
" You can also use the tools multiple times in one reasoning step.\n\n"
136+
"Reasoning with Tools:\n\n"
137+
138+
# decode segment-by-segment, handling action calls along the way
139+
truncated = ""
140+
while True:
141+
"[SEGMENT]" where inline_segment(SEGMENT, fcts)
142+
if not SEGMENT.endswith(DELIMITER_END):
143+
" " # seems to be needed for now
144+
return truncated + SEGMENT
145+
truncated += SEGMENT
146+
return truncated
157147
'''
158148
inline_use.demonstrations = INLINE_USE_PROMPT
159149

src/lmql/ops/inline_call.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ def forward(self, *args, **kwargs):
5757
def subinterpreter(self, runtime, prompt, args = None):
5858
query_kwargs = None
5959
if args is not None:
60-
query_kwargs, _ = self.query_fct.make_kwargs(*args[1:])
60+
args = args[1:]
61+
def is_kwarg(a):
62+
return type(a) is tuple and len(a) == 2 and type(a[0]) is str and a[0].startswith("__kw:")
63+
kwargs = {k[0][len("__kw:"):]: k[1] for k in args if is_kwarg(k)}
64+
args = [a for a in args if not is_kwarg(a)]
65+
66+
query_kwargs, _ = self.query_fct.make_kwargs(*args, **kwargs)
6167
self.query_kwargs = query_kwargs
6268
else:
6369
if self.query_kwargs is None:
@@ -89,7 +95,7 @@ def postprocess(self, operands, value):
8995
return si
9096

9197
def postprocess_order(self, other, operands, other_inputs, **kwargs):
92-
return 0 # other constraints cannot be compared
98+
return "before" # other constraints cannot be compared
9399

94100
@staticmethod
95101
def collect(op: Node):

0 commit comments

Comments
 (0)