Skip to content

Commit a2ec8cc

Browse files
Refactor for Mistral integration with abstractions
1 parent 89d5aa3 commit a2ec8cc

19 files changed

+102
-57
lines changed

src/main/java/com/example/LlamaApp.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
import com.example.inference.CategoricalSampler;
66
import com.example.inference.Sampler;
77
import com.example.inference.ToppSampler;
8-
import com.example.inference.engine.impl.Model;
9-
import com.example.inference.engine.impl.Options;
8+
import com.example.model.Model;
109
import com.example.loader.weights.ModelLoader;
11-
import com.example.loader.weights.State;
1210
import com.example.tornadovm.FloatArrayUtils;
1311
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1412

src/main/java/com/example/inference/engine/impl/Options.java renamed to src/main/java/com/example/Options.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.example.inference.engine.impl;
1+
package com.example;
22

33
import java.io.PrintStream;
44
import java.nio.file.Path;

src/main/java/com/example/aot/AOT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import com.example.auxiliary.Timer;
44
import com.example.core.model.GGUF;
55
import com.example.core.model.tensor.GGMLTensorEntry;
6-
import com.example.inference.engine.impl.Model;
7-
import com.example.inference.engine.impl.Options;
8-
import com.example.inference.engine.impl.llama.Llama;
6+
import com.example.model.Model;
7+
import com.example.Options;
8+
import com.example.model.llama.Llama;
99
import com.example.loader.weights.ModelLoader;
1010
import com.example.loader.weights.Weights;
1111

src/main/java/com/example/auxiliary/format/ChatFormat.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55
public interface ChatFormat {
66

7+
int getBeginOfText();
8+
79
}

src/main/java/com/example/auxiliary/format/LlamaChatFormat.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
public class LlamaChatFormat implements ChatFormat {
1212

13-
final LlamaTokenizer tokenizer;
14-
public final int beginOfText;
15-
final int endHeader;
16-
final int startHeader;
17-
final int endOfTurn;
18-
final int endOfText;
19-
final int endOfMessage;
20-
final Set<Integer> stopTokens;
13+
protected final LlamaTokenizer tokenizer;
14+
protected final int beginOfText;
15+
protected final int endHeader;
16+
protected final int startHeader;
17+
protected final int endOfTurn;
18+
protected final int endOfText;
19+
protected final int endOfMessage;
20+
protected final Set<Integer> stopTokens;
2121

2222
public LlamaChatFormat(LlamaTokenizer tokenizer) {
2323
this.tokenizer = tokenizer;
@@ -69,6 +69,11 @@ public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<LlamaC
6969
return tokens;
7070
}
7171

72+
@Override
73+
public int getBeginOfText() {
74+
return beginOfText;
75+
}
76+
7277
public record Message(LlamaChatFormat.Role role, String content) {
7378
}
7479

src/main/java/com/example/auxiliary/format/MistralChatFormat.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public MistralChatFormat(MistralTokenizer tokenizer) {
4040
this.middle = specialTokens.getOrDefault("[MIDDLE]", unknownToken);
4141
}
4242

43+
@Override
4344
public int getBeginOfText() { return beginOfText; }
4445

4546
public Set<Integer> getStopTokens() {

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
import com.example.core.model.tensor.Q4_0FloatTensor;
1111
import com.example.core.model.tensor.Q8_0FloatTensor;
1212
import com.example.core.types.Pair;
13-
import com.example.inference.engine.impl.Configuration;
14-
import com.example.inference.engine.impl.Model;
15-
import com.example.inference.engine.impl.llama.LlamaConfiguration;
16-
import com.example.inference.engine.impl.llama.Llama;
17-
import com.example.inference.engine.impl.mistral.Mistral;
18-
import com.example.inference.engine.impl.mistral.MistralConfiguration;
13+
import com.example.model.Configuration;
14+
import com.example.model.Model;
15+
import com.example.model.llama.LlamaConfiguration;
16+
import com.example.model.llama.Llama;
17+
import com.example.model.mistral.Mistral;
18+
import com.example.model.mistral.MistralConfiguration;
1919
import com.example.inference.operation.RoPE;
2020
import com.example.tokenizer.impl.LlamaTokenizer;
2121
import com.example.tokenizer.impl.MistralTokenizer;

src/main/java/com/example/loader/weights/State.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import com.example.core.model.tensor.ArrayFloatTensor;
44
import com.example.core.model.tensor.FloatTensor;
5-
import com.example.inference.engine.impl.Configuration;
5+
import com.example.model.Configuration;
66
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
77
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
88

src/main/java/com/example/inference/engine/impl/Configuration.java renamed to src/main/java/com/example/model/Configuration.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package com.example.inference.engine.impl;
1+
package com.example.model;
22

33
public interface Configuration {
44

src/main/java/com/example/inference/engine/impl/Model.java renamed to src/main/java/com/example/model/Model.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
package com.example.inference.engine.impl;
1+
package com.example.model;
22

33
import com.example.inference.Sampler;
4+
import com.example.Options;
5+
import com.example.loader.weights.ModelLoader.ModelType;
46
import com.example.loader.weights.State;
57
import com.example.loader.weights.Weights;
68
import com.example.tokenizer.impl.Tokenizer;
@@ -15,6 +17,8 @@ public interface Model {
1517
Tokenizer tokenizer();
1618
Weights weights();
1719

20+
ModelType getModelType();
21+
1822
List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens,
1923
int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan);
2024
/**

0 commit comments

Comments
 (0)