Skip to content

Commit d1239eb

Browse files
Distinct Deepseek-R1-Distill-Qwen from Qwen2
1 parent 35b993e commit d1239eb

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/main/java/org/beehive/gpullama3/model/ModelType.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
5050
}
5151
},
5252

53+
DEEPSEEK_R1_DISTILL_QWEN {
54+
@Override
55+
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
56+
return new Qwen2ModelLoader(fileChannel, gguf, contextLength, loadWeights).loadModel();
57+
}
58+
},
59+
5360
PHI_3 {
5461
@Override
5562
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) {
@@ -66,4 +73,8 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
6673

6774
// Abstract method that each enum constant must implement
6875
public abstract Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights);
76+
77+
public boolean isDeepSeekR1() {
78+
return this == DEEPSEEK_R1_DISTILL_QWEN;
79+
}
6980
}

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ private static ModelType detectModelType(Map<String, Object> metadata) {
6060
return ModelType.MISTRAL;
6161
} else if (lowerName.contains("llama")) {
6262
return ModelType.LLAMA_3;
63-
} else if (lowerName.contains("qwen2") || lowerName.contains("deepseek r1 distill")) {
63+
} else if (lowerName.contains("qwen2")) {
6464
return ModelType.QWEN_2;
6565
} else if (lowerName.contains("qwen3")) {
6666
return ModelType.QWEN_3;
67+
} else if (lowerName.contains("deepseek r1 distill")) {
68+
return ModelType.DEEPSEEK_R1_DISTILL_QWEN;
6769
} else if (lowerName.contains("phi3")) {
6870
return ModelType.PHI_3;
6971
}

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ TornadoVMLayerPlanner createPlanner(State state, Model model) {
9999
return switch (model.getModelType()) {
100100
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
101101
case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
102-
case QWEN_2 -> throw new UnsupportedOperationException("TornadoVM QWEN 2 not supported");
102+
case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> throw new UnsupportedOperationException("TornadoVM QWEN 2 not supported");
103103
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
104104
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
105105
};

0 commit comments

Comments
 (0)