Skip to content

Commit 8fa1d3e

Browse files
add new COMPLETION command
1 parent c6fb3f9 commit 8fa1d3e

File tree

3 files changed

+174
-17
lines changed

3 files changed

+174
-17
lines changed

docs/sphinx/esql.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ Commands
2525
:members:
2626
:exclude-members: __init__
2727

28+
.. autoclass:: elasticsearch.esql.esql.ChangePoint
29+
:members:
30+
:exclude-members: __init__
31+
2832
.. autoclass:: elasticsearch.esql.esql.Dissect
2933
:members:
3034
:exclude-members: __init__

elasticsearch/esql/esql.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,47 @@ def change_point(self, value: FieldType) -> "ChangePoint":
134134
"""
135135
return ChangePoint(self, value)
136136

137+
def completion(self, *prompt: ExpressionType, **named_prompt: ExpressionType):
138+
"""The `COMPLETION` command allows you to send prompts and context to a Large
139+
Language Model (LLM) directly within your ES|QL queries, to perform text
140+
generation tasks.
141+
142+
:param prompt: The input text or expression used to prompt the LLM. This can
143+
be a string literal or a reference to a column containing text.
144+
:param named_prompt: The input text or expresion, given as a keyword argument.
145+
The argument name is used for the column name. If not
146+
specified, the results will be stored in a column named
147+
`completion`. If the specified column already exists, it
148+
will be overwritten with the new results.
149+
150+
Examples::
151+
152+
query1 = (
153+
ESQL.row(question="What is Elasticsearch?")
154+
.completion(E("question")).with_("test_completion_model")
155+
.keep("question", "completion")
156+
)
157+
query2 = (
158+
ESQL.row(question="What is Elasticsearch?")
159+
.completion(answer=E("question")).with_("test_completion_model")
160+
.keep("question", "answer")
161+
)
162+
query3 = (
163+
ESQL.from_("movies")
164+
.sort("rating DESC")
165+
.limit(10)
166+
.eval(prompt=\"\"\"CONCAT(
167+
"Summarize this movie using the following information: \n",
168+
"Title: ", title, "\n",
169+
"Synopsis: ", synopsis, "\n",
170+
"Actors: ", MV_CONCAT(actors, ", "), "\n",
171+
)\"\"\")
172+
.completion(summary="prompt").with_("test_completion_model")
173+
.keep("title", "summary", "rating")
174+
)
175+
"""
176+
return Completion(self, *prompt, **named_prompt)
177+
137178
def dissect(self, input: FieldType, pattern: str) -> "Dissect":
138179
"""``DISSECT`` enables you to extract structured data out of a string.
139180
@@ -306,43 +347,39 @@ def limit(self, max_number_of_rows: int) -> "Limit":
306347
"""
307348
return Limit(self, max_number_of_rows)
308349

309-
def lookup_join(self, lookup_index: IndexType, field: FieldType) -> "LookupJoin":
350+
def lookup_join(self, lookup_index: IndexType) -> "LookupJoin":
310351
"""`LOOKUP JOIN` enables you to add data from another index, AKA a 'lookup' index,
311352
to your ES|QL query results, simplifying data enrichment and analysis workflows.
312353
313354
:param lookup_index: The name of the lookup index. This must be a specific index
314355
name - wildcards, aliases, and remote cluster references are
315356
not supported. Indices used for lookups must be configured
316357
with the lookup index mode.
317-
:param field: The field to join on. This field must exist in both your current query
318-
results and in the lookup index. If the field contains multi-valued
319-
entries, those entries will not match anything (the added fields will
320-
contain null for those rows).
321358
322359
Examples::
323360
324361
query1 = (
325362
ESQL.from_("firewall_logs")
326-
.lookup_join("threat_list", "source.IP")
363+
.lookup_join("threat_list").on("source.IP")
327364
.where("threat_level IS NOT NULL")
328365
)
329366
query2 = (
330367
ESQL.from_("system_metrics")
331-
.lookup_join("host_inventory", "host.name")
332-
.lookup_join("ownerships", "host.name")
368+
.lookup_join("host_inventory").on("host.name")
369+
.lookup_join("ownerships").on("host.name")
333370
)
334371
query3 = (
335372
ESQL.from_("app_logs")
336-
.lookup_join("service_owners", "service_id")
373+
.lookup_join("service_owners").on("service_id")
337374
)
338375
query4 = (
339376
ESQL.from_("employees")
340377
.eval(language_code="languages")
341378
.where("emp_no >= 10091 AND emp_no < 10094")
342-
.lookup_join("languages_lookup", "language_code")
379+
.lookup_join("languages_lookup").on("language_code")
343380
)
344381
"""
345-
return LookupJoin(self, lookup_index, field)
382+
return LookupJoin(self, lookup_index)
346383

347384
def mv_expand(self, column: FieldType) -> "MvExpand":
348385
"""The `MV_EXPAND` processing command expands multivalued columns into one row per
@@ -635,6 +672,45 @@ def _render_internal(self) -> str:
635672
return f"CHANGE_POINT {self._value}{key}{names}"
636673

637674

675+
class Completion(ESQLBase):
676+
"""Implementation of the ``COMPLETION`` processing command.
677+
678+
This class inherits from :class:`ESQLBase <elasticsearch.esql.esql.ESQLBase>`,
679+
to make it possible to chain all the commands that belong to an ES|QL query
680+
in a single expression.
681+
"""
682+
683+
def __init__(self, parent: ESQLBase, *prompt: ExpressionType, **named_prompt: ExpressionType):
684+
if len(prompt) + len(named_prompt) > 1:
685+
raise ValueError(
686+
"this method requires either one positional or one keyword argument only"
687+
)
688+
super().__init__(parent)
689+
self._prompt = prompt
690+
self._named_prompt = named_prompt
691+
self._inference_id = None
692+
693+
def with_(self, inference_id: str) -> "Completion":
694+
"""Continuation of the `COMPLETION` command.
695+
696+
:param inference_id: The ID of the inference endpoint to use for the task. The
697+
inference endpoint must be configured with the completion
698+
task type.
699+
"""
700+
self._inference_id = inference_id
701+
return self
702+
703+
def _render_internal(self) -> str:
704+
if self._inference_id is None:
705+
raise ValueError('The completion command requires an inference ID')
706+
if self._named_prompt:
707+
column = list(self._named_prompt.keys())[0]
708+
prompt = list(self._named_prompt.values())[0]
709+
return f'COMPLETION {column} = {prompt} WITH {self._inference_id}'
710+
else:
711+
return f'COMPLETION {self._prompt[0]} WITH {self._inference_id}'
712+
713+
638714
class Dissect(ESQLBase):
639715
"""Implementation of the ``DISSECT`` processing command.
640716
@@ -861,12 +937,25 @@ class LookupJoin(ESQLBase):
861937
in a single expression.
862938
"""
863939

864-
def __init__(self, parent: ESQLBase, lookup_index: IndexType, field: FieldType):
940+
def __init__(self, parent: ESQLBase, lookup_index: IndexType):
865941
super().__init__(parent)
866942
self._lookup_index = lookup_index
943+
self._field = None
944+
945+
def on(self, field: FieldType):
946+
"""Continuation of the `LOOKUP_JOIN` command.
947+
948+
:param field: The field to join on. This field must exist in both your current query
949+
results and in the lookup index. If the field contains multi-valued
950+
entries, those entries will not match anything (the added fields will
951+
contain null for those rows).
952+
"""
867953
self._field = field
954+
return self
868955

869956
def _render_internal(self) -> str:
957+
if self._field is None:
958+
raise ValueError("Joins require a field to join on.")
870959
index = (
871960
self._lookup_index
872961
if isinstance(self._lookup_index, str)

test_elasticsearch/test_esql.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,70 @@ def test_change_point():
7474
)
7575

7676

77+
def test_completion():
78+
query = (
79+
ESQL.row(question="What is Elasticsearch?")
80+
.completion("question").with_("test_completion_model")
81+
.keep("question", "completion")
82+
)
83+
assert query.render() == """ROW question = "What is Elasticsearch?"
84+
| COMPLETION question WITH test_completion_model
85+
| KEEP question, completion"""
86+
87+
query = (
88+
ESQL.row(question="What is Elasticsearch?")
89+
.completion(answer=E("question")).with_("test_completion_model")
90+
.keep("question", "answer")
91+
)
92+
assert query.render() == """ROW question = "What is Elasticsearch?"
93+
| COMPLETION answer = question WITH test_completion_model
94+
| KEEP question, answer"""
95+
96+
query = (
97+
ESQL.from_("movies")
98+
.sort("rating DESC")
99+
.limit(10)
100+
.eval(prompt="""CONCAT(
101+
"Summarize this movie using the following information: \\n",
102+
"Title: ", title, "\\n",
103+
"Synopsis: ", synopsis, "\\n",
104+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
105+
)""")
106+
.completion(summary="prompt").with_("test_completion_model")
107+
.keep("title", "summary", "rating")
108+
)
109+
assert query.render() == """FROM movies
110+
| SORT rating DESC
111+
| LIMIT 10
112+
| EVAL prompt = CONCAT(
113+
"Summarize this movie using the following information: \\n",
114+
"Title: ", title, "\\n",
115+
"Synopsis: ", synopsis, "\\n",
116+
"Actors: ", MV_CONCAT(actors, ", "), "\\n",
117+
)
118+
| COMPLETION summary = prompt WITH test_completion_model
119+
| KEEP title, summary, rating"""
120+
121+
query = (
122+
ESQL.from_("movies")
123+
.sort("rating DESC")
124+
.limit(10)
125+
.eval(prompt=functions.concat(
126+
"Summarize this movie using the following information: \n",
127+
"Title: ", E("title"), "\n",
128+
"Synopsis: ", E("synopsis"), "\n",
129+
"Actors: ", functions.mv_concat(E("actors"), ", "), "\n",
130+
))
131+
.completion(summary="prompt").with_("test_completion_model")
132+
.keep("title", "summary", "rating")
133+
)
134+
assert query.render() == """FROM movies
135+
| SORT rating DESC
136+
| LIMIT 10
137+
| EVAL prompt = CONCAT("Summarize this movie using the following information: \\n", "Title: ", title, "\\n", "Synopsis: ", synopsis, "\\n", "Actors: ", MV_CONCAT(actors, ", "), "\\n")
138+
| COMPLETION summary = prompt WITH test_completion_model
139+
| KEEP title, summary, rating"""
140+
77141
def test_dissect():
78142
query = (
79143
ESQL.row(a="2023-01-23T12:15:00.000Z - some text - 127.0.0.1")
@@ -260,7 +324,7 @@ def test_limit():
260324
def test_lookup_join():
261325
query = (
262326
ESQL.from_("firewall_logs")
263-
.lookup_join("threat_list", "source.IP")
327+
.lookup_join("threat_list").on("source.IP")
264328
.where("threat_level IS NOT NULL")
265329
)
266330
assert (
@@ -272,8 +336,8 @@ def test_lookup_join():
272336

273337
query = (
274338
ESQL.from_("system_metrics")
275-
.lookup_join("host_inventory", "host.name")
276-
.lookup_join("ownerships", "host.name")
339+
.lookup_join("host_inventory").on("host.name")
340+
.lookup_join("ownerships").on("host.name")
277341
)
278342
assert (
279343
query.render()
@@ -282,7 +346,7 @@ def test_lookup_join():
282346
| LOOKUP JOIN ownerships ON host.name"""
283347
)
284348

285-
query = ESQL.from_("app_logs").lookup_join("service_owners", "service_id")
349+
query = ESQL.from_("app_logs").lookup_join("service_owners").on("service_id")
286350
assert (
287351
query.render()
288352
== """FROM app_logs
@@ -293,7 +357,7 @@ def test_lookup_join():
293357
ESQL.from_("employees")
294358
.eval(language_code="languages")
295359
.where(E("emp_no") >= 10091, E("emp_no") < 10094)
296-
.lookup_join("languages_lookup", "language_code")
360+
.lookup_join("languages_lookup").on("language_code")
297361
)
298362
assert (
299363
query.render()

0 commit comments

Comments
 (0)