diff --git a/llama-tornado b/llama-tornado index b59473f2..237cabf0 100755 --- a/llama-tornado +++ b/llama-tornado @@ -169,11 +169,14 @@ class LlamaRunner: ] ) + # Choose main class based on mode + main_class = "org.beehive.gpullama3.api.LLMApiApplication" if args.service else "org.beehive.gpullama3.LlamaApp" + module_config.extend( [ "-cp", self._find_llama_jar(), - "org.beehive.gpullama3.LlamaApp", + main_class, ] ) cmd.extend(module_config) @@ -200,33 +203,36 @@ class LlamaRunner: def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]: """Add LLaMA-specific arguments to the command.""" + + # For service mode, only pass the model path and max-tokens + if hasattr(args, 'service') and args.service: + llama_args = [ + "--model", args.model_path, + "--max-tokens", str(args.max_tokens), + ] + return cmd + llama_args + llama_args = [ - "-m", - args.model_path, - "--temperature", - str(args.temperature), - "--top-p", - str(args.top_p), - "--seed", - str(args.seed), - "--max-tokens", - str(args.max_tokens), - "--stream", - str(args.stream).lower(), - "--echo", - str(args.echo).lower(), + "--model", args.model_path, + "--temperature", str(args.temperature), + "--top-p", str(args.top_p), + "--seed", str(args.seed), + "--max-tokens", str(args.max_tokens), + "--stream", str(args.stream).lower(), + "--echo", str(args.echo).lower(), + "--instruct" # Both modes use instruct ] - if args.prompt: - llama_args.extend(["-p", args.prompt]) + # Only add prompt-related args for standalone mode + if not hasattr(args, 'service') or not args.service: + if hasattr(args, 'prompt') and args.prompt: + llama_args.extend(["-p", args.prompt]) - if args.system_prompt: - llama_args.extend(["-sp", args.system_prompt]) + if hasattr(args, 'system_prompt') and args.system_prompt: + llama_args.extend(["-sp", args.system_prompt]) - if args.interactive: - llama_args.append("--interactive") - elif args.instruct: - llama_args.append("--instruct") + if hasattr(args, 'interactive') and args.interactive: + llama_args[-1] = "--interactive" # Replace --instruct return cmd + llama_args @@ -238,6 +244,30 @@ class LlamaRunner: cmd = self._build_base_command(args) cmd = self._add_llama_args(cmd, args) + # Show service-specific information + if args.service: + print("Starting GPULlama3.java REST API Service...") + print(f"Model: {args.model_path}") + # Display GPU/backend configuration + if args.use_gpu: + print(f"GPU Acceleration: Enabled ({args.backend.value.upper()} backend)") + print(f"GPU Memory: {args.gpu_memory}") + else: + print("GPU Acceleration: Disabled (CPU mode)") + print("API endpoints available at:") + print(" - http://localhost:8080/chat") + print(" - http://localhost:8080/chat/stream") + print(" - http://localhost:8080/health") + print("") + print("Example usage:") + print(' curl -X POST http://localhost:8080/chat \\') + print(' -H "Content-Type: application/json" \\') + print(' -d \'{"message": "Hello!"}\'') + print("") + print("Press Ctrl+C to stop the service") + print("-" * 60) + + # Print command if requested (before verbose output) if args.show_command: print("Full Java command:") @@ -390,6 +420,11 @@ def create_parser() -> argparse.ArgumentParser: default=True, help="Run in instruction mode (default)", ) + mode_group.add_argument( + "--service", + action="store_true", + help="Run as REST API service instead of standalone application" + ) # Hardware configuration hw_group = parser.add_argument_group("Hardware Configuration") diff --git a/pom.xml b/pom.xml index 624fc671..a031256d 100644 --- a/pom.xml +++ b/pom.xml @@ -12,6 +12,9 @@ 21 21 UTF-8 + 3.2.0 + 3.0.0 + 2.16.1 @@ -32,6 +35,26 @@ tornado-runtime 1.1.2-dev + + + + org.springframework.boot + spring-boot-starter-web + ${spring.boot.version} + + + + jakarta.annotation + jakarta.annotation-api + ${jakarta.version} + + + + + com.fasterxml.jackson.core + jackson-databind + ${jackson.version} + diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java index 98924481..606621a7 100644 --- a/src/main/java/org/beehive/gpullama3/Options.java +++ b/src/main/java/org/beehive/gpullama3/Options.java @@ -5,12 +5,15 @@ import java.nio.file.Paths; public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, - boolean useTornadovm) { + boolean useTornadovm, boolean serviceMode) { public static final int DEFAULT_MAX_TOKENS = 1024; public Options { - require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); + // Skip prompt validation in service mode + if (!serviceMode) { + require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\""); + } require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); } @@ -61,7 +64,7 @@ public static Options getDefaultOptions() { boolean echo = false; boolean useTornadoVM = getDefaultTornadoVM(); - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false); } public static Options parseOptions(String[] args) { @@ -123,6 +126,59 @@ public static Options parseOptions(String[] args) { useTornadovm = getDefaultTornadoVM(); } - return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm); + return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, false); + } + + public static Options parseServiceOptions(String[] args) { + Path modelPath = null; + int maxTokens = 512; // Default context length + Boolean useTornadovm = null; + + for (int i = 0; i < args.length; i++) { + String optionName = args[i]; + require(optionName.startsWith("-"), "Invalid option %s", optionName); + + String nextArg; + if (optionName.contains("=")) { + String[] parts = optionName.split("=", 2); + optionName = parts[0]; + nextArg = parts[1]; + } else { + if (i + 1 >= args.length) continue; // Skip if no next arg + nextArg = args[i + 1]; + i += 1; // skip arg + } + + // Only parse these options in service mode + switch (optionName) { + case "--model", "-m" -> modelPath = Paths.get(nextArg); + case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg); + case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg); + } + } + + require(modelPath != null, "Missing argument: --model is required"); + + // Do not use tornado by default + if (useTornadovm == null) { + useTornadovm = false; + } + + // Create service-mode Options object + return new Options( + modelPath, + null, // prompt - not used in service + null, // systemPrompt - handled per request + null, // suffix - not used + false, // interactive - not used in service + 0.7f, // temperature - default, overridden per request + 0.9f, // topp - default, overridden per request + System.nanoTime(), // seed - default + maxTokens, + false, // stream - handled per request + false, // echo - not used in service + useTornadovm, + true + ); } } diff --git a/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java b/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java new file mode 100644 index 00000000..2ff1d60d --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java @@ -0,0 +1,16 @@ +package org.beehive.gpullama3.api; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication(scanBasePackages = "org.beehive.gpullama3") +public class LLMApiApplication { + + public static void main(String[] args) { + System.out.println("Starting TornadoVM LLM API Server..."); + System.out.println("Command line arguments: " + String.join(" ", args)); + + // Let Options.parseOptions() handle validation - no duplication + SpringApplication.run(LLMApiApplication.class, args); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java new file mode 100644 index 00000000..72e5a624 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java @@ -0,0 +1,27 @@ +package org.beehive.gpullama3.api.config; + +import org.beehive.gpullama3.api.service.LLMService; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.Options; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class ModelConfiguration { + + /** + * Expose Model as a Spring bean using the initialized service + */ + @Bean + public Model model(LLMService llmService) { + return llmService.getModel(); + } + + /** + * Expose Options as a Spring bean using the initialized service + */ + @Bean + public Options options(LLMService llmService) { + return llmService.getOptions(); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java new file mode 100644 index 00000000..ba40f274 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java @@ -0,0 +1,116 @@ +package org.beehive.gpullama3.api.controller; + +import org.beehive.gpullama3.api.service.LLMService; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.Map; + +@RestController +public class ChatController { + + @Autowired + private LLMService llmService; + + @PostMapping("/chat") + public Map chat(@RequestBody ChatRequest request) { + // Use request parameters with fallbacks to defaults + int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150; + double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7; + double topP = request.getTopP() != null ? request.getTopP() : 0.9; + + logRequest("NON_STREAMING", request, maxTokens, temperature, topP); + + if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { + throw new IllegalArgumentException("Message cannot be empty"); + } + + String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage(), + maxTokens, temperature, topP); + + return Map.of("response", response); + } + + @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter streamChat(@RequestBody ChatRequest request) { + // Use request parameters with fallbacks to defaults + int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150; + double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7; + double topP = request.getTopP() != null ? request.getTopP() : 0.9; + + logRequest("STREAMING", request, maxTokens, temperature, topP); + + if (request.getMessage() == null || request.getMessage().trim().isEmpty()) { + throw new IllegalArgumentException("Message cannot be empty"); + } + + SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); + llmService.generateStreamingResponse( + request.getMessage(), + request.getSystemMessage(), + emitter, + maxTokens, + temperature, + topP, + request.getSeed()); + + return emitter; + } + + + @GetMapping("/health") + public Map health() { + return Map.of("status", "healthy", "timestamp", String.valueOf(System.currentTimeMillis())); + } + + private void logRequest(String type, ChatRequest request, int maxTokens, double temperature, double topP) { + System.out.printf("REQUEST [%s] user='%s' system='%s' maxTokens=%d temp=%.2f topP=%.2f%n", + type, + truncate(request.getMessage(), 100), + request.getSystemMessage() != null ? truncate(request.getSystemMessage(), 40) : "none", + maxTokens, + temperature, + topP + ); + } + + private String truncate(String text, int maxLength) { + if (text == null) return "null"; + return text.length() > maxLength ? text.substring(0, maxLength) + "..." : text; + } + + // Simple request class for custom parameters + public static class ChatRequest { + private String message; + private String systemMessage; + private Integer maxTokens; + private Double temperature; + private Double topP; + private Long seed; + + // Getters and Setters + public String getMessage() { return message; } + public void setMessage(String message) { this.message = message; } + + public String getSystemMessage() { return systemMessage; } + public void setSystemMessage(String systemMessage) { this.systemMessage = systemMessage; } + + public Integer getMaxTokens() { return maxTokens; } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + + public Double getTemperature() { return temperature; } + public void setTemperature(Double temperature) { this.temperature = temperature; } + + public Double getTopP() { return topP; } + public void setTopP(Double topP) { this.topP = topP; } + + public Long getSeed() { return seed; } + public void setSeed(Long seed) { this.seed = seed; } + } + +} diff --git a/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java b/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java new file mode 100644 index 00000000..56a304e4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java @@ -0,0 +1,59 @@ +package org.beehive.gpullama3.api.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +public class CompletionRequest { + private String model = "gpullama3"; + private String prompt; + + @JsonProperty("max_tokens") + private Integer maxTokens = 100; + + private Double temperature = 0.7; + + @JsonProperty("top_p") + private Double topP = 0.9; + + @JsonProperty("stop") + private List stopSequences; + + private Boolean stream = false; + + // Constructors + public CompletionRequest() {} + + // Getters and Setters + public String getModel() { return model; } + public void setModel(String model) { this.model = model; } + + public String getPrompt() { return prompt; } + public void setPrompt(String prompt) { this.prompt = prompt; } + + public Integer getMaxTokens() { return maxTokens; } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + + public Double getTemperature() { return temperature; } + public void setTemperature(Double temperature) { this.temperature = temperature; } + + public Double getTopP() { return topP; } + public void setTopP(Double topP) { this.topP = topP; } + + public List getStopSequences() { return stopSequences; } + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } + + public Boolean getStream() { return stream; } + public void setStream(Boolean stream) { this.stream = stream; } + + @Override + public String toString() { + return "CompletionRequest{" + + "model='" + model + '\'' + + ", prompt='" + (prompt != null ? prompt.substring(0, Math.min(50, prompt.length())) + "..." : null) + '\'' + + ", maxTokens=" + maxTokens + + ", temperature=" + temperature + + ", topP=" + topP + + ", stream=" + stream + + '}'; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java b/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java new file mode 100644 index 00000000..b3cecf3a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java @@ -0,0 +1,93 @@ +package org.beehive.gpullama3.api.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +public class CompletionResponse { + private String id; + private String object = "text_completion"; + private Long created; + private String model; + private List choices; + private Usage usage; + + public static class Choice { + private String text; + private Integer index; + + @JsonProperty("finish_reason") + private String finishReason; + + public Choice() {} + + public Choice(String text, Integer index, String finishReason) { + this.text = text; + this.index = index; + this.finishReason = finishReason; + } + + // Getters and Setters + public String getText() { return text; } + public void setText(String text) { this.text = text; } + + public Integer getIndex() { return index; } + public void setIndex(Integer index) { this.index = index; } + + public String getFinishReason() { return finishReason; } + public void setFinishReason(String finishReason) { this.finishReason = finishReason; } + } + + public static class Usage { + @JsonProperty("prompt_tokens") + private Integer promptTokens; + + @JsonProperty("completion_tokens") + private Integer completionTokens; + + @JsonProperty("total_tokens") + private Integer totalTokens; + + public Usage() {} + + public Usage(Integer promptTokens, Integer completionTokens) { + this.promptTokens = promptTokens; + this.completionTokens = completionTokens; + this.totalTokens = promptTokens + completionTokens; + } + + // Getters and Setters + public Integer getPromptTokens() { return promptTokens; } + public void setPromptTokens(Integer promptTokens) { this.promptTokens = promptTokens; } + + public Integer getCompletionTokens() { return completionTokens; } + public void setCompletionTokens(Integer completionTokens) { this.completionTokens = completionTokens; } + + public Integer getTotalTokens() { return totalTokens; } + public void setTotalTokens(Integer totalTokens) { this.totalTokens = totalTokens; } + } + + // Constructors + public CompletionResponse() { + this.id = "cmpl-" + System.currentTimeMillis(); + this.created = System.currentTimeMillis() / 1000; + } + + // Getters and Setters + public String getId() { return id; } + public void setId(String id) { this.id = id; } + + public String getObject() { return object; } + public void setObject(String object) { this.object = object; } + + public Long getCreated() { return created; } + public void setCreated(Long created) { this.created = created; } + + public String getModel() { return model; } + public void setModel(String model) { this.model = model; } + + public List getChoices() { return choices; } + public void setChoices(List choices) { this.choices = choices; } + + public Usage getUsage() { return usage; } + public void setUsage(Usage usage) { this.usage = usage; } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java new file mode 100644 index 00000000..9fca3e2a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java @@ -0,0 +1,243 @@ +package org.beehive.gpullama3.api.service; + +import jakarta.annotation.PostConstruct; +import org.beehive.gpullama3.Options; +import org.beehive.gpullama3.auxiliary.LastRunMetrics; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.loader.ModelLoader; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import org.springframework.boot.ApplicationArguments; +import org.springframework.stereotype.Service; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.IntConsumer; + +import static org.beehive.gpullama3.inference.sampler.Sampler.selectSampler; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel; + +@Service +public class LLMService { + + private final ApplicationArguments args; + + private Options options; + private Model model; + + public LLMService(ApplicationArguments args) { + this.args = args; + } + + @PostConstruct + public void init() { + try { + System.out.println("Initializing LLM service..."); + + // Step 1: Parse service options + System.out.println("Step 1: Parsing service options..."); + options = Options.parseServiceOptions(args.getSourceArgs()); + System.out.println("Model path: " + options.modelPath()); + System.out.println("Context length: " + options.maxTokens()); + + // Step 2: Load model weights + System.out.println("\nStep 2: Loading model..."); + System.out.println("Loading model from: " + options.modelPath()); + model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm()); + System.out.println("āœ“ Model loaded successfully"); + System.out.println(" Model type: " + model.getClass().getSimpleName()); + System.out.println(" Vocabulary size: " + model.configuration().vocabularySize()); + System.out.println(" Context length: " + model.configuration().contextLength()); + + System.out.println("\nāœ“ Model service initialization completed successfully!"); + System.out.println("=== Ready to serve requests ===\n"); + + } catch (Exception e) { + System.err.println("āœ— Failed to initialize model service: " + e.getMessage()); + e.printStackTrace(); + throw new RuntimeException("Model initialization failed", e); + } + } + + /** + * Generate response with default parameters. + */ + public String generateResponse(String message, String systemMessage) { + return generateResponse(message, systemMessage, 150, 0.7, 0.9); + } + + public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP) { + return generateResponse(message, systemMessage, maxTokens, temperature, topP, null); + } + + public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP, Long seed) { + try { + // Create sampler and state like runInstructOnce + long actualSeed = seed != null ? seed : System.currentTimeMillis(); + Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed); + State state = model.createNewState(); + + // Use model's ChatFormat + ChatFormat chatFormat = model.chatFormat(); + List promptTokens = new ArrayList<>(); + + // Add begin of text if needed + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + + // Add system message properly formatted + if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) { + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage))); + } + + // Add user message properly formatted + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + // Handle reasoning tokens if needed (for Deepseek-R1-Distill-Qwen) + if (model.shouldIncludeReasoning()) { + List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet()); + promptTokens.addAll(thinkStartTokens); + } + + // Use proper stop tokens from chat format + Set stopTokens = chatFormat.getStopTokens(); + + long startTime = System.currentTimeMillis(); + + // Use CPU path for now (GPU path disabled as noted) + List generatedTokens = model.generateTokens( + state, 0, promptTokens, stopTokens, maxTokens, sampler, false, token -> {} + ); + + // Remove stop tokens if present + if (!generatedTokens.isEmpty() && stopTokens.contains(generatedTokens.getLast())) { + generatedTokens.removeLast(); + } + + long duration = System.currentTimeMillis() - startTime; + double tokensPerSecond = generatedTokens.size() * 1000.0 / duration; + System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n", + generatedTokens.size(), duration, tokensPerSecond); + + String responseText = model.tokenizer().decode(generatedTokens); + + // Add reasoning prefix for non-streaming if needed + if (model.shouldIncludeReasoning()) { + responseText = "\n" + responseText; + } + + return responseText; + + } catch (Exception e) { + System.err.println("FAILED " + e.getMessage()); + throw new RuntimeException("Failed to generate response", e); + } + } + + public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter, + int maxTokens, double temperature, double topP, Long seed) { + CompletableFuture.runAsync(() -> { + try { + long actualSeed = seed != null ? seed : System.currentTimeMillis(); + Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed); + State state = model.createNewState(); + + // Use proper chat format like in runInstructOnce + ChatFormat chatFormat = model.chatFormat(); + List promptTokens = new ArrayList<>(); + + if (model.shouldAddBeginOfText()) { + promptTokens.add(chatFormat.getBeginOfText()); + } + + if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) { + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage))); + } + + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message))); + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + + // Include reasoning for Deepseek-R1-Distill-Qwen + if (model.shouldIncludeReasoning()) { + List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet()); + promptTokens.addAll(thinkStartTokens); + // We are in streaming, immediately output the think start + emitter.send(SseEmitter.event().data("\n")); + } + + Set stopTokens = chatFormat.getStopTokens(); + + final int[] tokenCount = {0}; + IntConsumer tokenConsumer = token -> { + try { + // Only display tokens that should be displayed (like in your original) + if (model.tokenizer().shouldDisplayToken(token)) { + String tokenText = model.tokenizer().decode(List.of(token)); + emitter.send(SseEmitter.event().data(tokenText)); + tokenCount[0]++; + } + } catch (Exception e) { + emitter.completeWithError(e); + } + }; + + // Initialize TornadoVM plan once per request if GPU path is enabled + TornadoVMMasterPlan tornadoVMPlan = null; + if (options.useTornadovm()) { + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); + } + + // Select execution path + if (options.useTornadovm()) { + // GPU path + model.generateTokensGPU(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer, tornadoVMPlan); + } else { + // CPU path + model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer); + } + + LastRunMetrics metrics = LastRunMetrics.getMetrics(); + double seconds = metrics.totalSeconds(); + int tokens = metrics.totalTokens(); + double tokensPerSecond = tokens / seconds; + System.out.printf("COMPLETED - [ achieved tok/s: %.2f. Tokens: %d, seconds: %.2f ]\n", + tokensPerSecond, tokens, seconds); + + // Send metrics as named event before [DONE] + String metricsString = LastRunMetrics.getMetricsString(); + if (!metricsString.isEmpty()) { + emitter.send(SseEmitter.event().name("metrics").data(metricsString)); + } + + emitter.send(SseEmitter.event().data("[DONE]")); + emitter.complete(); + + } catch (Exception e) { + System.err.println("FAILED " + e.getMessage()); + emitter.completeWithError(e); + } + }); + } + + // Getters for other services to access the initialized components + public Options getOptions() { + if (options == null) { + throw new IllegalStateException("Model service not initialized yet"); + } + return options; + } + + public Model getModel() { + if (model == null) { + throw new IllegalStateException("Model service not initialized yet"); + } + return model; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java b/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java index 0d411801..6dfb0d58 100644 --- a/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java +++ b/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java @@ -21,6 +21,15 @@ public static void setMetrics(int tokens, double seconds) { latestMetrics = new LastRunMetrics(tokens, seconds); } + public static LastRunMetrics getMetrics() { + return latestMetrics; + } + + public static String getMetricsString() { + double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds(); + return String.format("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds()); + } + /** * Prints the metrics from the latest run to stderr */