Skip to content

Commit 62a9adb

Browse files
Initial commit for GPULlama3.java REST API support
1 parent 4473b26 commit 62a9adb

File tree

10 files changed

+678
-23
lines changed

10 files changed

+678
-23
lines changed

llama-tornado

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,14 @@ class LlamaRunner:
168168
]
169169
)
170170

171+
# Choose main class based on mode
172+
main_class = "org.beehive.gpullama3.api.LLMApiApplication" if args.service else "org.beehive.gpullama3.LlamaApp"
173+
171174
module_config.extend(
172175
[
173176
"-cp",
174177
f"{self.llama_root}/target/gpu-llama3-1.0-SNAPSHOT.jar",
175-
"org.beehive.gpullama3.LlamaApp",
178+
main_class,
176179
]
177180
)
178181
cmd.extend(module_config)
@@ -181,33 +184,28 @@ class LlamaRunner:
181184

182185
def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]:
183186
"""Add LLaMA-specific arguments to the command."""
187+
184188
llama_args = [
185-
"-m",
186-
args.model_path,
187-
"--temperature",
188-
str(args.temperature),
189-
"--top-p",
190-
str(args.top_p),
191-
"--seed",
192-
str(args.seed),
193-
"--max-tokens",
194-
str(args.max_tokens),
195-
"--stream",
196-
str(args.stream).lower(),
197-
"--echo",
198-
str(args.echo).lower(),
189+
"--model", args.model_path,
190+
"--temperature", str(args.temperature),
191+
"--top-p", str(args.top_p),
192+
"--seed", str(args.seed),
193+
"--max-tokens", str(args.max_tokens),
194+
"--stream", str(args.stream).lower(),
195+
"--echo", str(args.echo).lower(),
196+
"--instruct" # Both modes use instruct
199197
]
200198

201-
if args.prompt:
202-
llama_args.extend(["-p", args.prompt])
199+
# Only add prompt-related args for standalone mode
200+
if not hasattr(args, 'service') or not args.service:
201+
if hasattr(args, 'prompt') and args.prompt:
202+
llama_args.extend(["-p", args.prompt])
203203

204-
if args.system_prompt:
205-
llama_args.extend(["-sp", args.system_prompt])
204+
if hasattr(args, 'system_prompt') and args.system_prompt:
205+
llama_args.extend(["-sp", args.system_prompt])
206206

207-
if args.interactive:
208-
llama_args.append("--interactive")
209-
elif args.instruct:
210-
llama_args.append("--instruct")
207+
if hasattr(args, 'interactive') and args.interactive:
208+
llama_args[-1] = "--interactive" # Replace --instruct
211209

212210
return cmd + llama_args
213211

@@ -219,6 +217,19 @@ class LlamaRunner:
219217
cmd = self._build_base_command(args)
220218
cmd = self._add_llama_args(cmd, args)
221219

220+
# Show service-specific information
221+
if args.service:
222+
print("Starting TornadoVM LLM REST API Service...")
223+
print(f"Model: {args.model_path}")
224+
print("API endpoints will be available at:")
225+
print(" - http://localhost:8080/v1/completions")
226+
print(" - http://localhost:8080/v1/completions/stream")
227+
print(" - http://localhost:8080/v1/models")
228+
print(" - http://localhost:8080/v1/health")
229+
print("\nPress Ctrl+C to stop the service")
230+
print("-" * 60)
231+
232+
222233
# Print command if requested (before verbose output)
223234
if args.show_command:
224235
print("Full Java command:")
@@ -368,6 +379,11 @@ def create_parser() -> argparse.ArgumentParser:
368379
default=True,
369380
help="Run in instruction mode (default)",
370381
)
382+
mode_group.add_argument(
383+
"--service",
384+
action="store_true",
385+
help="Run as REST API service instead of standalone application"
386+
)
371387

372388
# Hardware configuration
373389
hw_group = parser.add_argument_group("Hardware Configuration")

pom.xml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
<maven.compiler.source>21</maven.compiler.source>
1313
<maven.compiler.target>21</maven.compiler.target>
1414
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
15+
<spring.boot.version>3.2.0</spring.boot.version>
16+
<jakarta.version>3.0.0</jakarta.version>
17+
<jackson.version>2.16.1</jackson.version>
1518
</properties>
1619

1720
<dependencies>
@@ -32,6 +35,26 @@
3235
<artifactId>tornado-runtime</artifactId>
3336
<version>1.1.2-dev</version>
3437
</dependency>
38+
39+
<!-- Spring Boot Starter Web -->
40+
<dependency>
41+
<groupId>org.springframework.boot</groupId>
42+
<artifactId>spring-boot-starter-web</artifactId>
43+
<version>${spring.boot.version}</version>
44+
</dependency>
45+
46+
<dependency>
47+
<groupId>jakarta.annotation</groupId>
48+
<artifactId>jakarta.annotation-api</artifactId>
49+
<version>${jakarta.version}</version>
50+
</dependency>
51+
52+
<!-- For JSON processing -->
53+
<dependency>
54+
<groupId>com.fasterxml.jackson.core</groupId>
55+
<artifactId>jackson-databind</artifactId>
56+
<version>${jackson.version}</version>
57+
</dependency>
3558
</dependencies>
3659

3760
<build>
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.beehive.gpullama3.api;
2+
3+
import org.springframework.boot.SpringApplication;
4+
import org.springframework.boot.autoconfigure.SpringBootApplication;
5+
6+
@SpringBootApplication(scanBasePackages = "org.beehive.gpullama3")
7+
public class LLMApiApplication {
8+
9+
public static void main(String[] args) {
10+
System.out.println("Starting TornadoVM LLM API Server...");
11+
System.out.println("Command line arguments: " + String.join(" ", args));
12+
13+
// Let Options.parseOptions() handle validation - no duplication
14+
SpringApplication.run(LLMApiApplication.class, args);
15+
}
16+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package org.beehive.gpullama3.api.config;
2+
3+
import org.beehive.gpullama3.model.Model;
4+
import org.beehive.gpullama3.Options;
5+
import org.beehive.gpullama3.api.service.ModelInitializationService;
6+
import org.springframework.context.annotation.Bean;
7+
import org.springframework.context.annotation.Configuration;
8+
9+
@Configuration
10+
public class ModelConfiguration {
11+
12+
/**
13+
* Expose Model as a Spring bean using the initialized service
14+
*/
15+
@Bean
16+
public Model model(ModelInitializationService initService) {
17+
return initService.getModel();
18+
}
19+
20+
/**
21+
* Expose Options as a Spring bean using the initialized service
22+
*/
23+
@Bean
24+
public Options options(ModelInitializationService initService) {
25+
return initService.getOptions();
26+
}
27+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package org.beehive.gpullama3.api.controller;
2+
3+
import org.beehive.gpullama3.api.model.CompletionRequest;
4+
import org.beehive.gpullama3.api.model.CompletionResponse;
5+
import org.beehive.gpullama3.api.service.LLMService;
6+
import org.springframework.beans.factory.annotation.Autowired;
7+
import org.springframework.http.MediaType;
8+
import org.springframework.http.ResponseEntity;
9+
import org.springframework.web.bind.annotation.*;
10+
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
11+
12+
import java.util.Arrays;
13+
import java.util.concurrent.CompletableFuture;
14+
15+
@RestController
16+
@RequestMapping("/v1")
17+
@CrossOrigin(origins = "*")
18+
public class CompletionController {
19+
20+
@Autowired
21+
private LLMService llmService;
22+
23+
@PostMapping("/completions")
24+
public CompletableFuture<ResponseEntity<CompletionResponse>> createCompletion(@RequestBody CompletionRequest request) {
25+
26+
System.out.println("Received completion request: " + request);
27+
28+
if (Boolean.TRUE.equals(request.getStream())) {
29+
throw new IllegalArgumentException("Use /v1/completions/stream for streaming requests");
30+
}
31+
32+
// Validate request
33+
if (request.getPrompt() == null || request.getPrompt().trim().isEmpty()) {
34+
throw new IllegalArgumentException("Prompt cannot be null or empty");
35+
}
36+
37+
return llmService.generateCompletion(
38+
request.getPrompt(),
39+
request.getMaxTokens(),
40+
request.getTemperature(),
41+
request.getTopP(),
42+
request.getStopSequences()
43+
).thenApply(generatedText -> {
44+
CompletionResponse response = new CompletionResponse();
45+
response.setModel(request.getModel());
46+
47+
CompletionResponse.Choice choice = new CompletionResponse.Choice(
48+
generatedText, 0, "stop");
49+
response.setChoices(Arrays.asList(choice));
50+
51+
// Calculate rough token counts (you might want to make this more accurate)
52+
int promptTokens = request.getPrompt().length() / 4; // Rough estimate
53+
int completionTokens = generatedText.length() / 4; // Rough estimate
54+
CompletionResponse.Usage usage = new CompletionResponse.Usage(promptTokens, completionTokens);
55+
response.setUsage(usage);
56+
57+
System.out.println("Completion response prepared, length: " + generatedText.length());
58+
59+
return ResponseEntity.ok(response);
60+
});
61+
}
62+
63+
@PostMapping(value = "/completions/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
64+
public SseEmitter createStreamingCompletion(@RequestBody CompletionRequest request) {
65+
66+
System.out.println("Received streaming completion request: " + request);
67+
68+
// Validate request
69+
if (request.getPrompt() == null || request.getPrompt().trim().isEmpty()) {
70+
throw new IllegalArgumentException("Prompt cannot be null or empty");
71+
}
72+
73+
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
74+
75+
llmService.generateStreamingCompletion(
76+
request.getPrompt(),
77+
request.getMaxTokens(),
78+
request.getTemperature(),
79+
request.getTopP(),
80+
request.getStopSequences(),
81+
emitter
82+
);
83+
84+
return emitter;
85+
}
86+
87+
@GetMapping("/models")
88+
public ResponseEntity<Object> listModels() {
89+
return ResponseEntity.ok(new Object() {
90+
public final String object = "list";
91+
public final Object[] data = new Object[] {
92+
new Object() {
93+
public final String id = "gpullama3";
94+
public final String object = "model";
95+
public final long created = System.currentTimeMillis() / 1000;
96+
public final String owned_by = "beehive";
97+
}
98+
};
99+
});
100+
}
101+
102+
@GetMapping("/health")
103+
public ResponseEntity<Object> health() {
104+
return ResponseEntity.ok(new Object() {
105+
public final String status = "healthy";
106+
public final long timestamp = System.currentTimeMillis();
107+
});
108+
}
109+
110+
// Global exception handler for this controller
111+
@ExceptionHandler(IllegalArgumentException.class)
112+
public ResponseEntity<Object> handleBadRequest(IllegalArgumentException e) {
113+
return ResponseEntity.badRequest().body(new Object() {
114+
public final String error = e.getMessage();
115+
public final long timestamp = System.currentTimeMillis();
116+
});
117+
}
118+
119+
@ExceptionHandler(Exception.class)
120+
public ResponseEntity<Object> handleInternalError(Exception e) {
121+
System.err.println("Internal server error: " + e.getMessage());
122+
e.printStackTrace();
123+
124+
return ResponseEntity.internalServerError().body(new Object() {
125+
public final String error = "Internal server error";
126+
public final String message = e.getMessage();
127+
public final long timestamp = System.currentTimeMillis();
128+
});
129+
}
130+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package org.beehive.gpullama3.api.model;
2+
3+
import com.fasterxml.jackson.annotation.JsonProperty;
4+
import java.util.List;
5+
6+
public class CompletionRequest {
7+
private String model = "gpullama3";
8+
private String prompt;
9+
10+
@JsonProperty("max_tokens")
11+
private Integer maxTokens = 100;
12+
13+
private Double temperature = 0.7;
14+
15+
@JsonProperty("top_p")
16+
private Double topP = 0.9;
17+
18+
@JsonProperty("stop")
19+
private List<String> stopSequences;
20+
21+
private Boolean stream = false;
22+
23+
// Constructors
24+
public CompletionRequest() {}
25+
26+
// Getters and Setters
27+
public String getModel() { return model; }
28+
public void setModel(String model) { this.model = model; }
29+
30+
public String getPrompt() { return prompt; }
31+
public void setPrompt(String prompt) { this.prompt = prompt; }
32+
33+
public Integer getMaxTokens() { return maxTokens; }
34+
public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; }
35+
36+
public Double getTemperature() { return temperature; }
37+
public void setTemperature(Double temperature) { this.temperature = temperature; }
38+
39+
public Double getTopP() { return topP; }
40+
public void setTopP(Double topP) { this.topP = topP; }
41+
42+
public List<String> getStopSequences() { return stopSequences; }
43+
public void setStopSequences(List<String> stopSequences) { this.stopSequences = stopSequences; }
44+
45+
public Boolean getStream() { return stream; }
46+
public void setStream(Boolean stream) { this.stream = stream; }
47+
48+
@Override
49+
public String toString() {
50+
return "CompletionRequest{" +
51+
"model='" + model + '\'' +
52+
", prompt='" + (prompt != null ? prompt.substring(0, Math.min(50, prompt.length())) + "..." : null) + '\'' +
53+
", maxTokens=" + maxTokens +
54+
", temperature=" + temperature +
55+
", topP=" + topP +
56+
", stream=" + stream +
57+
'}';
58+
}
59+
}

0 commit comments

Comments
 (0)