Skip to content

Commit 42940bd

Browse files
committed
Android Java throw exception if cannot load model or tokenizer
1 parent 24789c8 commit 42940bd

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import android.util.Log;
1212
import com.facebook.soloader.nativeloader.NativeLoader;
1313
import com.facebook.soloader.nativeloader.SystemDelegate;
14+
import java.nio.file.Files;
15+
import java.nio.file.Paths;
1416
import java.util.concurrent.locks.Lock;
1517
import java.util.concurrent.locks.ReentrantLock;
1618
import org.pytorch.executorch.annotations.Experimental;
@@ -52,6 +54,9 @@ public static Module load(final String modelPath, int loadMode) {
5254
if (!NativeLoader.isInitialized()) {
5355
NativeLoader.init(new SystemDelegate());
5456
}
57+
if (!Files.isReadable(Paths.get(modelPath))) {
58+
throw new RuntimeException("Cannot load model path " + modelPath);
59+
}
5560
return new Module(new NativePeer(modelPath, loadMode));
5661
}
5762

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import com.facebook.jni.annotations.DoNotStrip;
1313
import com.facebook.soloader.nativeloader.NativeLoader;
1414
import com.facebook.soloader.nativeloader.SystemDelegate;
15+
import java.nio.file.Files;
16+
import java.nio.file.Paths;
1517
import org.pytorch.executorch.annotations.Experimental;
1618

1719
/**
@@ -41,33 +43,47 @@ public class LlmModule {
4143
private static native HybridData initHybrid(
4244
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);
4345

46+
/**
47+
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
48+
* data path.
49+
*/
50+
public LlmModule(
51+
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) {
52+
if (!Files.isReadable(Paths.get(modulePath))) {
53+
throw new RuntimeException("Cannot load model path " + modulePath);
54+
}
55+
if (!Files.isReadable(Paths.get(tokenizerPath))) {
56+
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
57+
}
58+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath);
59+
}
60+
4461
/** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */
4562
public LlmModule(String modulePath, String tokenizerPath, float temperature) {
46-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
63+
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
4764
}
4865

4966
/**
5067
* Constructs a LLM Module for a model with given model path, tokenizer, temperature and data
5168
* path.
5269
*/
5370
public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
54-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
71+
this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
5572
}
5673

5774
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
5875
public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
59-
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null);
76+
this(modelType, modulePath, tokenizerPath, temperature, null);
6077
}
6178

6279
/** Constructs a LLM Module for a model with the given LlmModuleConfig */
6380
public LlmModule(LlmModuleConfig config) {
64-
mHybridData =
65-
initHybrid(
66-
config.getModelType(),
67-
config.getModulePath(),
68-
config.getTokenizerPath(),
69-
config.getTemperature(),
70-
config.getDataPath());
81+
this(
82+
config.getModelType(),
83+
config.getModulePath(),
84+
config.getTokenizerPath(),
85+
config.getTemperature(),
86+
config.getDataPath());
7187
}
7288

7389
public void resetNative() {

0 commit comments

Comments
 (0)