Skip to content

Commit 8e63862

Browse files
Move loadModel methods to dedicated model classes
1 parent 340b35e commit 8e63862

File tree

5 files changed

+122
-86
lines changed

5 files changed

+122
-86
lines changed

src/main/java/com/example/aot/AOT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
4646
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
4747
return new PartialModel(
4848
path.getFileName().toString(),
49-
ModelLoader.loadLlamaModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false),
49+
Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT
5050
gguf.getTensorDataOffset(),
5151
gguf.getTensorInfos()
5252
);

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 2 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package com.example.loader.weights;
22

33
import com.example.LlamaApp;
4-
import com.example.auxiliary.Timer;
54
import com.example.core.model.GGMLType;
65
import com.example.core.model.GGUF;
76
import com.example.core.model.tensor.F16FloatTensor;
@@ -70,89 +69,10 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig
7069
// initial load of metadata from gguf file
7170
GGUF gguf = GGUF.loadModel(ggufPath);
7271
FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ);
73-
7472
// detect model type
7573
ModelType modelType = detectModelType(gguf.getMetadata());
76-
System.out.println("Detected model type: " + modelType);
77-
78-
// load model (vocabulary, tokenizer, configuration, tensors, weights)
79-
return switch (modelType) {
80-
case LLAMA_3 -> loadLlamaModel(fileChannel, gguf, contextLength, loadWeights);
81-
case MISTRAL -> loadMistralModel(fileChannel, gguf, contextLength, loadWeights);
82-
default -> throw new UnsupportedOperationException("Unsupported model type: " + modelType);
83-
};
84-
}
85-
86-
public static Llama loadLlamaModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException {
87-
try (var ignored = Timer.log("Load LlaMa model")) {
88-
Map<String, Object> metadata = gguf.getMetadata();
89-
90-
Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata);
91-
Tokenizer tokenizer = createLlama3Tokenizer(metadata, vocabulary);
92-
93-
LlamaConfiguration config = new LlamaConfiguration(
94-
(int) metadata.get("llama.embedding_length"),
95-
(int) metadata.get("llama.feed_forward_length"),
96-
(int) metadata.get("llama.block_count"),
97-
(int) metadata.get("llama.attention.head_count"),
98-
99-
metadata.containsKey("llama.attention.head_count_kv") ?
100-
(int) metadata.get("llama.attention.head_count_kv") :
101-
(int) metadata.get("llama.attention.head_count"),
102-
103-
vocabulary.size(),
104-
(int) metadata.get("llama.context_length"),
105-
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
106-
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
107-
).withContextLength(contextLength);
108-
109-
Weights weights = null;
110-
if (loadWeights) {
111-
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
112-
weights = loadWeights(tensorEntries, config);
113-
}
114-
return new Llama(config, tokenizer, weights);
115-
}
116-
}
117-
118-
public static Mistral loadMistralModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
119-
try (var ignored = Timer.log("Load Mistral model")) {
120-
Map<String, Object> metadata = gguf.getMetadata();
121-
122-
Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata);
123-
Tokenizer tokenizer = createMistralTokenizer(metadata, vocabulary);
124-
125-
int modelContextLength = (int) metadata.get("llama.context_length");
126-
if (contextLength < 0 || modelContextLength < contextLength) {
127-
contextLength = modelContextLength;
128-
}
129-
130-
MistralConfiguration config = new MistralConfiguration(
131-
(int) metadata.get("llama.embedding_length"),
132-
(int) metadata.get("llama.feed_forward_length"),
133-
(int) metadata.get("llama.block_count"),
134-
(int) metadata.get("llama.attention.head_count"),
135-
136-
metadata.containsKey("llama.attention.head_count_kv")
137-
? (int) metadata.get("llama.attention.head_count_kv")
138-
: (int) metadata.get("llama.attention.head_count"),
139-
140-
vocabulary.size(),
141-
contextLength,
142-
false,
143-
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
144-
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
145-
);
146-
147-
Weights weights = null;
148-
if (loadWeights) {
149-
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
150-
weights = loadWeights(tensorEntries, config);
151-
}
152-
return new Mistral(config, tokenizer, weights);
153-
} catch (IOException e) {
154-
throw new RuntimeException(e);
155-
}
74+
// model type-specific load
75+
return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights);
15676
}
15777

15878
public static Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configuration config) {
Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,33 @@
11
package com.example.model;
22

33
import com.example.core.model.GGUF;
4+
import com.example.model.llama.Llama;
5+
import com.example.model.mistral.Mistral;
46

57
import java.nio.channels.FileChannel;
68

79
public enum ModelType {
8-
LLAMA_3,
9-
MISTRAL,
10-
UNKNOWN;
10+
LLAMA_3 {
11+
@Override
12+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
13+
return Llama.loadModel(fileChannel, gguf, contextLength, loadWeights);
14+
}
15+
},
16+
17+
MISTRAL {
18+
@Override
19+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
20+
return Mistral.loadModel(fileChannel, gguf, contextLength, loadWeights);
21+
}
22+
},
23+
24+
UNKNOWN {
25+
@Override
26+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
27+
throw new UnsupportedOperationException("Cannot load unknown model type");
28+
}
29+
};
30+
31+
// Abstract method that each enum constant must implement
32+
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights);
1133
}

src/main/java/com/example/model/llama/Llama.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
package com.example.model.llama;
22

3+
import com.example.auxiliary.Timer;
4+
import com.example.core.model.GGUF;
5+
import com.example.core.model.tensor.GGMLTensorEntry;
36
import com.example.model.Model;
47
import com.example.loader.weights.State;
58
import com.example.loader.weights.Weights;
69
import com.example.model.ModelType;
710
import com.example.tokenizer.impl.LlamaTokenizer;
811
import com.example.tokenizer.impl.Tokenizer;
12+
import com.example.tokenizer.vocabulary.Vocabulary;
13+
14+
import java.io.IOException;
15+
import java.nio.channels.FileChannel;
16+
import java.util.Map;
17+
18+
import static com.example.loader.weights.ModelLoader.loadWeights;
919

1020
public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
1121
private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
@@ -32,5 +42,39 @@ public State createNewState(int batchsize) {
3242
return state;
3343
}
3444

45+
public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
46+
try (var ignored = Timer.log("Load LlaMa model")) {
47+
Map<String, Object> metadata = gguf.getMetadata();
48+
49+
Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata);
50+
Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary);
51+
52+
LlamaConfiguration config = new LlamaConfiguration(
53+
(int) metadata.get("llama.embedding_length"),
54+
(int) metadata.get("llama.feed_forward_length"),
55+
(int) metadata.get("llama.block_count"),
56+
(int) metadata.get("llama.attention.head_count"),
57+
58+
metadata.containsKey("llama.attention.head_count_kv") ?
59+
(int) metadata.get("llama.attention.head_count_kv") :
60+
(int) metadata.get("llama.attention.head_count"),
61+
62+
vocabulary.size(),
63+
(int) metadata.get("llama.context_length"),
64+
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
65+
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
66+
).withContextLength(contextLength);
67+
68+
Weights weights = null;
69+
if (loadWeights) {
70+
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
71+
weights = loadWeights(tensorEntries, config);
72+
}
73+
return new Llama(config, tokenizer, weights);
74+
} catch (IOException e) {
75+
throw new RuntimeException(e);
76+
}
77+
}
78+
3579
}
3680

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
package com.example.model.mistral;
22

3+
import com.example.auxiliary.Timer;
4+
import com.example.core.model.GGUF;
5+
import com.example.core.model.tensor.GGMLTensorEntry;
36
import com.example.model.Model;
47
import com.example.loader.weights.State;
58
import com.example.loader.weights.Weights;
69
import com.example.model.ModelType;
710
import com.example.tokenizer.impl.MistralTokenizer;
811
import com.example.tokenizer.impl.Tokenizer;
12+
import com.example.tokenizer.vocabulary.Vocabulary;
13+
14+
import java.io.IOException;
15+
import java.nio.channels.FileChannel;
16+
import java.util.Map;
17+
18+
import static com.example.loader.weights.ModelLoader.loadWeights;
919

1020
public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
1121

@@ -29,4 +39,44 @@ public State createNewState(int batchsize) {
2939
return state;
3040
}
3141

42+
public static Mistral loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
43+
try (var ignored = Timer.log("Load Mistral model")) {
44+
Map<String, Object> metadata = gguf.getMetadata();
45+
46+
Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata);
47+
Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary);
48+
49+
int modelContextLength = (int) metadata.get("llama.context_length");
50+
if (contextLength < 0 || modelContextLength < contextLength) {
51+
contextLength = modelContextLength;
52+
}
53+
54+
MistralConfiguration config = new MistralConfiguration(
55+
(int) metadata.get("llama.embedding_length"),
56+
(int) metadata.get("llama.feed_forward_length"),
57+
(int) metadata.get("llama.block_count"),
58+
(int) metadata.get("llama.attention.head_count"),
59+
60+
metadata.containsKey("llama.attention.head_count_kv")
61+
? (int) metadata.get("llama.attention.head_count_kv")
62+
: (int) metadata.get("llama.attention.head_count"),
63+
64+
vocabulary.size(),
65+
contextLength,
66+
false,
67+
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
68+
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
69+
);
70+
71+
Weights weights = null;
72+
if (loadWeights) {
73+
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
74+
weights = loadWeights(tensorEntries, config);
75+
}
76+
return new Mistral(config, tokenizer, weights);
77+
} catch (IOException e) {
78+
throw new RuntimeException(e);
79+
}
80+
}
81+
3282
}

0 commit comments

Comments
 (0)