Skip to content

Commit bb71b95

Browse files
authored
Merge pull request #50 from mikepapadim/feat/handle_gpu_cpu_switch
Add `useTornadovm` flag to model loader to handle Builder option in Langchain4j
2 parents a78d3e9 + ea90462 commit bb71b95

File tree

11 files changed

+45
-41
lines changed

11 files changed

+45
-41
lines changed

src/main/java/org/beehive/gpullama3/LlamaApp.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ private static Model loadModel(Options options) throws IOException {
130130
}
131131
return model;
132132
}
133-
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
133+
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
134134
}
135135

136136
private static Sampler createSampler(Model model, Options options) {

src/main/java/org/beehive/gpullama3/Options.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static void require(boolean condition, String messageFormat, Object... args) {
2525
}
2626

2727
private static boolean getDefaultTornadoVM() {
28-
return Boolean.parseBoolean(System.getProperty("use.tornadovm", "false"));
28+
return Boolean.parseBoolean(System.getProperty("use.tornadovm", "true"));
2929
}
3030

3131
static void printUsage(PrintStream out) {

src/main/java/org/beehive/gpullama3/aot/AOT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
4848
}
4949
GGUF gguf = GGUF.loadModel(path);
5050
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
51-
modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false);
51+
modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false, false);
5252
return new PartialModel(path.getFileName().toString(), modelLoader.loadModel(), // TODO: needs proper handling for AOT
5353
gguf.getTensorDataOffset(), gguf.getTensorInfos());
5454
}

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.beehive.gpullama3.model;
22

3+
import dev.langchain4j.model.chat.request.ChatRequest;
34
import org.beehive.gpullama3.Options;
45
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
56
import org.beehive.gpullama3.inference.sampler.Sampler;
@@ -92,7 +93,7 @@ default void runInteractive(Sampler sampler, Options options) {
9293
Scanner in = new Scanner(System.in);
9394

9495
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
95-
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
96+
if (options.useTornadovm() && tornadoVMPlan == null) {
9697
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
9798
}
9899

@@ -131,7 +132,7 @@ default void runInteractive(Sampler sampler, Options options) {
131132
};
132133

133134
// Choose between GPU and CPU path based on configuration
134-
if (Options.getDefaultOptions().useTornadovm()) {
135+
if (options.useTornadovm()) {
135136
// GPU path using TornadoVM
136137
responseTokens = generateTokensGPU(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
137138
options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
@@ -170,7 +171,7 @@ default void runInteractive(Sampler sampler, Options options) {
170171
}
171172
} finally {
172173
// Clean up TornadoVM resources when exiting the chat loop
173-
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan != null) {
174+
if (options.useTornadovm() && tornadoVMPlan != null) {
174175
try {
175176
tornadoVMPlan.freeTornadoExecutionPlan();
176177
} catch (Exception e) {
@@ -201,7 +202,7 @@ default String runInstructOnce(Sampler sampler, Options options) {
201202
}
202203

203204
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
204-
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
205+
if (options.useTornadovm() && tornadoVMPlan == null) {
205206
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
206207
}
207208

@@ -231,7 +232,7 @@ default String runInstructOnce(Sampler sampler, Options options) {
231232

232233
Set<Integer> stopTokens = chatFormat.getStopTokens();
233234

234-
if (Options.getDefaultOptions().useTornadovm()) {
235+
if (options.useTornadovm()) {
235236
// GPU path using TornadoVM - Call generateTokensGPU without the token consumer parameter
236237
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
237238
} else {
@@ -275,7 +276,7 @@ default String runInstructOnceLangChain4J(Sampler sampler, Options options, Cons
275276
}
276277

277278
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
278-
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
279+
if (options.useTornadovm() && tornadoVMPlan == null) {
279280
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
280281
}
281282

@@ -305,7 +306,7 @@ default String runInstructOnceLangChain4J(Sampler sampler, Options options, Cons
305306

306307
Set<Integer> stopTokens = chatFormat.getStopTokens();
307308

308-
if (Options.getDefaultOptions().useTornadovm()) {
309+
if (options.useTornadovm()) {
309310
// GPU path using TornadoVM Call generateTokensGPU without the token consumer parameter
310311
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
311312
} else {
@@ -332,4 +333,5 @@ default String runInstructOnceLangChain4J(Sampler sampler, Options options, Cons
332333

333334
return responseText;
334335
}
336+
335337
}

src/main/java/org/beehive/gpullama3/model/ModelType.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,55 +24,55 @@
2424
public enum ModelType {
2525
LLAMA_3 {
2626
@Override
27-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
28-
return new LlamaModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
27+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
28+
return new LlamaModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
2929
}
3030
},
3131

3232
MISTRAL {
3333
@Override
34-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
35-
return new MistralModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
34+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
35+
return new MistralModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
3636
}
3737
},
3838

3939
QWEN_2 {
4040
@Override
41-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
42-
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
41+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
42+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
4343
}
4444
},
4545

4646
QWEN_3 {
4747
@Override
48-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
49-
return new Qwen3ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
48+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
49+
return new Qwen3ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
5050
}
5151
},
5252

5353
DEEPSEEK_R1_DISTILL_QWEN {
5454
@Override
55-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
56-
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
55+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
56+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
5757
}
5858
},
5959

6060
PHI_3 {
6161
@Override
62-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
63-
return new Phi3ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
62+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
63+
return new Phi3ModelLoader(fileChannel, gguf, contextLength, loadWeights, useTornadovm).loadModel();
6464
}
6565
},
6666

6767
UNKNOWN {
6868
@Override
69-
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
69+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
7070
throw new UnsupportedOperationException("Cannot load unknown model type");
7171
}
7272
};
7373

7474
// Abstract method that each enum constant must implement
75-
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights);
75+
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm);
7676

7777
public boolean isDeepSeekR1() {
7878
return this == DEEPSEEK_R1_DISTILL_QWEN;

src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
public class LlamaModelLoader extends ModelLoader {
1919

20-
public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
21-
super(fileChannel, gguf, contextLength, loadWeights);
20+
public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadoVM) {
21+
super(fileChannel, gguf, contextLength, loadWeights, useTornadoVM);
2222
}
2323

2424
// @formatter:off

src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
public class MistralModelLoader extends ModelLoader {
1919

20-
public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
21-
super(fileChannel, gguf, contextLength, loadWeights);
20+
public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
21+
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
2222
}
2323

2424
// @formatter:off

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ public abstract class ModelLoader {
3939
protected GGUF gguf;
4040
protected int contextLength;
4141
protected boolean loadWeights;
42+
protected boolean useTornadovm;
4243

43-
public ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
44+
public ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
4445
this.fileChannel = fileChannel;
4546
this.gguf = gguf;
4647
this.contextLength = contextLength;
4748
this.loadWeights = loadWeights;
49+
this.useTornadovm = useTornadovm;
4850
}
4951

5052
private static ModelType detectModelType(Map<String, Object> metadata) {
@@ -74,14 +76,14 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
7476
return ModelType.UNKNOWN;
7577
}
7678

77-
public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException {
79+
public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights, boolean useTornadovm) throws IOException {
7880
// initial load of metadata from gguf file
7981
GGUF gguf = GGUF.loadModel(ggufPath);
8082
FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
8183
// detect model type
8284
ModelType modelType = detectModelType(gguf.getMetadata());
8385
// model type-specific load
84-
return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights);
86+
return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
8587
}
8688

8789
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
@@ -222,7 +224,7 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
222224
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
223225
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
224226

225-
if (Options.getDefaultOptions().useTornadovm()) {
227+
if (useTornadovm) {
226228
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
227229
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
228230
}

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import java.util.Map;
2828

2929
public class Phi3ModelLoader extends ModelLoader {
30-
public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
31-
super(fileChannel, gguf, contextLength, loadWeights);
30+
public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
31+
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
3232
}
3333

3434
// @formatter:off
@@ -98,7 +98,7 @@ private Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configur
9898
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
9999
GGMLTensorEntry outputWeight = tensorEntries.get("output.weight"); // Phi3 always has separate output weight
100100

101-
if (Options.getDefaultOptions().useTornadovm()) {
101+
if (useTornadovm) {
102102
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
103103
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
104104
}

src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131
public class Qwen2ModelLoader extends ModelLoader {
3232

33-
public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
34-
super(fileChannel, gguf, contextLength, loadWeights);
33+
public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
34+
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
3535
}
3636

3737
@Override
@@ -96,7 +96,7 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
9696
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
9797
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
9898

99-
if (Options.getDefaultOptions().useTornadovm()) {
99+
if (useTornadovm) {
100100
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
101101
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
102102
}

0 commit comments

Comments
 (0)