Skip to content

Commit 3b4bb62

Browse files
Apply a formatter pass
1 parent 72a2b8b commit 3b4bb62

File tree

15 files changed

+69
-66
lines changed

15 files changed

+69
-66
lines changed

src/main/java/com/example/inference/InferenceCore.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,5 +195,4 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i
195195
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
196196
}
197197

198-
199198
}

src/main/java/com/example/inference/InferenceEngine.java

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,8 @@ private InferenceEngine() {
4646
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
4747
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt
4848
*/
49-
public static List<Integer> generateTokens(Model model,
50-
State state,
51-
int startPosition,
52-
List<Integer> promptTokens,
53-
Set<Integer> stopTokens,
54-
int maxTokens,
55-
Sampler sampler,
56-
boolean echo,
57-
IntConsumer onTokenGenerated) {
49+
public static List<Integer> generateTokens(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
50+
IntConsumer onTokenGenerated) {
5851
// Start timing the whole process
5952
long startNanos = System.nanoTime();
6053
long inferenceStartNanos = 0;
@@ -129,16 +122,8 @@ public static List<Integer> generateTokens(Model model,
129122
return generatedTokens;
130123
}
131124

132-
public static List<Integer> generateTokensGPU(Model model,
133-
State state,
134-
int startPosition,
135-
List<Integer> promptTokens,
136-
Set<Integer> stopTokens,
137-
int maxTokens,
138-
Sampler sampler,
139-
boolean echo,
140-
IntConsumer onTokenGenerated,
141-
TornadoVMMasterPlan tornadoVMPlan) {
125+
public static List<Integer> generateTokensGPU(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
126+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
142127
// === Setup and Initialization ===
143128
long startNanos = System.nanoTime();
144129
long inferenceStartNanos = 0;

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public static Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Co
105105
}
106106

107107
private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
108-
GGMLTensorEntry outputWeight) {
108+
GGMLTensorEntry outputWeight) {
109109
return new Weights(
110110
// Load directly to TornadoVM format
111111
loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
@@ -124,7 +124,7 @@ private static Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tenso
124124
* Creates weights in standard format only
125125
*/
126126
private static Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
127-
GGMLTensorEntry outputWeight) {
127+
GGMLTensorEntry outputWeight) {
128128
return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
129129
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
130130
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),

src/main/java/com/example/model/Model.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121

2222
public interface Model {
2323
Configuration configuration();
24+
2425
Tokenizer tokenizer();
26+
2527
Weights weights();
2628

2729
ModelType getModelType();
2830

2931
State createNewState();
32+
3033
State createNewState(int batchsize);
3134

3235
/**
@@ -85,14 +88,12 @@ default void runInteractive(Sampler sampler, Options options) {
8588
// Choose between GPU and CPU path based on configuration
8689
if (USE_TORNADOVM) {
8790
// GPU path using TornadoVM
88-
responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition,
89-
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
91+
responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
9092
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
9193
} else {
9294
// CPU path
93-
responseTokens = InferenceEngine.generateTokens(this, state, startPosition,
94-
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
95-
options.maxTokens(), sampler, options.echo(), tokenConsumer);
95+
responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
96+
sampler, options.echo(), tokenConsumer);
9697
}
9798

9899
// Include stop token in the prompt history, but not in the response displayed to the user.
@@ -164,11 +165,10 @@ default void runInstructOnce(Sampler sampler, Options options) {
164165
if (USE_TORNADOVM) {
165166
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
166167
// Call generateTokensGPU without the token consumer parameter
167-
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens,
168-
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
168+
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null,
169+
tornadoVMPlan);
169170
} else {
170-
responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens,
171-
options.maxTokens(), sampler, options.echo(), tokenConsumer);
171+
responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
172172
}
173173

174174
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {

src/main/java/com/example/model/format/ChatFormat.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ static ChatFormat create(Object tokenizer) {
1919
}
2020

2121
List<Integer> encodeHeader(Message message);
22+
2223
List<Integer> encodeMessage(Message message);
24+
2325
int getBeginOfText();
26+
2427
Set<Integer> getStopTokens();
2528

2629
/**

src/main/java/com/example/model/format/LlamaChatFormat.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,14 @@ public LlamaChatFormat(LlamaTokenizer tokenizer) {
3131
}
3232

3333
@Override
34-
public int getBeginOfText() { return beginOfText; }
34+
public int getBeginOfText() {
35+
return beginOfText;
36+
}
3537

3638
@Override
37-
public Set<Integer> getStopTokens() { return stopTokens; }
39+
public Set<Integer> getStopTokens() {
40+
return stopTokens;
41+
}
3842

3943
@Override
4044
public List<Integer> encodeHeader(Message message) {

src/main/java/com/example/model/format/MistralChatFormat.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,14 @@ public MistralChatFormat(MistralTokenizer tokenizer) {
4141
}
4242

4343
@Override
44-
public int getBeginOfText() { return beginOfText; }
44+
public int getBeginOfText() {
45+
return beginOfText;
46+
}
4547

4648
@Override
47-
public Set<Integer> getStopTokens() { return Set.of(endOfText); }
49+
public Set<Integer> getStopTokens() {
50+
return Set.of(endOfText);
51+
}
4852

4953
@Override
5054
public List<Integer> encodeHeader(Message message) {

src/main/java/com/example/model/llama/Llama.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weigh
2121
private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
2222

2323
/* For explicit use */
24-
private LlamaTokenizer getAsLlamaTokenizer() { return (LlamaTokenizer) tokenizer; }
24+
private LlamaTokenizer getAsLlamaTokenizer() {
25+
return (LlamaTokenizer) tokenizer;
26+
}
2527

2628
@Override
2729
public ModelType getModelType() {
@@ -42,6 +44,7 @@ public State createNewState(int batchsize) {
4244
return state;
4345
}
4446

47+
// @formatter:off
4548
public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
4649
try (var ignored = Timer.log("Load LlaMa model")) {
4750
Map<String, Object> metadata = gguf.getMetadata();
@@ -75,6 +78,7 @@ public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLen
7578
throw new RuntimeException(e);
7679
}
7780
}
81+
// @formatter:on
7882

7983
}
8084

src/main/java/com/example/model/llama/LlamaConfiguration.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@
22

33
import com.example.model.Configuration;
44

5-
public record LlamaConfiguration(
6-
int dim, int hiddenDim, int numberOfLayers, int numberOfHeads,
7-
int numberOfKeyValueHeads, int vocabularySize, int contextLength,
8-
float rmsNormEps, float ropeTheta
9-
) implements Configuration {
5+
public record LlamaConfiguration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta)
6+
implements Configuration {
107

118
public int headSize() {
129
return dim / numberOfHeads;
@@ -17,7 +14,6 @@ public int kvDim() {
1714
return dim * numberOfKeyValueHeads / numberOfHeads;
1815
}
1916

20-
2117
/** Multiplier for key/value sharing in multi-query attention */
2218
public int kvMul() {
2319
return numberOfHeads / numberOfKeyValueHeads;
@@ -30,6 +26,7 @@ public int kvMul() {
3026
* @return A new Configuration instance with updated context length,
3127
* or the current instance if newContextLength is negative
3228
*/
29+
// @formatter:off
3330
public LlamaConfiguration withContextLength(int newContextLength) {
3431
if (newContextLength < 0) {
3532
return this; // no change
@@ -46,5 +43,6 @@ public LlamaConfiguration withContextLength(int newContextLength) {
4643
this.ropeTheta
4744
);
4845
}
46+
// @formatter:on
4947
}
5048

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
2121

2222
/* For explicit use */
23-
private MistralTokenizer getAsMistralTokenizer() { return (MistralTokenizer) tokenizer; }
23+
private MistralTokenizer getAsMistralTokenizer() {
24+
return (MistralTokenizer) tokenizer;
25+
}
2426

2527
@Override
2628
public ModelType getModelType() {
@@ -39,6 +41,7 @@ public State createNewState(int batchsize) {
3941
return state;
4042
}
4143

44+
// @formatter:off
4245
public static Mistral loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
4346
try (var ignored = Timer.log("Load Mistral model")) {
4447
Map<String, Object> metadata = gguf.getMetadata();
@@ -78,5 +81,6 @@ public static Mistral loadModel(FileChannel fileChannel, GGUF gguf, int contextL
7881
throw new RuntimeException(e);
7982
}
8083
}
84+
// @formatter:on
8185

8286
}

0 commit comments

Comments
 (0)