Skip to content

Commit bc56d28

Browse files
committed
Add llmPredict built-in for LLM inference via HTTP
Register llmPredict through the full SystemDS compilation pipeline (Builtins, Opcodes, Types, DMLTranslator, HOP, LOP, CP instruction). LlmPredictCPInstruction sends HTTP POST to OpenAI-compatible servers with configurable concurrency. Includes 10 tests (7 mock, 3 live).
1 parent b394e32 commit bc56d28

File tree

11 files changed

+968
-3
lines changed

11 files changed

+968
-3
lines changed

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ public enum Builtins {
226226
LMDS("lmDS", true),
227227
LMPREDICT("lmPredict", true),
228228
LMPREDICT_STATS("lmPredictStats", true),
229+
LLMPREDICT("llmPredict", false, true),
229230
LOCAL("local", false),
230231
LOG("log", false),
231232
LOGSUMEXP("logSumExp", true),

src/main/java/org/apache/sysds/common/Opcodes.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ public enum Opcodes {
204204
GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin),
205205
RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin),
206206
REPLACE("replace", InstructionType.ParameterizedBuiltin),
207+
LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin),
207208
LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin),
208209
UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin),
209210
REXPAND("rexpand", InstructionType.ParameterizedBuiltin),

src/main/java/org/apache/sysds/common/Types.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) {
805805

806806
/** Parameterized operations that require named variable arguments */
807807
public enum ParamBuiltinOp {
808-
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND,
808+
AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND,
809809
LOWER_TRI, UPPER_TRI,
810810
TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
811811
TOKENIZE, TOSTRING, LIST, PARAMSERV

src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ public Lop constructLops()
187187
case LOWER_TRI:
188188
case UPPER_TRI:
189189
case TOKENIZE:
190+
case LLMPREDICT:
190191
case TRANSFORMAPPLY:
191192
case TRANSFORMDECODE:
192193
case TRANSFORMCOLMAP:
@@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) {
758759
if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA
759760
|| _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST
760761
|| _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF
761-
|| _op == ParamBuiltinOp.PARAMSERV) {
762+
|| _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) {
762763
_etype = ExecType.CP;
763764
}
764765

@@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) {
768769
switch(_op) {
769770
case CONTAINS:
770771
if(getTargetHop().optFindExecType() == ExecType.SPARK)
771-
_etype = ExecType.SPARK;
772+
_etype = ExecType.SPARK;
772773
break;
773774
default:
774775
// Do not change execution type.

src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ public String getInstructions(String output)
176176
case CONTAINS:
177177
case REPLACE:
178178
case TOKENIZE:
179+
case LLMPREDICT:
179180
case TRANSFORMAPPLY:
180181
case TRANSFORMDECODE:
181182
case TRANSFORMCOLMAP:

src/main/java/org/apache/sysds/parser/DMLTranslator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu
20072007
case LOWER_TRI:
20082008
case UPPER_TRI:
20092009
case TOKENIZE:
2010+
case LLMPREDICT:
20102011
case TRANSFORMAPPLY:
20112012
case TRANSFORMDECODE:
20122013
case TRANSFORMCOLMAP:

src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
6161
pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG);
6262
pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY);
6363
pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE);
64+
pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT);
6465
pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI);
6566
pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI);
6667

@@ -211,6 +212,10 @@ public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<Stri
211212
validateOrder(output, conditional);
212213
break;
213214

215+
case LLMPREDICT:
216+
validateLlmPredict(output, conditional);
217+
break;
218+
214219
case TOKENIZE:
215220
validateTokenize(output, conditional);
216221
break;
@@ -614,6 +619,42 @@ private void validateTokenize(DataIdentifier output, boolean conditional)
614619
output.setDimensions(-1, -1);
615620
}
616621

622+
private void validateLlmPredict(DataIdentifier output, boolean conditional)
623+
{
624+
Set<String> valid = new HashSet<>(Arrays.asList(
625+
"target", "url", "model", "max_tokens", "temperature", "top_p", "concurrency"));
626+
checkInvalidParameters(getOpCode(), getVarParams(), valid);
627+
checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional);
628+
checkStringParam(false, "llmPredict", "url", conditional);
629+
630+
// validate numeric parameter types at compile time (when literal).
631+
// Note: no range validation -- different LLM servers accept different
632+
// ranges (e.g. vLLM allows temperature=0.0, OpenAI requires >0).
633+
// Runtime errors from the server are more informative than
634+
// compile-time checks locked to one server's rules.
635+
checkNumericScalarParam("llmPredict", "max_tokens", conditional);
636+
checkNumericScalarParam("llmPredict", "temperature", conditional);
637+
checkNumericScalarParam("llmPredict", "top_p", conditional);
638+
checkNumericScalarParam("llmPredict", "concurrency", conditional);
639+
640+
output.setDataType(DataType.FRAME);
641+
output.setValueType(ValueType.STRING);
642+
output.setDimensions(-1, -1);
643+
}
644+
645+
private void checkNumericScalarParam(String fname, String pname, boolean conditional) {
646+
Expression expr = getVarParam(pname);
647+
if(expr == null) return;
648+
if(expr instanceof DataIdentifier) {
649+
DataIdentifier di = (DataIdentifier) expr;
650+
if(di.getDataType() != null && !di.getDataType().isScalar()) {
651+
raiseValidateError(
652+
String.format("Function %s: parameter '%s' must be a scalar, got %s.",
653+
fname, pname, di.getDataType()), conditional);
654+
}
655+
}
656+
}
657+
617658
// example: A = transformapply(target=X, meta=M, spec=s)
618659
private void validateTransformApply(DataIdentifier output, boolean conditional)
619660
{
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.runtime.instructions.cp;
21+
22+
import java.io.IOException;
23+
import java.io.InputStream;
24+
import java.io.OutputStream;
25+
import java.net.ConnectException;
26+
import java.net.HttpURLConnection;
27+
import java.net.MalformedURLException;
28+
import java.net.SocketTimeoutException;
29+
import java.net.URI;
30+
import java.net.URISyntaxException;
31+
import java.nio.charset.StandardCharsets;
32+
import java.util.ArrayList;
33+
import java.util.LinkedHashMap;
34+
import java.util.List;
35+
import java.util.concurrent.Callable;
36+
import java.util.concurrent.ExecutorService;
37+
import java.util.concurrent.Executors;
38+
import java.util.concurrent.Future;
39+
40+
import org.apache.commons.lang3.tuple.Pair;
41+
import org.apache.sysds.common.Types.DataType;
42+
import org.apache.sysds.common.Types.ValueType;
43+
import org.apache.sysds.runtime.DMLRuntimeException;
44+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
45+
import org.apache.sysds.runtime.frame.data.FrameBlock;
46+
import org.apache.sysds.runtime.lineage.LineageItem;
47+
import org.apache.sysds.runtime.lineage.LineageItemUtils;
48+
import org.apache.wink.json4j.JSONObject;
49+
50+
public class LlmPredictCPInstruction extends ParameterizedBuiltinCPInstruction {
51+
52+
protected LlmPredictCPInstruction(LinkedHashMap<String, String> paramsMap,
53+
CPOperand out, String opcode, String istr) {
54+
super(null, paramsMap, out, opcode, istr);
55+
}
56+
57+
@Override
58+
public void processInstruction(ExecutionContext ec) {
59+
FrameBlock prompts = ec.getFrameInput(params.get("target"));
60+
String url = params.get("url");
61+
String model = params.containsKey("model") ?
62+
params.get("model") : null;
63+
int maxTokens = params.containsKey("max_tokens") ?
64+
Integer.parseInt(params.get("max_tokens")) : 512;
65+
double temperature = params.containsKey("temperature") ?
66+
Double.parseDouble(params.get("temperature")) : 0.0;
67+
double topP = params.containsKey("top_p") ?
68+
Double.parseDouble(params.get("top_p")) : 0.9;
69+
int concurrency = params.containsKey("concurrency") ?
70+
Integer.parseInt(params.get("concurrency")) : 1;
71+
concurrency = Math.max(1, Math.min(concurrency, 128));
72+
73+
int n = prompts.getNumRows();
74+
String[][] data = new String[n][];
75+
76+
List<Callable<String[]>> tasks = new ArrayList<>(n);
77+
for(int i = 0; i < n; i++) {
78+
String prompt = prompts.get(i, 0).toString();
79+
tasks.add(() -> callLlmEndpoint(prompt, url, model, maxTokens, temperature, topP));
80+
}
81+
82+
try {
83+
if(concurrency <= 1) {
84+
for(int i = 0; i < n; i++)
85+
data[i] = tasks.get(i).call();
86+
}
87+
else {
88+
ExecutorService pool = Executors.newFixedThreadPool(
89+
Math.min(concurrency, n));
90+
List<Future<String[]>> futures = pool.invokeAll(tasks);
91+
pool.shutdown();
92+
for(int i = 0; i < n; i++)
93+
data[i] = futures.get(i).get();
94+
}
95+
}
96+
catch(DMLRuntimeException e) {
97+
throw e;
98+
}
99+
catch(Exception e) {
100+
throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e);
101+
}
102+
103+
ValueType[] schema = {ValueType.STRING, ValueType.STRING,
104+
ValueType.INT64, ValueType.INT64, ValueType.INT64};
105+
String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"};
106+
FrameBlock fbout = new FrameBlock(schema, colNames);
107+
for(String[] row : data)
108+
fbout.appendRow(row);
109+
110+
ec.setFrameOutput(output.getName(), fbout);
111+
ec.releaseFrameInput(params.get("target"));
112+
}
113+
114+
// No retry logic by design: as a database built-in, llmPredict should
115+
// fail fast on transient errors and let the caller (DML script) decide
116+
// whether and how to retry. Silent retries with backoff would make
117+
// execution time unpredictable.
118+
private static String[] callLlmEndpoint(String prompt, String url,
119+
String model, int maxTokens, double temperature, double topP) {
120+
long t0 = System.nanoTime();
121+
122+
// validate URL and open connection
123+
HttpURLConnection conn;
124+
try {
125+
conn = (HttpURLConnection) new URI(url).toURL().openConnection();
126+
}
127+
catch(URISyntaxException | MalformedURLException | IllegalArgumentException e) {
128+
throw new DMLRuntimeException(
129+
"llmPredict: invalid URL '" + url + "'. "
130+
+ "Expected format: http://host:port/v1/completions", e);
131+
}
132+
catch(IOException e) {
133+
throw new DMLRuntimeException(
134+
"llmPredict: cannot open connection to '" + url + "'.", e);
135+
}
136+
137+
try {
138+
JSONObject req = new JSONObject();
139+
if(model != null)
140+
req.put("model", model);
141+
req.put("prompt", prompt);
142+
req.put("max_tokens", maxTokens);
143+
req.put("temperature", temperature);
144+
req.put("top_p", topP);
145+
146+
conn.setRequestMethod("POST");
147+
conn.setRequestProperty("Content-Type", "application/json");
148+
conn.setConnectTimeout(10_000);
149+
conn.setReadTimeout(300_000);
150+
conn.setDoOutput(true);
151+
152+
try(OutputStream os = conn.getOutputStream()) {
153+
os.write(req.toString().getBytes(StandardCharsets.UTF_8));
154+
}
155+
156+
int httpCode = conn.getResponseCode();
157+
if(httpCode != 200) {
158+
String errBody = "";
159+
try(InputStream es = conn.getErrorStream()) {
160+
if(es != null)
161+
errBody = new String(es.readAllBytes(), StandardCharsets.UTF_8);
162+
}
163+
catch(Exception ignored) {}
164+
throw new DMLRuntimeException(
165+
"llmPredict: endpoint returned HTTP " + httpCode
166+
+ " for '" + url + "'."
167+
+ (errBody.isEmpty() ? "" : " Response: " + errBody));
168+
}
169+
170+
String body;
171+
try(InputStream is = conn.getInputStream()) {
172+
body = new String(is.readAllBytes(), StandardCharsets.UTF_8);
173+
}
174+
175+
JSONObject resp = new JSONObject(body);
176+
if(!resp.has("choices") || resp.getJSONArray("choices").length() == 0) {
177+
String errMsg = resp.has("error") ? resp.optString("error") : body;
178+
throw new DMLRuntimeException(
179+
"llmPredict: server response missing 'choices'. Response: " + errMsg);
180+
}
181+
String text = resp.getJSONArray("choices")
182+
.getJSONObject(0).getString("text");
183+
long elapsed = (System.nanoTime() - t0) / 1_000_000;
184+
int inTok = 0, outTok = 0;
185+
if(resp.has("usage")) {
186+
JSONObject usage = resp.getJSONObject("usage");
187+
inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0;
188+
outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0;
189+
}
190+
return new String[]{prompt, text,
191+
String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)};
192+
}
193+
catch(ConnectException e) {
194+
throw new DMLRuntimeException(
195+
"llmPredict: connection refused to '" + url + "'. "
196+
+ "Ensure the LLM server is running and reachable.", e);
197+
}
198+
catch(SocketTimeoutException e) {
199+
throw new DMLRuntimeException(
200+
"llmPredict: timed out connecting to '" + url + "'. "
201+
+ "Ensure the LLM server is running and reachable.", e);
202+
}
203+
catch(IOException e) {
204+
throw new DMLRuntimeException(
205+
"llmPredict: I/O error communicating with '" + url + "'.", e);
206+
}
207+
catch(DMLRuntimeException e) {
208+
throw e;
209+
}
210+
catch(Exception e) {
211+
throw new DMLRuntimeException(
212+
"llmPredict: failed to get response from '" + url + "'.", e);
213+
}
214+
finally {
215+
conn.disconnect();
216+
}
217+
}
218+
219+
@Override
220+
public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
221+
CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME);
222+
CPOperand urlOp = new CPOperand(params.get("url"), ValueType.STRING, DataType.SCALAR, true);
223+
return Pair.of(output.getName(),
224+
new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp)));
225+
}
226+
}

src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcode
158158
|| opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) {
159159
return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str);
160160
}
161+
else if(opcode.equals(Opcodes.LLMPREDICT.toString())) {
162+
return new LlmPredictCPInstruction(paramsMap, out, opcode, str);
163+
}
161164
else if(Opcodes.PARAMSERV.toString().equals(opcode)) {
162165
return new ParamservBuiltinCPInstruction(null, paramsMap, out, opcode, str);
163166
}
@@ -324,6 +327,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) {
324327
ec.setFrameOutput(output.getName(), fbout);
325328
ec.releaseFrameInput(params.get("target"));
326329
}
330+
327331
else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) {
328332
// acquire locks
329333
FrameBlock data = ec.getFrameInput(params.get("target"));

0 commit comments

Comments
 (0)