Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 58 additions & 23 deletions llama-tornado
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,14 @@ class LlamaRunner:
]
)

# Choose main class based on mode
main_class = "org.beehive.gpullama3.api.LLMApiApplication" if args.service else "org.beehive.gpullama3.LlamaApp"

module_config.extend(
[
"-cp",
self._find_llama_jar(),
"org.beehive.gpullama3.LlamaApp",
main_class,
]
)
cmd.extend(module_config)
Expand All @@ -200,33 +203,36 @@ class LlamaRunner:

def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]:
"""Add LLaMA-specific arguments to the command."""

# For service mode, only pass the model path and max-tokens
if hasattr(args, 'service') and args.service:
llama_args = [
"--model", args.model_path,
"--max-tokens", str(args.max_tokens),
]
return cmd + llama_args

llama_args = [
"-m",
args.model_path,
"--temperature",
str(args.temperature),
"--top-p",
str(args.top_p),
"--seed",
str(args.seed),
"--max-tokens",
str(args.max_tokens),
"--stream",
str(args.stream).lower(),
"--echo",
str(args.echo).lower(),
"--model", args.model_path,
"--temperature", str(args.temperature),
"--top-p", str(args.top_p),
"--seed", str(args.seed),
"--max-tokens", str(args.max_tokens),
"--stream", str(args.stream).lower(),
"--echo", str(args.echo).lower(),
"--instruct" # Both modes use instruct
]

if args.prompt:
llama_args.extend(["-p", args.prompt])
# Only add prompt-related args for standalone mode
if not hasattr(args, 'service') or not args.service:
if hasattr(args, 'prompt') and args.prompt:
llama_args.extend(["-p", args.prompt])

if args.system_prompt:
llama_args.extend(["-sp", args.system_prompt])
if hasattr(args, 'system_prompt') and args.system_prompt:
llama_args.extend(["-sp", args.system_prompt])

if args.interactive:
llama_args.append("--interactive")
elif args.instruct:
llama_args.append("--instruct")
if hasattr(args, 'interactive') and args.interactive:
llama_args[-1] = "--interactive" # Replace --instruct

return cmd + llama_args

Expand All @@ -238,6 +244,30 @@ class LlamaRunner:
cmd = self._build_base_command(args)
cmd = self._add_llama_args(cmd, args)

# Show service-specific information
if args.service:
print("Starting GPULlama3.java REST API Service...")
print(f"Model: {args.model_path}")
# Display GPU/backend configuration
if args.use_gpu:
print(f"GPU Acceleration: Enabled ({args.backend.value.upper()} backend)")
print(f"GPU Memory: {args.gpu_memory}")
else:
print("GPU Acceleration: Disabled (CPU mode)")
print("API endpoints available at:")
print(" - http://localhost:8080/chat")
print(" - http://localhost:8080/chat/stream")
print(" - http://localhost:8080/health")
print("")
print("Example usage:")
print(' curl -X POST http://localhost:8080/chat \\')
print(' -H "Content-Type: application/json" \\')
print(' -d \'{"message": "Hello!"}\'')
print("")
print("Press Ctrl+C to stop the service")
print("-" * 60)


# Print command if requested (before verbose output)
if args.show_command:
print("Full Java command:")
Expand Down Expand Up @@ -390,6 +420,11 @@ def create_parser() -> argparse.ArgumentParser:
default=True,
help="Run in instruction mode (default)",
)
mode_group.add_argument(
"--service",
action="store_true",
help="Run as REST API service instead of standalone application"
)

# Hardware configuration
hw_group = parser.add_argument_group("Hardware Configuration")
Expand Down
23 changes: 23 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
<maven.compiler.source>21</maven.compiler.source>
<maven.compiler.target>21</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spring.boot.version>3.2.0</spring.boot.version>
<jakarta.version>3.0.0</jakarta.version>
<jackson.version>2.16.1</jackson.version>
</properties>

<dependencies>
Expand All @@ -32,6 +35,26 @@
<artifactId>tornado-runtime</artifactId>
<version>1.1.2-dev</version>
</dependency>

<!-- Spring Boot Starter Web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<version>${spring.boot.version}</version>
</dependency>

<dependency>
<groupId>jakarta.annotation</groupId>
<artifactId>jakarta.annotation-api</artifactId>
<version>${jakarta.version}</version>
</dependency>

<!-- For JSON processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
</dependencies>

<build>
Expand Down
64 changes: 60 additions & 4 deletions src/main/java/org/beehive/gpullama3/Options.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import java.nio.file.Paths;

public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
boolean useTornadovm) {
boolean useTornadovm, boolean serviceMode) {

public static final int DEFAULT_MAX_TOKENS = 1024;

public Options {
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
// Skip prompt validation in service mode
if (!serviceMode) {
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
}
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
Expand Down Expand Up @@ -61,7 +64,7 @@ public static Options getDefaultOptions() {
boolean echo = false;
boolean useTornadoVM = getDefaultTornadoVM();

return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false);
}

public static Options parseOptions(String[] args) {
Expand Down Expand Up @@ -123,6 +126,59 @@ public static Options parseOptions(String[] args) {
useTornadovm = getDefaultTornadoVM();
}

return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, false);
}

public static Options parseServiceOptions(String[] args) {
Path modelPath = null;
int maxTokens = 512; // Default context length
Boolean useTornadovm = null;

for (int i = 0; i < args.length; i++) {
String optionName = args[i];
require(optionName.startsWith("-"), "Invalid option %s", optionName);

String nextArg;
if (optionName.contains("=")) {
String[] parts = optionName.split("=", 2);
optionName = parts[0];
nextArg = parts[1];
} else {
if (i + 1 >= args.length) continue; // Skip if no next arg
nextArg = args[i + 1];
i += 1; // skip arg
}

// Only parse these options in service mode
switch (optionName) {
case "--model", "-m" -> modelPath = Paths.get(nextArg);
case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg);
}
}

require(modelPath != null, "Missing argument: --model <path> is required");

// Do not use tornado by default
if (useTornadovm == null) {
useTornadovm = false;
}

// Create service-mode Options object
return new Options(
modelPath,
null, // prompt - not used in service
null, // systemPrompt - handled per request
null, // suffix - not used
false, // interactive - not used in service
0.7f, // temperature - default, overridden per request
0.9f, // topp - default, overridden per request
System.nanoTime(), // seed - default
maxTokens,
false, // stream - handled per request
false, // echo - not used in service
useTornadovm,
true
);
}
}
16 changes: 16 additions & 0 deletions src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.beehive.gpullama3.api;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication(scanBasePackages = "org.beehive.gpullama3")
public class LLMApiApplication {

public static void main(String[] args) {
System.out.println("Starting TornadoVM LLM API Server...");
System.out.println("Command line arguments: " + String.join(" ", args));

// Let Options.parseOptions() handle validation - no duplication
SpringApplication.run(LLMApiApplication.class, args);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package org.beehive.gpullama3.api.config;

import org.beehive.gpullama3.api.service.LLMService;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.Options;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class ModelConfiguration {

/**
* Expose Model as a Spring bean using the initialized service
*/
@Bean
public Model model(LLMService llmService) {
return llmService.getModel();
}

/**
* Expose Options as a Spring bean using the initialized service
*/
@Bean
public Options options(LLMService llmService) {
return llmService.getOptions();
}
}
Loading