Skip to content

Commit 733815b

Browse files
Move ModelType enum to dedicated file
1 parent 1640e90 commit 733815b

File tree

5 files changed

+21
-18
lines changed

5 files changed

+21
-18
lines changed

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.example.core.types.Pair;
1313
import com.example.model.Configuration;
1414
import com.example.model.Model;
15+
import com.example.model.ModelType;
1516
import com.example.model.llama.LlamaConfiguration;
1617
import com.example.model.llama.Llama;
1718
import com.example.model.mistral.Mistral;
@@ -39,21 +40,13 @@
3940
import java.util.stream.Collectors;
4041
import java.util.stream.IntStream;
4142

42-
import static com.example.loader.weights.ModelLoader.ModelType.LLAMA_3;
43-
4443
public final class ModelLoader {
4544
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
4645
private static final String TOKENIZER_MISTRAL_MODEL = "llama";
4746

4847
private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
4948
private static final String MISTRAL_PATTERN = "\\S+|\\s+";
5049

51-
public enum ModelType {
52-
LLAMA_3,
53-
MISTRAL,
54-
UNKNOWN
55-
}
56-
5750
private static ModelType detectModelType(Map<String, Object> metadata) {
5851
String name = (String) metadata.get("general.name");
5952
String tokenizerModel = (String) metadata.get("tokenizer.ggml.model");
@@ -65,23 +58,23 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
6558
if (lowerName.contains("mistral")) {
6659
return ModelType.MISTRAL;
6760
} else if (lowerName.contains("llama")) {
68-
return LLAMA_3;
61+
return ModelType.LLAMA_3;
6962
}
7063
}
7164

7265
// Check by tokenizer model
7366
if (TOKENIZER_MISTRAL_MODEL.equals(tokenizerModel)) {
7467
return ModelType.MISTRAL;
7568
} else if (TOKENIZER_LLAMA_3_MODEL.equals(tokenizerModel)) {
76-
return LLAMA_3;
69+
return ModelType.LLAMA_3;
7770
}
7871

7972
// Check by vocabulary size as fallback
8073
if (vocabSize != null) {
8174
if (vocabSize == 32768) {
8275
return ModelType.MISTRAL;
8376
} else if (vocabSize == 128256) {
84-
return LLAMA_3;
77+
return ModelType.LLAMA_3;
8578
}
8679
}
8780

src/main/java/com/example/model/Model.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import com.example.inference.InferenceEngine;
66
import com.example.inference.sampler.Sampler;
77
import com.example.Options;
8-
import com.example.loader.weights.ModelLoader.ModelType;
98
import com.example.loader.weights.State;
109
import com.example.loader.weights.Weights;
1110
import com.example.tokenizer.impl.Tokenizer;
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package com.example.model;
2+
3+
import com.example.core.model.GGUF;
4+
5+
import java.nio.channels.FileChannel;
6+
7+
public enum ModelType {
8+
LLAMA_3,
9+
MISTRAL,
10+
UNKNOWN;
11+
}

src/main/java/com/example/model/llama/Llama.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.example.model.llama;
22

33
import com.example.model.Model;
4-
import com.example.loader.weights.ModelLoader;
54
import com.example.loader.weights.State;
65
import com.example.loader.weights.Weights;
6+
import com.example.model.ModelType;
77
import com.example.tokenizer.impl.LlamaTokenizer;
88
import com.example.tokenizer.impl.Tokenizer;
99

@@ -14,8 +14,8 @@ public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weigh
1414
private LlamaTokenizer getAsLlamaTokenizer() { return (LlamaTokenizer) tokenizer; }
1515

1616
@Override
17-
public ModelLoader.ModelType getModelType() {
18-
return ModelLoader.ModelType.LLAMA_3;
17+
public ModelType getModelType() {
18+
return ModelType.LLAMA_3;
1919
}
2020

2121
@Override

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package com.example.model.mistral;
22

33
import com.example.model.Model;
4-
import com.example.loader.weights.ModelLoader;
54
import com.example.loader.weights.State;
65
import com.example.loader.weights.Weights;
6+
import com.example.model.ModelType;
77
import com.example.tokenizer.impl.MistralTokenizer;
88
import com.example.tokenizer.impl.Tokenizer;
99

@@ -13,8 +13,8 @@ public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, W
1313
private MistralTokenizer getAsMistralTokenizer() { return (MistralTokenizer) tokenizer; }
1414

1515
@Override
16-
public ModelLoader.ModelType getModelType() {
17-
return ModelLoader.ModelType.MISTRAL;
16+
public ModelType getModelType() {
17+
return ModelType.MISTRAL;
1818
}
1919

2020
public State createNewState() {

0 commit comments

Comments
 (0)