diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java index be6efdd67be..21b6a0610fd 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java @@ -96,20 +96,9 @@ public void testModuleLoadForwardExplicit() throws IOException{ assertTrue(results[0].isTensor()); } - @Test + @Test(expected = RuntimeException.class) public void testModuleLoadNonExistantFile() throws IOException{ Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); - - EValue[] results = module.forward(); - assertEquals(null, results); - } - - @Test - public void testModuleLoadMethodNonExistantFile() throws IOException{ - Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); - - int loadMethod = module.loadMethod(FORWARD_METHOD); - assertEquals(loadMethod, ACCESS_FAILED); } @Test @@ -146,11 +135,11 @@ public void testForwardOnDestroyedModule() throws IOException{ assertEquals(loadMethod, OK); module.destroy(); - + EValue[] results = module.forward(); assertEquals(0, results.length); } - + @Test public void testForwardFromMultipleThreads() throws InterruptedException, IOException { Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); @@ -169,7 +158,7 @@ public void run() { assertTrue(results[0].isTensor()); completed.incrementAndGet(); } catch (InterruptedException e) { - + } } }; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index f3f543dc2a8..cb544dd6a37 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -11,6 +11,7 @@ import android.util.Log; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.io.File; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; @@ -52,6 +53,10 @@ public static Module load(final String modelPath, int loadMode) { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } + File modelFile = new File(modelPath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path " + modelPath); + } return new Module(new NativePeer(modelPath, loadMode)); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index f845937be41..69e302edf78 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -12,6 +12,7 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.io.File; import org.pytorch.executorch.annotations.Experimental; /** @@ -41,9 +42,26 @@ public class LlmModule { private static native HybridData initHybrid( int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath); + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + public LlmModule( + int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { + File modelFile = new File(modulePath); + if (!modelFile.canRead() || !modelFile.isFile()) { + throw new RuntimeException("Cannot load model path " + modulePath); + } + File tokenizerFile = new File(tokenizerPath); + if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) { + throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); + } + mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath); + } + /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ public LlmModule(String modulePath, String tokenizerPath, float temperature) { - mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); + this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); } /** @@ -51,23 +69,22 @@ public LlmModule(String modulePath, String tokenizerPath, float temperature) { * path. */ public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); + this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); } /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null); + this(modelType, modulePath, tokenizerPath, temperature, null); } /** Constructs a LLM Module for a model with the given LlmModuleConfig */ public LlmModule(LlmModuleConfig config) { - mHybridData = - initHybrid( - config.getModelType(), - config.getModulePath(), - config.getTokenizerPath(), - config.getTemperature(), - config.getDataPath()); + this( + config.getModelType(), + config.getModulePath(), + config.getTokenizerPath(), + config.getTemperature(), + config.getDataPath()); } public void resetNative() {