From e50b1c0e2aa19b2c59a19427466407b3b53cf12b Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 5 Sep 2025 12:24:57 +0300 Subject: [PATCH 01/12] Initial commit for GPULlama3.java REST API support --- llama-tornado | 62 ++++--- pom.xml | 23 +++ .../gpullama3/api/LLMApiApplication.java | 16 ++ .../api/config/ModelConfiguration.java | 27 +++ .../api/controller/CompletionController.java | 130 +++++++++++++ .../api/model/CompletionRequest.java | 59 ++++++ .../api/model/CompletionResponse.java | 93 ++++++++++ .../gpullama3/api/service/LLMService.java | 173 ++++++++++++++++++ .../service/ModelInitializationService.java | 86 +++++++++ .../api/service/TokenizerService.java | 32 ++++ 10 files changed, 678 insertions(+), 23 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java create mode 100644 src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java create mode 100644 src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java create mode 100644 src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java create mode 100644 src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java create mode 100644 src/main/java/org/beehive/gpullama3/api/service/LLMService.java create mode 100644 src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java create mode 100644 src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java diff --git a/llama-tornado b/llama-tornado index b59473f..6c62b01 100755 --- a/llama-tornado +++ b/llama-tornado @@ -169,11 +169,14 @@ class LlamaRunner: ] ) + # Choose main class based on mode + main_class = "org.beehive.gpullama3.api.LLMApiApplication" if args.service else "org.beehive.gpullama3.LlamaApp" + module_config.extend( [ "-cp", self._find_llama_jar(), - "org.beehive.gpullama3.LlamaApp", + main_class, ] ) cmd.extend(module_config) @@ -200,33 +203,28 @@ class LlamaRunner: def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]: """Add LLaMA-specific arguments to the command.""" + llama_args = [ - "-m", - args.model_path, - "--temperature", - str(args.temperature), - "--top-p", - str(args.top_p), - "--seed", - str(args.seed), - "--max-tokens", - str(args.max_tokens), - "--stream", - str(args.stream).lower(), - "--echo", - str(args.echo).lower(), + "--model", args.model_path, + "--temperature", str(args.temperature), + "--top-p", str(args.top_p), + "--seed", str(args.seed), + "--max-tokens", str(args.max_tokens), + "--stream", str(args.stream).lower(), + "--echo", str(args.echo).lower(), + "--instruct" # Both modes use instruct ] - if args.prompt: - llama_args.extend(["-p", args.prompt]) + # Only add prompt-related args for standalone mode + if not hasattr(args, 'service') or not args.service: + if hasattr(args, 'prompt') and args.prompt: + llama_args.extend(["-p", args.prompt]) - if args.system_prompt: - llama_args.extend(["-sp", args.system_prompt]) + if hasattr(args, 'system_prompt') and args.system_prompt: + llama_args.extend(["-sp", args.system_prompt]) - if args.interactive: - llama_args.append("--interactive") - elif args.instruct: - llama_args.append("--instruct") + if hasattr(args, 'interactive') and args.interactive: + llama_args[-1] = "--interactive" # Replace --instruct return cmd + llama_args @@ -238,6 +236,19 @@ class LlamaRunner: cmd = self._build_base_command(args) cmd = self._add_llama_args(cmd, args) + # Show service-specific information + if args.service: + print("Starting TornadoVM LLM REST API Service...") + print(f"Model: {args.model_path}") + print("API endpoints will be available at:") + print(" - http://localhost:8080/v1/completions") + print(" - http://localhost:8080/v1/completions/stream") + print(" - http://localhost:8080/v1/models") + print(" - http://localhost:8080/v1/health") + print("\nPress Ctrl+C to stop the service") + print("-" * 60) + + # Print command if requested (before verbose output) if args.show_command: print("Full Java command:") @@ -390,6 +401,11 @@ def create_parser() -> argparse.ArgumentParser: default=True, help="Run in instruction mode (default)", ) + mode_group.add_argument( + "--service", + action="store_true", + help="Run as REST API service instead of standalone application" + ) # Hardware configuration hw_group = parser.add_argument_group("Hardware Configuration") diff --git a/pom.xml b/pom.xml index 624fc67..a031256 100644 --- a/pom.xml +++ b/pom.xml @@ -12,6 +12,9 @@ 21 21 UTF-8 + 3.2.0 + 3.0.0 + 2.16.1 @@ -32,6 +35,26 @@ tornado-runtime 1.1.2-dev + + + + org.springframework.boot + spring-boot-starter-web + ${spring.boot.version} + + + + jakarta.annotation + jakarta.annotation-api + ${jakarta.version} + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + diff --git a/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java b/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java new file mode 100644 index 0000000..2ff1d60 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java @@ -0,0 +1,16 @@ +package org.beehive.gpullama3.api; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication(scanBasePackages = "org.beehive.gpullama3") +public class LLMApiApplication { + + public static void main(String[] args) { + System.out.println("Starting TornadoVM LLM API Server..."); + System.out.println("Command line arguments: " + String.join(" ", args)); + + // Let Options.parseOptions() handle validation - no duplication + SpringApplication.run(LLMApiApplication.class, args); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java new file mode 100644 index 0000000..9a3399d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.api.config; + +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.Options; +import org.beehive.gpullama3.api.service.ModelInitializationService; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class ModelConfiguration { + + /** + * Expose Model as a Spring bean using the initialized service + */ + @Bean + public Model model(ModelInitializationService initService) { + return initService.getModel(); + } + + /** + * Expose Options as a Spring bean using the initialized service + */ + @Bean + public Options options(ModelInitializationService initService) { + return initService.getOptions(); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java b/src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java new file mode 100644 index 0000000..6137128 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java @@ -0,0 +1,130 @@ +package org.beehive.gpullama3.api.controller; + +import org.beehive.gpullama3.api.model.CompletionRequest; +import org.beehive.gpullama3.api.model.CompletionResponse; +import org.beehive.gpullama3.api.service.LLMService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; + +@RestController +@RequestMapping("/v1") +@CrossOrigin(origins = "*") +public class CompletionController { + + @Autowired + private LLMService llmService; + + @PostMapping("/completions") + public CompletableFuture> createCompletion(@RequestBody CompletionRequest request) { + + System.out.println("Received completion request: " + request); + + if (Boolean.TRUE.equals(request.getStream())) { + throw new IllegalArgumentException("Use /v1/completions/stream for streaming requests"); + } + + // Validate request + if (request.getPrompt() == null || request.getPrompt().trim().isEmpty()) { + throw new IllegalArgumentException("Prompt cannot be null or empty"); + } + + return llmService.generateCompletion( + request.getPrompt(), + request.getMaxTokens(), + request.getTemperature(), + request.getTopP(), + request.getStopSequences() + ).thenApply(generatedText -> { + CompletionResponse response = new CompletionResponse(); + response.setModel(request.getModel()); + + CompletionResponse.Choice choice = new CompletionResponse.Choice( + generatedText, 0, "stop"); + response.setChoices(Arrays.asList(choice)); + + // Calculate rough token counts (you might want to make this more accurate) + int promptTokens = request.getPrompt().length() / 4; // Rough estimate + int completionTokens = generatedText.length() / 4; // Rough estimate + CompletionResponse.Usage usage = new CompletionResponse.Usage(promptTokens, completionTokens); + response.setUsage(usage); + + System.out.println("Completion response prepared, length: " + generatedText.length()); + + return ResponseEntity.ok(response); + }); + } + + @PostMapping(value = "/completions/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter createStreamingCompletion(@RequestBody CompletionRequest request) { + + System.out.println("Received streaming completion request: " + request); + + // Validate request + if (request.getPrompt() == null || request.getPrompt().trim().isEmpty()) { + throw new IllegalArgumentException("Prompt cannot be null or empty"); + } + + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + + llmService.generateStreamingCompletion( + request.getPrompt(), + request.getMaxTokens(), + request.getTemperature(), + request.getTopP(), + request.getStopSequences(), + emitter + ); + + return emitter; + } + + @GetMapping("/models") + public ResponseEntity listModels() { + return ResponseEntity.ok(new Object() { + public final String object = "list"; + public final Object[] data = new Object[] { + new Object() { + public final String id = "gpullama3"; + public final String object = "model"; + public final long created = System.currentTimeMillis() / 1000; + public final String owned_by = "beehive"; + } + }; + }); + } + + @GetMapping("/health") + public ResponseEntity health() { + return ResponseEntity.ok(new Object() { + public final String status = "healthy"; + public final long timestamp = System.currentTimeMillis(); + }); + } + + // Global exception handler for this controller + @ExceptionHandler(IllegalArgumentException.class) + public ResponseEntity handleBadRequest(IllegalArgumentException e) { + return ResponseEntity.badRequest().body(new Object() { + public final String error = e.getMessage(); + public final long timestamp = System.currentTimeMillis(); + }); + } + + @ExceptionHandler(Exception.class) + public ResponseEntity handleInternalError(Exception e) { + System.err.println("Internal server error: " + e.getMessage()); + e.printStackTrace(); + + return ResponseEntity.internalServerError().body(new Object() { + public final String error = "Internal server error"; + public final String message = e.getMessage(); + public final long timestamp = System.currentTimeMillis(); + }); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java b/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java new file mode 100644 index 0000000..56a304e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java @@ -0,0 +1,59 @@ +package org.beehive.gpullama3.api.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +public class CompletionRequest { + private String model = "gpullama3"; + private String prompt; + + @JsonProperty("max_tokens") + private Integer maxTokens = 100; + + private Double temperature = 0.7; + + @JsonProperty("top_p") + private Double topP = 0.9; + + @JsonProperty("stop") + private List stopSequences; + + private Boolean stream = false; + + // Constructors + public CompletionRequest() {} + + // Getters and Setters + public String getModel() { return model; } + public void setModel(String model) { this.model = model; } + + public String getPrompt() { return prompt; } + public void setPrompt(String prompt) { this.prompt = prompt; } + + public Integer getMaxTokens() { return maxTokens; } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + + public Double getTemperature() { return temperature; } + public void setTemperature(Double temperature) { this.temperature = temperature; } + + public Double getTopP() { return topP; } + public void setTopP(Double topP) { this.topP = topP; } + + public List getStopSequences() { return stopSequences; } + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } + + public Boolean getStream() { return stream; } + public void setStream(Boolean stream) { this.stream = stream; } + + @Override + public String toString() { + return "CompletionRequest{" + + "model='" + model + '\'' + + ", prompt='" + (prompt != null ? prompt.substring(0, Math.min(50, prompt.length())) + "..." : null) + '\'' + + ", maxTokens=" + maxTokens + + ", temperature=" + temperature + + ", topP=" + topP + + ", stream=" + stream + + '}'; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java b/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java new file mode 100644 index 0000000..b3cecf3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java @@ -0,0 +1,93 @@ +package org.beehive.gpullama3.api.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +public class CompletionResponse { + private String id; + private String object = "text_completion"; + private Long created; + private String model; + private List choices; + private Usage usage; + + public static class Choice { + private String text; + private Integer index; + + @JsonProperty("finish_reason") + private String finishReason; + + public Choice() {} + + public Choice(String text, Integer index, String finishReason) { + this.text = text; + this.index = index; + this.finishReason = finishReason; + } + + // Getters and Setters + public String getText() { return text; } + public void setText(String text) { this.text = text; } + + public Integer getIndex() { return index; } + public void setIndex(Integer index) { this.index = index; } + + public String getFinishReason() { return finishReason; } + public void setFinishReason(String finishReason) { this.finishReason = finishReason; } + } + + public static class Usage { + @JsonProperty("prompt_tokens") + private Integer promptTokens; + + @JsonProperty("completion_tokens") + private Integer completionTokens; + + @JsonProperty("total_tokens") + private Integer totalTokens; + + public Usage() {} + + public Usage(Integer promptTokens, Integer completionTokens) { + this.promptTokens = promptTokens; + this.completionTokens = completionTokens; + this.totalTokens = promptTokens + completionTokens; + } + + // Getters and Setters + public Integer getPromptTokens() { return promptTokens; } + public void setPromptTokens(Integer promptTokens) { this.promptTokens = promptTokens; } + + public Integer getCompletionTokens() { return completionTokens; } + public void setCompletionTokens(Integer completionTokens) { this.completionTokens = completionTokens; } + + public Integer getTotalTokens() { return totalTokens; } + public void setTotalTokens(Integer totalTokens) { this.totalTokens = totalTokens; } + } + + // Constructors + public CompletionResponse() { + this.id = "cmpl-" + System.currentTimeMillis(); + this.created = System.currentTimeMillis() / 1000; + } + + // Getters and Setters + public String getId() { return id; } + public void setId(String id) { this.id = id; } + + public String getObject() { return object; } + public void setObject(String object) { this.object = object; } + + public Long getCreated() { return created; } + public void setCreated(Long created) { this.created = created; } + + public String getModel() { return model; } + public void setModel(String model) { this.model = model; } + + public List getChoices() { return choices; } + public void setChoices(List choices) { this.choices = choices; } + + public Usage getUsage() { return usage; } + public void setUsage(Usage usage) { this.usage = usage; } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java new file mode 100644 index 0000000..41741ac --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -0,0 +1,173 @@ +package org.beehive.gpullama3.api.service; + +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.*; +import java.util.concurrent.CompletableFuture; +import java.util.function.IntConsumer; + +@Service +public class LLMService { + + @Autowired + private ModelInitializationService initService; + + @Autowired + private TokenizerService tokenizerService; + + public CompletableFuture generateCompletion( + String prompt, + int maxTokens, + double temperature, + double topP, + List stopSequences) { + + return CompletableFuture.supplyAsync(() -> { + try { + System.out.println("Starting completion generation..."); + System.out.println("Prompt: " + prompt.substring(0, Math.min(50, prompt.length())) + "..."); + System.out.println("Max tokens: " + maxTokens + ", Temperature: " + temperature); + + // Get initialized components + Model model = initService.getModel(); + + // Convert prompt to tokens + List promptTokens = tokenizerService.encode(prompt); + System.out.println("Prompt tokens: " + promptTokens.size()); + + // Convert stop sequences to token sets + Set stopTokens = new HashSet<>(); + if (stopSequences != null) { + for (String stop : stopSequences) { + stopTokens.addAll(tokenizerService.encode(stop)); + } + System.out.println("Stop tokens: " + stopTokens.size()); + } + + // Create custom sampler with request-specific parameters + //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis()); + Sampler sampler = initService.getSampler(); + + // Create state based on model type + State state = createStateForModel(model); + + // Generate tokens using your existing method + List generatedTokens = model.generateTokens( + state, + 0, + promptTokens, + stopTokens, + maxTokens, + sampler, + false, + token -> {} // No callback for non-streaming + ); + + // Decode tokens back to text + String result = tokenizerService.decode(generatedTokens); + System.out.println("Generated " + generatedTokens.size() + " tokens"); + System.out.println("Completion finished successfully"); + + return result; + + } catch (Exception e) { + System.err.println("Error generating completion: " + e.getMessage()); + e.printStackTrace(); + throw new RuntimeException("Error generating completion", e); + } + }); + } + + public void generateStreamingCompletion( + String prompt, + int maxTokens, + double temperature, + double topP, + List stopSequences, + SseEmitter emitter) { + + CompletableFuture.runAsync(() -> { + try { + System.out.println("Starting streaming completion generation..."); + + Model model = initService.getModel(); + + List promptTokens = tokenizerService.encode(prompt); + + Set stopTokens = new HashSet<>(); + if (stopSequences != null) { + for (String stop : stopSequences) { + stopTokens.addAll(tokenizerService.encode(stop)); + } + } + + //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis()); + Sampler sampler = initService.getSampler(); + State state = createStateForModel(model); + + final int[] tokenCount = {0}; + + // Streaming callback + IntConsumer tokenCallback = token -> { + try { + String tokenText = tokenizerService.decode(List.of(token)); + tokenCount[0]++; + + String eventData = String.format( + "data: {\"choices\":[{\"text\":\"%s\",\"index\":0,\"finish_reason\":null}]}\n\n", + escapeJson(tokenText) + ); + + emitter.send(SseEmitter.event().data(eventData)); + + if (tokenCount[0] % 10 == 0) { + System.out.println("Streamed " + tokenCount[0] + " tokens"); + } + + } catch (Exception e) { + System.err.println("Error in streaming callback: " + e.getMessage()); + emitter.completeWithError(e); + } + }; + + model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenCallback); + + // Send completion event + emitter.send(SseEmitter.event().data("data: [DONE]\n\n")); + emitter.complete(); + + System.out.println("Streaming completion finished. Total tokens: " + tokenCount[0]); + + } catch (Exception e) { + System.err.println("Error in streaming generation: " + e.getMessage()); + e.printStackTrace(); + emitter.completeWithError(e); + } + }); + } + + /** + * Create appropriate State subclass based on the model type + */ + private State createStateForModel(Model model) { + try { + return model.createNewState(); + } catch (Exception e) { + throw new RuntimeException("Failed to create state for model", e); + } + } + + private String escapeJson(String str) { + if (str == null) return ""; + return str.replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace("\\", "\\\\"); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java b/src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java new file mode 100644 index 0000000..c2c30b4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java @@ -0,0 +1,86 @@ +package org.beehive.gpullama3.api.service; + +import org.beehive.gpullama3.LlamaApp; +import org.beehive.gpullama3.Options; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.springframework.boot.ApplicationArguments; +import org.springframework.stereotype.Service; + +import jakarta.annotation.PostConstruct; +import java.io.IOException; + +import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +@Service +public class ModelInitializationService { + + private final ApplicationArguments args; + + private Options options; + private Model model; + private Sampler sampler; + + public ModelInitializationService(ApplicationArguments args) { + this.args = args; + } + + @PostConstruct + public void init() { + try { + System.out.println("=== Model Initialization Service ==="); + System.out.println("Initializing model service..."); + + // Step 1: Parse options from command line arguments + System.out.println("Step 1: Parsing options..."); + options = Options.parseOptions(args.getSourceArgs()); + System.out.println("✓ Options parsed successfully"); + + // Step 2: Load model + System.out.println("\nStep 2: Loading model..."); + System.out.println("Loading model from: " + options.modelPath()); + model = loadModel(options); + System.out.println("✓ Model loaded successfully"); + System.out.println(" Model type: " + model.getClass().getSimpleName()); + System.out.println(" Vocabulary size: " + model.configuration().vocabularySize()); + System.out.println(" Context length: " + model.configuration().contextLength()); + + // Step 3: Create default sampler + System.out.println("\nStep 3: Creating default sampler..."); + sampler = createSampler(model, options); + System.out.println("✓ Default sampler created"); + System.out.println(" Sampler type: " + sampler.getClass().getSimpleName()); + + System.out.println("\n✓ Model service initialization completed successfully!"); + System.out.println("=== Ready to serve requests ===\n"); + + } catch (Exception e) { + System.err.println("✗ Failed to initialize model service: " + e.getMessage()); + e.printStackTrace(); + throw new RuntimeException("Model initialization failed", e); + } + } + + // Getters for other services to access the initialized components + public Options getOptions() { + if (options == null) { + throw new IllegalStateException("Model service not initialized yet"); + } + return options; + } + + public Model getModel() { + if (model == null) { + throw new IllegalStateException("Model service not initialized yet"); + } + return model; + } + + public Sampler getSampler() { + if (sampler == null) { + throw new IllegalStateException("Model service not initialized yet"); + } + return sampler; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java b/src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java new file mode 100644 index 0000000..0120887 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java @@ -0,0 +1,32 @@ +package org.beehive.gpullama3.api.service; + +import org.beehive.gpullama3.model.Model; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import java.util.List; + +@Service +public class TokenizerService { + + @Autowired + private ModelInitializationService initService; + + public List encode(String text) { + Model model = initService.getModel(); + // Use your model's tokenizer - adapt this to your actual tokenizer interface + // This assumes your Model has a tokenizer() method that returns a Tokenizer + return model.tokenizer().encodeAsList(text); + } + + public String decode(List tokens) { + Model model = initService.getModel(); + // Use your model's tokenizer for decoding + return model.tokenizer().decode(tokens); + } + +// public String decode(int token) { +// Model model = initService.getModel(); +// // Convenience method for single token decoding +// return model.tokenizer().decode(token); +// } +} \ No newline at end of file From a4e103d599f1d0ff4b425019f2d8cede3a38aa80 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 16 Sep 2025 19:07:25 +0300 Subject: [PATCH 02/12] [WIP] Simplify service logic with ChatController --- .../api/controller/ChatController.java | 79 +++++++++++ .../api/controller/CompletionController.java | 130 ------------------ .../service/ModelInitializationService.java | 86 ------------ .../api/service/TokenizerService.java | 32 ----- 4 files changed, 79 insertions(+), 248 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/api/controller/ChatController.java delete mode 100644 src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java delete mode 100644 src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java delete mode 100644 src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java diff --git a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java new file mode 100644 index 0000000..df1e036 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java @@ -0,0 +1,79 @@ +package org.beehive.gpullama3.api.controller; + +import org.beehive.gpullama3.api.service.LLMService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.Map; + +@RestController +public class ChatController { + + @Autowired + private LLMService llmService; + + @PostMapping("/chat") + public Map chat(@RequestBody ChatRequest request) { + logRequest("NON_STREAMING", request, 150, 0.7, 0.9); + + if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { + throw new IllegalArgumentException("Message cannot be empty"); + } + + String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage()); + + return Map.of("response", response); + } + + @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter streamChat(@RequestBody ChatRequest request) { + logRequest("STREAMING", request, 150, 0.7, 0.9); + + if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { + throw new IllegalArgumentException("Message cannot be empty"); + } + + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + llmService.generateStreamingResponse(request.getMessage(), request.getSystemMessage(), emitter); + + return emitter; + } + + @GetMapping("/health") + public Map health() { + return Map.of("status", "healthy", "timestamp", String.valueOf(System.currentTimeMillis())); + } + + private void logRequest(String type, ChatRequest request, int maxTokens, double temperature, double topP) { + System.out.printf("REQUEST [%s] user='%s' system='%s' maxTokens=%d temp=%.2f topP=%.2f%n", + type, + truncate(request.getMessage(), 100), + request.getSystemMessage() != null ? truncate(request.getSystemMessage(), 40) : "none", + maxTokens, + temperature, + topP + ); + } + + private String truncate(String text, int maxLength) { + if (text == null) return "null"; + return text.length() > maxLength ? text.substring(0, maxLength) + "..." : text; + } + + // Simple request class for custom parameters + public static class ChatRequest { + private String message; + private String systemMessage; + + public String getMessage() { return message; } + public void setMessage(String message) { this.message = message; } + + public String getSystemMessage() { return systemMessage; } + public void setSystemMessage(String systemMessage) { this.systemMessage = systemMessage; } + } +} diff --git a/src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java b/src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java deleted file mode 100644 index 6137128..0000000 --- a/src/main/java/org/beehive/gpullama3/api/controller/CompletionController.java +++ /dev/null @@ -1,130 +0,0 @@ -package org.beehive.gpullama3.api.controller; - -import org.beehive.gpullama3.api.model.CompletionRequest; -import org.beehive.gpullama3.api.model.CompletionResponse; -import org.beehive.gpullama3.api.service.LLMService; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.*; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; - -import java.util.Arrays; -import java.util.concurrent.CompletableFuture; - -@RestController -@RequestMapping("/v1") -@CrossOrigin(origins = "*") -public class CompletionController { - - @Autowired - private LLMService llmService; - - @PostMapping("/completions") - public CompletableFuture> createCompletion(@RequestBody CompletionRequest request) { - - System.out.println("Received completion request: " + request); - - if (Boolean.TRUE.equals(request.getStream())) { - throw new IllegalArgumentException("Use /v1/completions/stream for streaming requests"); - } - - // Validate request - if (request.getPrompt() == null || request.getPrompt().trim().isEmpty()) { - throw new IllegalArgumentException("Prompt cannot be null or empty"); - } - - return llmService.generateCompletion( - request.getPrompt(), - request.getMaxTokens(), - request.getTemperature(), - request.getTopP(), - request.getStopSequences() - ).thenApply(generatedText -> { - CompletionResponse response = new CompletionResponse(); - response.setModel(request.getModel()); - - CompletionResponse.Choice choice = new CompletionResponse.Choice( - generatedText, 0, "stop"); - response.setChoices(Arrays.asList(choice)); - - // Calculate rough token counts (you might want to make this more accurate) - int promptTokens = request.getPrompt().length() / 4; // Rough estimate - int completionTokens = generatedText.length() / 4; // Rough estimate - CompletionResponse.Usage usage = new CompletionResponse.Usage(promptTokens, completionTokens); - response.setUsage(usage); - - System.out.println("Completion response prepared, length: " + generatedText.length()); - - return ResponseEntity.ok(response); - }); - } - - @PostMapping(value = "/completions/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public SseEmitter createStreamingCompletion(@RequestBody CompletionRequest request) { - - System.out.println("Received streaming completion request: " + request); - - // Validate request - if (request.getPrompt() == null || request.getPrompt().trim().isEmpty()) { - throw new IllegalArgumentException("Prompt cannot be null or empty"); - } - - SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - - llmService.generateStreamingCompletion( - request.getPrompt(), - request.getMaxTokens(), - request.getTemperature(), - request.getTopP(), - request.getStopSequences(), - emitter - ); - - return emitter; - } - - @GetMapping("/models") - public ResponseEntity listModels() { - return ResponseEntity.ok(new Object() { - public final String object = "list"; - public final Object[] data = new Object[] { - new Object() { - public final String id = "gpullama3"; - public final String object = "model"; - public final long created = System.currentTimeMillis() / 1000; - public final String owned_by = "beehive"; - } - }; - }); - } - - @GetMapping("/health") - public ResponseEntity health() { - return ResponseEntity.ok(new Object() { - public final String status = "healthy"; - public final long timestamp = System.currentTimeMillis(); - }); - } - - // Global exception handler for this controller - @ExceptionHandler(IllegalArgumentException.class) - public ResponseEntity handleBadRequest(IllegalArgumentException e) { - return ResponseEntity.badRequest().body(new Object() { - public final String error = e.getMessage(); - public final long timestamp = System.currentTimeMillis(); - }); - } - - @ExceptionHandler(Exception.class) - public ResponseEntity handleInternalError(Exception e) { - System.err.println("Internal server error: " + e.getMessage()); - e.printStackTrace(); - - return ResponseEntity.internalServerError().body(new Object() { - public final String error = "Internal server error"; - public final String message = e.getMessage(); - public final long timestamp = System.currentTimeMillis(); - }); - } -} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java b/src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java deleted file mode 100644 index c2c30b4..0000000 --- a/src/main/java/org/beehive/gpullama3/api/service/ModelInitializationService.java +++ /dev/null @@ -1,86 +0,0 @@ -package org.beehive.gpullama3.api.service; - -import org.beehive.gpullama3.LlamaApp; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.inference.sampler.Sampler; -import org.springframework.boot.ApplicationArguments; -import org.springframework.stereotype.Service; - -import jakarta.annotation.PostConstruct; -import java.io.IOException; - -import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler; -import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; - -@Service -public class ModelInitializationService { - - private final ApplicationArguments args; - - private Options options; - private Model model; - private Sampler sampler; - - public ModelInitializationService(ApplicationArguments args) { - this.args = args; - } - - @PostConstruct - public void init() { - try { - System.out.println("=== Model Initialization Service ==="); - System.out.println("Initializing model service..."); - - // Step 1: Parse options from command line arguments - System.out.println("Step 1: Parsing options..."); - options = Options.parseOptions(args.getSourceArgs()); - System.out.println("✓ Options parsed successfully"); - - // Step 2: Load model - System.out.println("\nStep 2: Loading model..."); - System.out.println("Loading model from: " + options.modelPath()); - model = loadModel(options); - System.out.println("✓ Model loaded successfully"); - System.out.println(" Model type: " + model.getClass().getSimpleName()); - System.out.println(" Vocabulary size: " + model.configuration().vocabularySize()); - System.out.println(" Context length: " + model.configuration().contextLength()); - - // Step 3: Create default sampler - System.out.println("\nStep 3: Creating default sampler..."); - sampler = createSampler(model, options); - System.out.println("✓ Default sampler created"); - System.out.println(" Sampler type: " + sampler.getClass().getSimpleName()); - - System.out.println("\n✓ Model service initialization completed successfully!"); - System.out.println("=== Ready to serve requests ===\n"); - - } catch (Exception e) { - System.err.println("✗ Failed to initialize model service: " + e.getMessage()); - e.printStackTrace(); - throw new RuntimeException("Model initialization failed", e); - } - } - - // Getters for other services to access the initialized components - public Options getOptions() { - if (options == null) { - throw new IllegalStateException("Model service not initialized yet"); - } - return options; - } - - public Model getModel() { - if (model == null) { - throw new IllegalStateException("Model service not initialized yet"); - } - return model; - } - - public Sampler getSampler() { - if (sampler == null) { - throw new IllegalStateException("Model service not initialized yet"); - } - return sampler; - } -} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java b/src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java deleted file mode 100644 index 0120887..0000000 --- a/src/main/java/org/beehive/gpullama3/api/service/TokenizerService.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.beehive.gpullama3.api.service; - -import org.beehive.gpullama3.model.Model; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; -import java.util.List; - -@Service -public class TokenizerService { - - @Autowired - private ModelInitializationService initService; - - public List encode(String text) { - Model model = initService.getModel(); - // Use your model's tokenizer - adapt this to your actual tokenizer interface - // This assumes your Model has a tokenizer() method that returns a Tokenizer - return model.tokenizer().encodeAsList(text); - } - - public String decode(List tokens) { - Model model = initService.getModel(); - // Use your model's tokenizer for decoding - return model.tokenizer().decode(tokens); - } - -// public String decode(int token) { -// Model model = initService.getModel(); -// // Convenience method for single token decoding -// return model.tokenizer().decode(token); -// } -} \ No newline at end of file From 9a1a47514b506b34e13cf849085a7571eeaf5850 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 16 Sep 2025 19:08:13 +0300 Subject: [PATCH 03/12] Update llama-tornado python script for service --- llama-tornado | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/llama-tornado b/llama-tornado index 6c62b01..c618130 100755 --- a/llama-tornado +++ b/llama-tornado @@ -204,6 +204,14 @@ class LlamaRunner: def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]: """Add LLaMA-specific arguments to the command.""" + # For service mode, only pass the model path and max-tokens + if hasattr(args, 'service') and args.service: + llama_args = [ + "--model", args.model_path, + "--max-tokens", str(args.max_tokens), + ] + return cmd + llama_args + llama_args = [ "--model", args.model_path, "--temperature", str(args.temperature), @@ -238,14 +246,19 @@ class LlamaRunner: # Show service-specific information if args.service: - print("Starting TornadoVM LLM REST API Service...") + print("Starting GPULlama3.java REST API Service...") print(f"Model: {args.model_path}") - print("API endpoints will be available at:") - print(" - http://localhost:8080/v1/completions") - print(" - http://localhost:8080/v1/completions/stream") - print(" - http://localhost:8080/v1/models") - print(" - http://localhost:8080/v1/health") - print("\nPress Ctrl+C to stop the service") + print("API endpoints available at:") + print(" - http://localhost:8080/chat") + print(" - http://localhost:8080/chat/stream") + print(" - http://localhost:8080/health") + print("") + print("Example usage:") + print(' curl -X POST http://localhost:8080/chat \\') + print(' -H "Content-Type: application/json" \\') + print(' -d \'{"message": "Hello!"}\'') + print("") + print("Press Ctrl+C to stop the service") print("-" * 60) From 3eabd7517b2c7c1ee44ecf056bbc9066239b1773 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 16 Sep 2025 19:09:42 +0300 Subject: [PATCH 04/12] Add serviceMode Options field and parseServiceOptions method --- .../java/org/beehive/gpullama3/Options.java | 57 +++++++++++++++++-- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 9892448..67d0ef3 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -5,12 +5,15 @@ import java.nio.file.Paths; public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, - boolean useTornadovm) { + boolean useTornadovm, boolean serviceMode) { public static final int DEFAULT_MAX_TOKENS = 1024; public Options { - require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); + // Skip prompt validation in service mode + if (!serviceMode) { + require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); + } require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); } @@ -61,7 +64,7 @@ public static Options getDefaultOptions() { boolean echo = false; boolean useTornadoVM = getDefaultTornadoVM(); - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false); } public static Options parseOptions(String[] args) { @@ -123,6 +126,52 @@ public static Options parseOptions(String[] args) { useTornadovm = getDefaultTornadoVM(); } - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, false); + } + + public static Options parseServiceOptions(String[] args) { + Path modelPath = null; + int maxTokens = 512; // Default context length + + for (int i = 0; i < args.length; i++) { + String optionName = args[i]; + require(optionName.startsWith("-"), "Invalid option %s", optionName); + + String nextArg; + if (optionName.contains("=")) { + String[] parts = optionName.split("=", 2); + optionName = parts[0]; + nextArg = parts[1]; + } else { + if (i + 1 >= args.length) continue; // Skip if no next arg + nextArg = args[i + 1]; + i += 1; // skip arg + } + + // Only parse these options in service mode + switch (optionName) { + case "--model", "-m" -> modelPath = Paths.get(nextArg); + case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg); + } + } + + require(modelPath != null, "Missing argument: --model is required"); + + // Create service-mode Options object + return new Options( + modelPath, + null, // prompt - not used in service + null, // systemPrompt - handled per request + null, // suffix - not used + false, // interactive - not used in service + 0.7f, // temperature - default, overridden per request + 0.9f, // topp - default, overridden per request + System.nanoTime(), // seed - default + maxTokens, + false, // stream - handled per request + false, // echo - not used in service + getDefaultTornadoVM(), + true + ); } } From 60c67e6333c96db040f0db7008d39ddf945746b7 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Tue, 16 Sep 2025 19:11:11 +0300 Subject: [PATCH 05/12] [WIP] Update service logic --- .../api/config/ModelConfiguration.java | 10 +- .../gpullama3/api/service/LLMService.java | 286 ++++++++++-------- 2 files changed, 167 insertions(+), 129 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java index 9a3399d..72e5a62 100644 --- a/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java +++ b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java @@ -1,8 +1,8 @@ package org.beehive.gpullama3.api.config; +import org.beehive.gpullama3.api.service.LLMService; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.api.service.ModelInitializationService; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -13,15 +13,15 @@ public class ModelConfiguration { * Expose Model as a Spring bean using the initialized service */ @Bean - public Model model(ModelInitializationService initService) { - return initService.getModel(); + public Model model(LLMService llmService) { + return llmService.getModel(); } /** * Expose Options as a Spring bean using the initialized service */ @Bean - public Options options(ModelInitializationService initService) { - return initService.getOptions(); + public Options options(LLMService llmService) { + return llmService.getOptions(); } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java index 41741ac..87fda42 100644 --- a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -1,173 +1,211 @@ package org.beehive.gpullama3.api.service; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.inference.state.State; +import jakarta.annotation.PostConstruct; +import org.beehive.gpullama3.Options; import org.beehive.gpullama3.inference.sampler.Sampler; -import org.springframework.beans.factory.annotation.Autowired; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.loader.ModelLoader; +import org.springframework.boot.ApplicationArguments; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import java.util.*; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.function.IntConsumer; + +import static org.beehive.gpullama3.inference.sampler.Sampler.selectSampler; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; @Service public class LLMService { - @Autowired - private ModelInitializationService initService; + private final ApplicationArguments args; - @Autowired - private TokenizerService tokenizerService; + private Options options; + private Model model; - public CompletableFuture generateCompletion( - String prompt, - int maxTokens, - double temperature, - double topP, - List stopSequences) { + public LLMService(ApplicationArguments args) { + this.args = args; + } - return CompletableFuture.supplyAsync(() -> { - try { - System.out.println("Starting completion generation..."); - System.out.println("Prompt: " + prompt.substring(0, Math.min(50, prompt.length())) + "..."); - System.out.println("Max tokens: " + maxTokens + ", Temperature: " + temperature); - - // Get initialized components - Model model = initService.getModel(); - - // Convert prompt to tokens - List promptTokens = tokenizerService.encode(prompt); - System.out.println("Prompt tokens: " + promptTokens.size()); - - // Convert stop sequences to token sets - Set stopTokens = new HashSet<>(); - if (stopSequences != null) { - for (String stop : stopSequences) { - stopTokens.addAll(tokenizerService.encode(stop)); - } - System.out.println("Stop tokens: " + stopTokens.size()); - } + @PostConstruct + public void init() { + try { + System.out.println("Initializing LLM service..."); + + // Step 1: Parse service options + System.out.println("Step 1: Parsing service options..."); + options = Options.parseServiceOptions(args.getSourceArgs()); + System.out.println("Model path: " + options.modelPath()); + System.out.println("Context length: " + options.maxTokens()); + + // Step 2: Load model weights + System.out.println("\nStep 2: Loading model..."); + System.out.println("Loading model from: " + options.modelPath()); + model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); + System.out.println("✓ Model loaded successfully"); + System.out.println(" Model type: " + model.getClass().getSimpleName()); + System.out.println(" Vocabulary size: " + model.configuration().vocabularySize()); + System.out.println(" Context length: " + model.configuration().contextLength()); + + System.out.println("\n✓ Model service initialization completed successfully!"); + System.out.println("=== Ready to serve requests ===\n"); - // Create custom sampler with request-specific parameters - //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis()); - Sampler sampler = initService.getSampler(); + } catch (Exception e) { + System.err.println("✗ Failed to initialize model service: " + e.getMessage()); + e.printStackTrace(); + throw new RuntimeException("Model initialization failed", e); + } + } - // Create state based on model type - State state = createStateForModel(model); + public String generateResponse(String message, String systemMessage) { + return generateResponse(message, systemMessage, 150, 0.7, 0.9); + } - // Generate tokens using your existing method - List generatedTokens = model.generateTokens( - state, - 0, - promptTokens, - stopTokens, - maxTokens, - sampler, - false, - token -> {} // No callback for non-streaming - ); + public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP) { + try { + // Create sampler and state like runInstructOnce + Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, System.currentTimeMillis()); + State state = model.createNewState(); - // Decode tokens back to text - String result = tokenizerService.decode(generatedTokens); - System.out.println("Generated " + generatedTokens.size() + " tokens"); - System.out.println("Completion finished successfully"); + // Use model's ChatFormat + ChatFormat chatFormat = model.chatFormat(); + List promptTokens = new ArrayList<>(); - return result; + // Add begin of text if needed + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } - } catch (Exception e) { - System.err.println("Error generating completion: " + e.getMessage()); - e.printStackTrace(); - throw new RuntimeException("Error generating completion", e); + // Add system message properly formatted + if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) { + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage))); } - }); - } - public void generateStreamingCompletion( - String prompt, - int maxTokens, - double temperature, - double topP, - List stopSequences, - SseEmitter emitter) { + // Add user message properly formatted + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + // Handle reasoning tokens if needed (for Deepseek-R1-Distill-Qwen) + if (model.shouldIncludeReasoning()) { + List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet()); + promptTokens.addAll(thinkStartTokens); + } + + // Use proper stop tokens from chat format + Set stopTokens = chatFormat.getStopTokens(); + + long startTime = System.currentTimeMillis(); + + // Use CPU path for now (GPU path disabled as noted) + List generatedTokens = model.generateTokens( + state, 0, promptTokens, stopTokens, maxTokens, sampler, false, token -> {} + ); + // Remove stop tokens if present + if (!generatedTokens.isEmpty() && stopTokens.contains(generatedTokens.getLast())) { + generatedTokens.removeLast(); + } + + long duration = System.currentTimeMillis() - startTime; + double tokensPerSecond = generatedTokens.size() * 1000.0 / duration; + System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n", + generatedTokens.size(), duration, tokensPerSecond); + + + String responseText = model.tokenizer().decode(generatedTokens); + + // Add reasoning prefix for non-streaming if needed + if (model.shouldIncludeReasoning()) { + responseText = "\n" + responseText; + } + + return responseText; + + } catch (Exception e) { + System.err.println("FAILED " + e.getMessage()); + throw new RuntimeException("Failed to generate response", e); + } + } + + public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter) { CompletableFuture.runAsync(() -> { try { - System.out.println("Starting streaming completion generation..."); - - Model model = initService.getModel(); + Sampler sampler = selectSampler(model.configuration().vocabularySize(), 0.7f, 0.9f, System.currentTimeMillis()); + State state = model.createNewState(); - List promptTokens = tokenizerService.encode(prompt); + // Use proper chat format like in runInstructOnce + ChatFormat chatFormat = model.chatFormat(); + List promptTokens = new ArrayList<>(); - Set stopTokens = new HashSet<>(); - if (stopSequences != null) { - for (String stop : stopSequences) { - stopTokens.addAll(tokenizerService.encode(stop)); - } + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); } - //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis()); - Sampler sampler = initService.getSampler(); - State state = createStateForModel(model); - - final int[] tokenCount = {0}; + if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) { + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage))); + } - // Streaming callback - IntConsumer tokenCallback = token -> { - try { - String tokenText = tokenizerService.decode(List.of(token)); - tokenCount[0]++; + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - String eventData = String.format( - "data: {\"choices\":[{\"text\":\"%s\",\"index\":0,\"finish_reason\":null}]}\n\n", - escapeJson(tokenText) - ); + // Handle reasoning tokens for streaming + if (model.shouldIncludeReasoning()) { + List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet()); + promptTokens.addAll(thinkStartTokens); + emitter.send(SseEmitter.event().data("\n")); // Output immediately + } - emitter.send(SseEmitter.event().data(eventData)); + Set stopTokens = chatFormat.getStopTokens(); - if (tokenCount[0] % 10 == 0) { - System.out.println("Streamed " + tokenCount[0] + " tokens"); + final int[] tokenCount = {0}; + long startTime = System.currentTimeMillis(); + List generatedTokens = model.generateTokens( + state, 0, promptTokens, stopTokens, 150, sampler, false, + token -> { + try { + // Only display tokens that should be displayed (like in your original) + if (model.tokenizer().shouldDisplayToken(token)) { + String tokenText = model.tokenizer().decode(List.of(token)); + emitter.send(SseEmitter.event().data(tokenText)); + tokenCount[0]++; + } + } catch (Exception e) { + emitter.completeWithError(e); + } } + ); - } catch (Exception e) { - System.err.println("Error in streaming callback: " + e.getMessage()); - emitter.completeWithError(e); - } - }; - - model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenCallback); + long duration = System.currentTimeMillis() - startTime; + double tokensPerSecond = tokenCount[0] * 1000.0 / duration; + System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n", + tokenCount[0], duration, tokensPerSecond); - // Send completion event - emitter.send(SseEmitter.event().data("data: [DONE]\n\n")); + emitter.send(SseEmitter.event().data("[DONE]")); emitter.complete(); - System.out.println("Streaming completion finished. Total tokens: " + tokenCount[0]); - } catch (Exception e) { - System.err.println("Error in streaming generation: " + e.getMessage()); - e.printStackTrace(); + System.err.println("FAILED " + e.getMessage()); emitter.completeWithError(e); } }); } - /** - * Create appropriate State subclass based on the model type - */ - private State createStateForModel(Model model) { - try { - return model.createNewState(); - } catch (Exception e) { - throw new RuntimeException("Failed to create state for model", e); + // Getters for other services to access the initialized components + public Options getOptions() { + if (options == null) { + throw new IllegalStateException("Model service not initialized yet"); } + return options; } - private String escapeJson(String str) { - if (str == null) return ""; - return str.replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("\t", "\\t") - .replace("\\", "\\\\"); + public Model getModel() { + if (model == null) { + throw new IllegalStateException("Model service not initialized yet"); + } + return model; } } \ No newline at end of file From 560e64116fe21739526fc2876fa8fb243f11212d Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Wed, 17 Sep 2025 19:49:42 +0300 Subject: [PATCH 06/12] [WIP] Update service logic --- .../api/controller/ChatController.java | 39 +++++++++++++++++-- .../gpullama3/api/service/LLMService.java | 27 +++++++++++-- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java index df1e036..70c95e9 100644 --- a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java +++ b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java @@ -19,31 +19,44 @@ public class ChatController { @PostMapping("/chat") public Map chat(@RequestBody ChatRequest request) { - logRequest("NON_STREAMING", request, 150, 0.7, 0.9); + // Use request parameters with fallbacks to defaults + int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150; + double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7; + double topP = request.getTopP() != null ? request.getTopP() : 0.9; + + logRequest("NON_STREAMING", request, maxTokens, temperature, topP); if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { throw new IllegalArgumentException("Message cannot be empty"); } - String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage()); + String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage(), + maxTokens, temperature, topP); return Map.of("response", response); } @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) public SseEmitter streamChat(@RequestBody ChatRequest request) { - logRequest("STREAMING", request, 150, 0.7, 0.9); + // Use request parameters with fallbacks to defaults + int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150; + double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7; + double topP = request.getTopP() != null ? request.getTopP() : 0.9; + + logRequest("STREAMING", request, maxTokens, temperature, topP); if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { throw new IllegalArgumentException("Message cannot be empty"); } SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - llmService.generateStreamingResponse(request.getMessage(), request.getSystemMessage(), emitter); + llmService.generateStreamingResponse(request.getMessage(), request.getSystemMessage(), + emitter, maxTokens, temperature, topP); return emitter; } + @GetMapping("/health") public Map health() { return Map.of("status", "healthy", "timestamp", String.valueOf(System.currentTimeMillis())); @@ -69,11 +82,29 @@ private String truncate(String text, int maxLength) { public static class ChatRequest { private String message; private String systemMessage; + private Integer maxTokens; + private Double temperature; + private Double topP; + private Long seed; + // Getters and Setters public String getMessage() { return message; } public void setMessage(String message) { this.message = message; } public String getSystemMessage() { return systemMessage; } public void setSystemMessage(String systemMessage) { this.systemMessage = systemMessage; } + + public Integer getMaxTokens() { return maxTokens; } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + + public Double getTemperature() { return temperature; } + public void setTemperature(Double temperature) { this.temperature = temperature; } + + public Double getTopP() { return topP; } + public void setTopP(Double topP) { this.topP = topP; } + + public Long getSeed() { return seed; } + public void setSeed(Long seed) { this.seed = seed; } } + } diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java index 87fda42..096295e 100644 --- a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -61,14 +61,22 @@ public void init() { } } + /** + * Generate response with default parameters. + */ public String generateResponse(String message, String systemMessage) { return generateResponse(message, systemMessage, 150, 0.7, 0.9); } public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP) { + return generateResponse(message, systemMessage, maxTokens, temperature, topP, null); + } + + public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP, Long seed) { try { // Create sampler and state like runInstructOnce - Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, System.currentTimeMillis()); + long actualSeed = seed != null ? seed : System.currentTimeMillis(); + Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed); State state = model.createNewState(); // Use model's ChatFormat @@ -115,7 +123,6 @@ public String generateResponse(String message, String systemMessage, int maxToke System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n", generatedTokens.size(), duration, tokensPerSecond); - String responseText = model.tokenizer().decode(generatedTokens); // Add reasoning prefix for non-streaming if needed @@ -132,9 +139,20 @@ public String generateResponse(String message, String systemMessage, int maxToke } public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter) { + generateStreamingResponse(message, systemMessage, emitter, 150, 0.7, 0.9); + } + + public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter, + int maxTokens, double temperature, double topP) { + generateStreamingResponse(message, systemMessage, emitter, maxTokens, temperature, topP, null); + } + + public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter, + int maxTokens, double temperature, double topP, Long seed) { CompletableFuture.runAsync(() -> { try { - Sampler sampler = selectSampler(model.configuration().vocabularySize(), 0.7f, 0.9f, System.currentTimeMillis()); + long actualSeed = seed != null ? seed : System.currentTimeMillis(); + Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed); State state = model.createNewState(); // Use proper chat format like in runInstructOnce @@ -164,13 +182,14 @@ public void generateStreamingResponse(String message, String systemMessage, SseE final int[] tokenCount = {0}; long startTime = System.currentTimeMillis(); List generatedTokens = model.generateTokens( - state, 0, promptTokens, stopTokens, 150, sampler, false, + state, 0, promptTokens, stopTokens, maxTokens, sampler, false, token -> { try { // Only display tokens that should be displayed (like in your original) if (model.tokenizer().shouldDisplayToken(token)) { String tokenText = model.tokenizer().decode(List.of(token)); emitter.send(SseEmitter.event().data(tokenText)); + //emitter.send(SseEmitter.event().comment("flush")); tokenCount[0]++; } } catch (Exception e) { From 514375398d62c1fdad7f353f61f4590705b2e528 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 18 Sep 2025 21:11:57 +0300 Subject: [PATCH 07/12] [WIP] Start adding logic for tornado path in service --- llama-tornado | 6 ++++++ src/main/java/org/beehive/gpullama3/Options.java | 8 +++++++- .../org/beehive/gpullama3/api/service/LLMService.java | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/llama-tornado b/llama-tornado index c618130..237cabf 100755 --- a/llama-tornado +++ b/llama-tornado @@ -248,6 +248,12 @@ class LlamaRunner: if args.service: print("Starting GPULlama3.java REST API Service...") print(f"Model: {args.model_path}") + # Display GPU/backend configuration + if args.use_gpu: + print(f"GPU Acceleration: Enabled ({args.backend.value.upper()} backend)") + print(f"GPU Memory: {args.gpu_memory}") + else: + print("GPU Acceleration: Disabled (CPU mode)") print("API endpoints available at:") print(" - http://localhost:8080/chat") print(" - http://localhost:8080/chat/stream") diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 67d0ef3..9469c59 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -132,6 +132,7 @@ public static Options parseOptions(String[] args) { public static Options parseServiceOptions(String[] args) { Path modelPath = null; int maxTokens = 512; // Default context length + Boolean useTornadovm = null; for (int i = 0; i < args.length; i++) { String optionName = args[i]; @@ -152,11 +153,16 @@ public static Options parseServiceOptions(String[] args) { switch (optionName) { case "--model", "-m" -> modelPath = Paths.get(nextArg); case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg); + case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg); } } require(modelPath != null, "Missing argument: --model is required"); + if (useTornadovm == null) { + useTornadovm = getDefaultTornadoVM(); + } + // Create service-mode Options object return new Options( modelPath, @@ -170,7 +176,7 @@ public static Options parseServiceOptions(String[] args) { maxTokens, false, // stream - handled per request false, // echo - not used in service - getDefaultTornadoVM(), + useTornadovm, true ); } diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java index 096295e..23e2c64 100644 --- a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -45,7 +45,7 @@ public void init() { // Step 2: Load model weights System.out.println("\nStep 2: Loading model..."); System.out.println("Loading model from: " + options.modelPath()); - model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); + model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm()); System.out.println("✓ Model loaded successfully"); System.out.println(" Model type: " + model.getClass().getSimpleName()); System.out.println(" Vocabulary size: " + model.configuration().vocabularySize()); From 6b76b08421ba95fce8d94c44a98048ad0ccec0ba Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Thu, 18 Sep 2025 21:12:52 +0300 Subject: [PATCH 08/12] Minor changes --- .../gpullama3/api/controller/ChatController.java | 10 ++++++++-- .../beehive/gpullama3/api/service/LLMService.java | 14 +++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java index 70c95e9..ba40f27 100644 --- a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java +++ b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java @@ -50,8 +50,14 @@ public SseEmitter streamChat(@RequestBody ChatRequest request) { } SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); - llmService.generateStreamingResponse(request.getMessage(), request.getSystemMessage(), - emitter, maxTokens, temperature, topP); + llmService.generateStreamingResponse( + request.getMessage(), + request.getSystemMessage(), + emitter, + maxTokens, + temperature, + topP, + request.getSeed()); return emitter; } diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java index 23e2c64..7f6a44c 100644 --- a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -138,15 +138,6 @@ public String generateResponse(String message, String systemMessage, int maxToke } } - public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter) { - generateStreamingResponse(message, systemMessage, emitter, 150, 0.7, 0.9); - } - - public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter, - int maxTokens, double temperature, double topP) { - generateStreamingResponse(message, systemMessage, emitter, maxTokens, temperature, topP, null); - } - public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter, int maxTokens, double temperature, double topP, Long seed) { CompletableFuture.runAsync(() -> { @@ -170,11 +161,12 @@ public void generateStreamingResponse(String message, String systemMessage, SseE promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message))); promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - // Handle reasoning tokens for streaming + // Include reasoning for Deepseek-R1-Distill-Qwen if (model.shouldIncludeReasoning()) { List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet()); promptTokens.addAll(thinkStartTokens); - emitter.send(SseEmitter.event().data("\n")); // Output immediately + // We are in streaming, immediately output the think start + emitter.send(SseEmitter.event().data("\n")); } Set stopTokens = chatFormat.getStopTokens(); From 17a6275c500a5a4287147cf837db316f7ad33231 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 19 Sep 2025 14:36:07 +0300 Subject: [PATCH 09/12] Set tornado default to false for service --- src/main/java/org/beehive/gpullama3/Options.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 9469c59..606621a 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -159,8 +159,9 @@ public static Options parseServiceOptions(String[] args) { require(modelPath != null, "Missing argument: --model is required"); + // Do not use tornado by default if (useTornadovm == null) { - useTornadovm = getDefaultTornadoVM(); + useTornadovm = false; } // Create service-mode Options object From 50d7536adb28e32a60697dee842b3eb2822a0635 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 19 Sep 2025 15:15:04 +0300 Subject: [PATCH 10/12] Use a Consumer on a generated token --- .../gpullama3/api/service/LLMService.java | 39 +++++++++++-------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java index 7f6a44c..0f0db25 100644 --- a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -15,6 +15,7 @@ import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.function.IntConsumer; import static org.beehive.gpullama3.inference.sampler.Sampler.selectSampler; import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; @@ -172,23 +173,29 @@ public void generateStreamingResponse(String message, String systemMessage, SseE Set stopTokens = chatFormat.getStopTokens(); final int[] tokenCount = {0}; - long startTime = System.currentTimeMillis(); - List generatedTokens = model.generateTokens( - state, 0, promptTokens, stopTokens, maxTokens, sampler, false, - token -> { - try { - // Only display tokens that should be displayed (like in your original) - if (model.tokenizer().shouldDisplayToken(token)) { - String tokenText = model.tokenizer().decode(List.of(token)); - emitter.send(SseEmitter.event().data(tokenText)); - //emitter.send(SseEmitter.event().comment("flush")); - tokenCount[0]++; - } - } catch (Exception e) { - emitter.completeWithError(e); - } + IntConsumer tokenConsumer = token -> { + try { + // Only display tokens that should be displayed (like in your original) + if (model.tokenizer().shouldDisplayToken(token)) { + String tokenText = model.tokenizer().decode(List.of(token)); + emitter.send(SseEmitter.event().data(tokenText)); + //emitter.send(SseEmitter.event().comment("flush")); + tokenCount[0]++; } - ); + } catch (Exception e) { + emitter.completeWithError(e); + } + }; + + + long startTime = System.currentTimeMillis(); + if (options.useTornadovm()) { + // GPU path + throw new UnsupportedOperationException("Tornadovm is not supported"); + } else { + // CPU path + model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer); + } long duration = System.currentTimeMillis() - startTime; double tokensPerSecond = tokenCount[0] * 1000.0 / duration; From c557b5ea4730834b11aaef364796fc48fd8fb427 Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 19 Sep 2025 18:21:48 +0300 Subject: [PATCH 11/12] Add get methods --- .../org/beehive/gpullama3/auxiliary/LastRunMetrics.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java b/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java index 0d41180..6dfb0d5 100644 --- a/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java +++ b/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java @@ -21,6 +21,15 @@ public static void setMetrics(int tokens, double seconds) { latestMetrics = new LastRunMetrics(tokens, seconds); } + public static LastRunMetrics getMetrics() { + return latestMetrics; + } + + public static String getMetricsString() { + double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds(); + return String.format("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds()); + } + /** * Prints the metrics from the latest run to stderr */ From 93d0c93d75d6ad4f365452c1d8edb1abb1de924a Mon Sep 17 00:00:00 2001 From: Orion Papadakis Date: Fri, 19 Sep 2025 18:23:13 +0300 Subject: [PATCH 12/12] Add GPU path to Rest API --- .../gpullama3/api/service/LLMService.java | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java index 0f0db25..9fca3e2 100644 --- a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -2,11 +2,13 @@ import jakarta.annotation.PostConstruct; import org.beehive.gpullama3.Options; +import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.loader.ModelLoader; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import org.springframework.boot.ApplicationArguments; import org.springframework.stereotype.Service; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; @@ -179,7 +181,6 @@ public void generateStreamingResponse(String message, String systemMessage, SseE if (model.tokenizer().shouldDisplayToken(token)) { String tokenText = model.tokenizer().decode(List.of(token)); emitter.send(SseEmitter.event().data(tokenText)); - //emitter.send(SseEmitter.event().comment("flush")); tokenCount[0]++; } } catch (Exception e) { @@ -187,20 +188,33 @@ public void generateStreamingResponse(String message, String systemMessage, SseE } }; + // Initialize TornadoVM plan once per request if GPU path is enabled + TornadoVMMasterPlan tornadoVMPlan = null; + if (options.useTornadovm()) { + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); + } - long startTime = System.currentTimeMillis(); + // Select execution path if (options.useTornadovm()) { // GPU path - throw new UnsupportedOperationException("Tornadovm is not supported"); + model.generateTokensGPU(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer, tornadoVMPlan); } else { // CPU path model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer); } - long duration = System.currentTimeMillis() - startTime; - double tokensPerSecond = tokenCount[0] * 1000.0 / duration; - System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n", - tokenCount[0], duration, tokensPerSecond); + LastRunMetrics metrics = LastRunMetrics.getMetrics(); + double seconds = metrics.totalSeconds(); + int tokens = metrics.totalTokens(); + double tokensPerSecond = tokens / seconds; + System.out.printf("COMPLETED - [ achieved tok/s: %.2f. Tokens: %d, seconds: %.2f ]\n", + tokensPerSecond, tokens, seconds); + + // Send metrics as named event before [DONE] + String metricsString = LastRunMetrics.getMetricsString(); + if (!metricsString.isEmpty()) { + emitter.send(SseEmitter.event().name("metrics").data(metricsString)); + } emitter.send(SseEmitter.event().data("[DONE]")); emitter.complete();