Skip to content

Commit 6132b64

Browse files
committed
Refactor TornadoVMMasterPlan to simplify constructor and add NVIDIA scheduler decision logic
1 parent 928b8c4 commit 6132b64

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.example.loader.weights.State;
55
import com.example.model.Configuration;
66
import com.example.model.Model;
7+
import com.example.model.ModelType;
78
import uk.ac.manchester.tornado.api.GridScheduler;
89
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
910
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
@@ -12,6 +13,7 @@
1213
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1314

1415
import java.util.List;
16+
import java.util.Locale;
1517

1618
public class TornadoVMMasterPlan {
1719
private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
@@ -22,9 +24,9 @@ public class TornadoVMMasterPlan {
2224
public TornadoExecutionPlan executionPlan;
2325
List<ImmutableTaskGraph> taskGraphs;
2426

25-
public TornadoVMMasterPlan(State state, Model model, boolean isNvidia) {
27+
public TornadoVMMasterPlan(State state, Model model) {
2628
TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner(state, model);
27-
Tuple2<List<ImmutableTaskGraph>, GridScheduler> tornadoVMPlan = isNvidia
29+
Tuple2<List<ImmutableTaskGraph>, GridScheduler> tornadoVMPlan = shouldUseNvidiaScheduler(model)
2830
? tornadoVMLayerPlanner.setupTornadoForwardPlanLayered()
2931
: tornadoVMLayerPlanner.setupTornadoForwardPlanLayeredNonNvidia();
3032
this.taskGraphs = tornadoVMPlan.getFirst();
@@ -57,9 +59,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
5759
}
5860

5961
// 1. Pre-allocate the TornadoVM plan
60-
TornadoRuntime coreRuntime = TornadoRuntimeProvider.getTornadoRuntime();
61-
boolean isNvidia = coreRuntime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase().contains("nvidia");
62-
TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model, isNvidia);
62+
TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model);
6363

6464
// Record time after plan creation
6565
if (ENABLE_TORNADOVM_INIT_TIME) {
@@ -89,6 +89,28 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
8989
return tornadoVMPlan;
9090
}
9191

92+
/**
93+
* Determines whether the NVIDIA-specific scheduler should be used based on the current
94+
* hardware backend and the model type.
95+
* <p>
96+
* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model
97+
* is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is
98+
* {@code MISTRAL}, the NVIDIA-specific scheduler should not be used.
99+
*
100+
* @param model the model whose type may affect the scheduler decision
101+
* @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise
102+
*/
103+
public boolean shouldUseNvidiaScheduler(Model model) {
104+
TornadoRuntime runtime = TornadoRuntimeProvider.getTornadoRuntime();
105+
String platformName = runtime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase(Locale.ROOT);
106+
107+
boolean isNvidia = platformName.contains("nvidia");
108+
boolean isMistral = model.getModelType() == ModelType.MISTRAL;
109+
110+
return !isNvidia || isMistral;
111+
}
112+
113+
92114
/**
93115
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration.
94116
*This method processes the transformer layers in sequence for a particular token position in the context

0 commit comments

Comments
 (0)