Skip to content

Commit abc1b2b

Browse files
Initial commit for Qwen2TornadoVMLayerPlanner
1 parent 09f1b4d commit abc1b2b

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package org.beehive.gpullama3.tornadovm;
2+
3+
import org.beehive.gpullama3.auxiliary.Tuple2;
4+
import org.beehive.gpullama3.inference.state.Qwen2State;
5+
import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
6+
import org.beehive.gpullama3.model.Model;
7+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
8+
import uk.ac.manchester.tornado.api.GridScheduler;
9+
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
10+
import uk.ac.manchester.tornado.api.TaskGraph;
11+
12+
import java.util.List;
13+
14+
public class Qwen2TornadoVMLayerPlanner extends TornadoVMLayerPlanner<Qwen2State, Qwen2Configuration, Qwen2TornadoWeights> {
15+
16+
/**
17+
* Constructs a TornadoVMLayerPlanner for the given Llama model.
18+
*
19+
* @param state
20+
* The state object containing model tensors and buffers
21+
* @param model
22+
* The Llama model instance containing configuration and weights
23+
*/
24+
public Qwen2TornadoVMLayerPlanner(Qwen2State state, Model model) {
25+
super(state, model);
26+
}
27+
28+
@Override
29+
protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
30+
throw new UnsupportedOperationException("configureLayerDataTransfers Not supported yet.");
31+
}
32+
33+
@Override
34+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayered() {
35+
throw new UnsupportedOperationException("setupTornadoForwardPlanLayered Not supported yet.");
36+
}
37+
38+
@Override
39+
public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLayeredNonNvidia() {
40+
return setupTornadoForwardPlanLayered();
41+
}
42+
43+
private GridScheduler setupQwen2GridSchedulersLayeredNonNvidia() {
44+
throw new UnsupportedOperationException("setupQwen2GridSchedulersLayeredNonNvidia Not supported yet.");
45+
}
46+
}

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import org.beehive.gpullama3.auxiliary.Tuple2;
44
import org.beehive.gpullama3.inference.state.Phi3State;
5+
import org.beehive.gpullama3.inference.state.Qwen2State;
56
import org.beehive.gpullama3.inference.state.Qwen3State;
67
import org.beehive.gpullama3.inference.state.State;
78
import org.beehive.gpullama3.model.Configuration;
@@ -99,7 +100,7 @@ TornadoVMLayerPlanner createPlanner(State state, Model model) {
99100
return switch (model.getModelType()) {
100101
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
101102
case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
102-
case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> throw new UnsupportedOperationException("TornadoVM QWEN 2 not supported");
103+
case QWEN_2, DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2TornadoVMLayerPlanner((Qwen2State) state, model);
103104
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
104105
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
105106
};

0 commit comments

Comments
 (0)