Skip to content

Commit 35b993e

Browse files
Extend logic for Qwen2
1 parent c4562ad commit 35b993e

File tree

4 files changed

+28
-15
lines changed

4 files changed

+28
-15
lines changed

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ default void runInstructOnce(Sampler sampler, Options options) {
164164

165165
List<Integer> promptTokens = new ArrayList<>();
166166

167-
if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.PHI_3)) {
167+
if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.QWEN_2) && !getModelType().equals(ModelType.PHI_3)) {
168168
promptTokens.add(chatFormat.getBeginOfText());
169169
}
170170

src/main/java/org/beehive/gpullama3/model/ModelType.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import org.beehive.gpullama3.model.loader.LlamaModelLoader;
55
import org.beehive.gpullama3.model.loader.MistralModelLoader;
66
import org.beehive.gpullama3.model.loader.Phi3ModelLoader;
7+
import org.beehive.gpullama3.model.loader.Qwen2ModelLoader;
78
import org.beehive.gpullama3.model.loader.Qwen3ModelLoader;
89

910
import java.nio.channels.FileChannel;
@@ -35,6 +36,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
3536
}
3637
},
3738

39+
QWEN_2 {
40+
@Override
41+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
42+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
43+
}
44+
},
45+
3846
QWEN_3 {
3947
@Override
4048
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
6060
return ModelType.MISTRAL;
6161
} else if (lowerName.contains("llama")) {
6262
return ModelType.LLAMA_3;
63+
} else if (lowerName.contains("qwen2") || lowerName.contains("deepseek r1 distill")) {
64+
return ModelType.QWEN_2;
6365
} else if (lowerName.contains("qwen3")) {
6466
return ModelType.QWEN_3;
6567
} else if (lowerName.contains("phi3")) {

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,21 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
9393
}
9494

9595
/**
96-
* Determines whether the NVIDIA-specific scheduler should be used based on the current hardware backend and the model type.
96+
* Dispatcher method to select the TornadoVMLayerPlanner for the model.
97+
*/
98+
TornadoVMLayerPlanner createPlanner(State state, Model model) {
99+
return switch (model.getModelType()) {
100+
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
101+
case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
102+
case QWEN_2 -> throw new UnsupportedOperationException("TornadoVM QWEN 2 not supported");
103+
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
104+
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
105+
};
106+
}
107+
108+
/**
109+
* Determines whether the NVIDIA-specific scheduler should be used based on the current
110+
* hardware backend and the model type.
97111
* <p>
98112
* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the
99113
* NVIDIA-specific scheduler should not be used.
@@ -115,19 +129,8 @@ public static boolean shouldUseNvidiaScheduler(Model model) {
115129
}
116130

117131
/**
118-
* Dispatcher method to select the TornadoVMLayerPlanner for the model.
119-
*/
120-
TornadoVMLayerPlanner createPlanner(State state, Model model) {
121-
return switch (model.getModelType()) {
122-
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
123-
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
124-
case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
125-
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
126-
};
127-
}
128-
129-
/**
130-
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context
132+
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration.
133+
*This method processes the transformer layers in sequence for a particular token position in the context
131134
* window.
132135
*
133136
* <p>The execution happens in three phases:

0 commit comments

Comments
 (0)