4
4
import com .example .loader .weights .State ;
5
5
import com .example .model .Configuration ;
6
6
import com .example .model .Model ;
7
+ import com .example .model .ModelType ;
7
8
import uk .ac .manchester .tornado .api .GridScheduler ;
8
9
import uk .ac .manchester .tornado .api .ImmutableTaskGraph ;
9
10
import uk .ac .manchester .tornado .api .TornadoExecutionPlan ;
12
13
import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
13
14
14
15
import java .util .List ;
16
+ import java .util .Locale ;
15
17
16
18
public class TornadoVMMasterPlan {
17
19
private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean .parseBoolean (System .getProperty ("llama.EnableTimingForTornadoVMInit" , "False" ));
@@ -22,9 +24,9 @@ public class TornadoVMMasterPlan {
22
24
public TornadoExecutionPlan executionPlan ;
23
25
List <ImmutableTaskGraph > taskGraphs ;
24
26
25
- public TornadoVMMasterPlan (State state , Model model , boolean isNvidia ) {
27
+ public TornadoVMMasterPlan (State state , Model model ) {
26
28
TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner (state , model );
27
- Tuple2 <List <ImmutableTaskGraph >, GridScheduler > tornadoVMPlan = isNvidia
29
+ Tuple2 <List <ImmutableTaskGraph >, GridScheduler > tornadoVMPlan = shouldUseNvidiaScheduler ( model )
28
30
? tornadoVMLayerPlanner .setupTornadoForwardPlanLayered ()
29
31
: tornadoVMLayerPlanner .setupTornadoForwardPlanLayeredNonNvidia ();
30
32
this .taskGraphs = tornadoVMPlan .getFirst ();
@@ -57,9 +59,7 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
57
59
}
58
60
59
61
// 1. Pre-allocate the TornadoVM plan
60
- TornadoRuntime coreRuntime = TornadoRuntimeProvider .getTornadoRuntime ();
61
- boolean isNvidia = coreRuntime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase ().contains ("nvidia" );
62
- TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan (state , model , isNvidia );
62
+ TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan (state , model );
63
63
64
64
// Record time after plan creation
65
65
if (ENABLE_TORNADOVM_INIT_TIME ) {
@@ -89,6 +89,29 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
89
89
return tornadoVMPlan ;
90
90
}
91
91
92
+ /**
93
+ * Determines whether the NVIDIA-specific scheduler should be used based on the current
94
+ * hardware backend and the model type.
95
+ * <p>
96
+ * The scheduler is used only if the runtime is targeting an NVIDIA backend and the model
97
+ * is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is
98
+ * {@code MISTRAL}, the NVIDIA-specific scheduler should not be used.
99
+ *
100
+ * @param model the model whose type may affect the scheduler decision
101
+ * @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise
102
+ */
103
+ public static boolean shouldUseNvidiaScheduler (Model model ) {
104
+ TornadoRuntime runtime = TornadoRuntimeProvider .getTornadoRuntime ();
105
+ String platformName = runtime .getBackend (0 ).getDefaultDevice ().getPlatformName ().toLowerCase (Locale .ROOT );
106
+
107
+ boolean isNvidia = platformName .contains ("nvidia" );
108
+ boolean isNotMistral = model .getModelType () != ModelType .MISTRAL ;
109
+
110
+ boolean result = isNvidia && isNotMistral ;
111
+
112
+ return result ;
113
+ }
114
+
92
115
/**
93
116
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration.
94
117
*This method processes the transformer layers in sequence for a particular token position in the context
0 commit comments