Skip to content

Commit cde76dc

Browse files
committed
Refactor TornadoVMMasterPlan to improve NVIDIA model compatibility logic
1 parent 6132b64 commit cde76dc

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,12 @@ public boolean shouldUseNvidiaScheduler(Model model) {
105105
String platformName = runtime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase(Locale.ROOT);
106106

107107
boolean isNvidia = platformName.contains("nvidia");
108-
boolean isMistral = model.getModelType() == ModelType.MISTRAL;
108+
boolean isNotMistral = model.getModelType() != ModelType.MISTRAL;
109109

110-
return !isNvidia || isMistral;
111-
}
110+
boolean result = isNvidia && isNotMistral;
112111

112+
return result;
113+
}
113114

114115
/**
115116
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration.

0 commit comments

Comments
 (0)