Skip to content

Commit a0dafe0

Browse files
Refactor ModelLoader: move loadModel() from LlamaApp to ModelLoader class
1 parent 61fd9b5 commit a0dafe0

File tree

3 files changed

+31
-39
lines changed

3 files changed

+31
-39
lines changed

src/main/java/org/beehive/gpullama3/LlamaApp.java

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,56 +2,20 @@
22

33
import org.beehive.gpullama3.aot.AOT;
44
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
5-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
6-
import org.beehive.gpullama3.inference.sampler.CategoricalSampler;
75
import org.beehive.gpullama3.inference.sampler.Sampler;
8-
import org.beehive.gpullama3.inference.sampler.ToppSampler;
96
import org.beehive.gpullama3.model.Model;
107
import org.beehive.gpullama3.model.loader.ModelLoader;
11-
import org.beehive.gpullama3.tornadovm.FloatArrayUtils;
12-
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
138

149
import java.io.IOException;
15-
import java.util.random.RandomGenerator;
16-
import java.util.random.RandomGeneratorFactory;
1710

1811
import static org.beehive.gpullama3.inference.sampler.Sampler.createSampler;
12+
import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel;
13+
1914
public class LlamaApp {
2015
// Configuration flags for hardware acceleration and optimizations
2116
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
22-
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
2317
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode
2418

25-
26-
/**
27-
* Loads the language model based on the given options.
28-
* <p>
29-
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
30-
* </p>
31-
*
32-
* @param options
33-
* the parsed CLI options containing model path and max token limit
34-
* @return the loaded {@link Model} instance
35-
* @throws IOException
36-
* if the model fails to load
37-
* @throws IllegalStateException
38-
* if AOT loading is enabled but the preloaded model is unavailable
39-
*/
40-
private static Model loadModel(Options options) throws IOException {
41-
if (USE_AOT) {
42-
Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
43-
if (model == null) {
44-
throw new IllegalStateException("Failed to load precompiled AOT model.");
45-
}
46-
return model;
47-
}
48-
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
49-
}
50-
51-
private static Sampler createSampler(Model model, Options options) {
52-
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
53-
}
54-
5519
private static void runSingleInstruction(Model model, Sampler sampler, Options options) {
5620
String response = model.runInstructOnce(sampler, options);
5721
System.out.println(response);

src/main/java/org/beehive/gpullama3/inference/sampler/Sampler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
103103
return sampler;
104104
}
105105

106-
static Sampler createSampler(Model model, Options options) {
106+
public static Sampler createSampler(Model model, Options options) {
107107
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
108108
}
109109

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.beehive.gpullama3.model.loader;
22

33
import org.beehive.gpullama3.Options;
4+
import org.beehive.gpullama3.aot.AOT;
45
import org.beehive.gpullama3.core.model.GGMLType;
56
import org.beehive.gpullama3.core.model.GGUF;
67
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
@@ -35,6 +36,8 @@
3536

3637
public abstract class ModelLoader {
3738

39+
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
40+
3841
protected FileChannel fileChannel;
3942
protected GGUF gguf;
4043
protected int contextLength;
@@ -76,6 +79,31 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
7679
return ModelType.UNKNOWN;
7780
}
7881

82+
/**
83+
* Loads the language model based on the given options.
84+
* <p>
85+
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
86+
* </p>
87+
*
88+
* @param options
89+
* the parsed CLI options containing model path and max token limit
90+
* @return the loaded {@link Model} instance
91+
* @throws IOException
92+
* if the model fails to load
93+
* @throws IllegalStateException
94+
* if AOT loading is enabled but the preloaded model is unavailable
95+
*/
96+
public static Model loadModel(Options options) throws IOException {
97+
if (USE_AOT) {
98+
Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
99+
if (model == null) {
100+
throw new IllegalStateException("Failed to load precompiled AOT model.");
101+
}
102+
return model;
103+
}
104+
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
105+
}
106+
79107
public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeights, boolean useTornadovm) throws IOException {
80108
// initial load of metadata from gguf file
81109
GGUF gguf = GGUF.loadModel(ggufPath);

0 commit comments

Comments
 (0)