@@ -742,6 +742,18 @@ public Device withDefaultDevice(String deviceName) {
742742 return Device .fromName (deviceName , Engine .getEngine (engineName ));
743743 }
744744
745+ private boolean hasModelFile (Path modelDir , String prefix , String ... extensions ) {
746+ for (String extension : extensions ) {
747+ if (Files .isRegularFile (modelDir .resolve (prefix + extension ))) {
748+ return true ;
749+ }
750+ if (Files .isRegularFile (modelDir .resolve ("model" + extension ))) {
751+ return true ;
752+ }
753+ }
754+ return false ;
755+ }
756+
745757 private String inferEngine () throws ModelException {
746758 String eng = prop .getProperty ("engine" );
747759 if (eng != null ) {
@@ -768,23 +780,19 @@ private String inferEngine() throws ModelException {
768780 return zoo .getSupportedEngines ().iterator ().next ();
769781 } else if (isTorchServeModel ()) {
770782 return "Python" ;
771- } else if (Files .isRegularFile (modelDir .resolve (prefix + ".pt" ))
772- || Files .isRegularFile (modelDir .resolve ("model.pt" ))) {
783+ } else if (hasModelFile (modelDir , prefix , ".pt" )) {
773784 return "PyTorch" ;
774785 } else if (Files .isRegularFile (modelDir .resolve ("config.pbtxt" ))) {
775786 return "TritonServer" ;
776787 } else if (Files .isRegularFile (modelDir .resolve ("saved_model.pb" ))) {
777788 return "TensorFlow" ;
778- } else if (Files .isRegularFile (modelDir .resolve (prefix + ".onnx" ))
779- || Files .isRegularFile (modelDir .resolve ("model.onnx" ))) {
789+ } else if (hasModelFile (modelDir , prefix , ".onnx" )) {
780790 return "OnnxRuntime" ;
781- } else if (Files .isRegularFile (modelDir .resolve (prefix + ".json" ))
782- || Files .isRegularFile (modelDir .resolve (prefix + ".xgb" ))
783- || Files .isRegularFile (modelDir .resolve (prefix + ".bst" ))
784- || Files .isRegularFile (modelDir .resolve ("model.json" ))
785- || Files .isRegularFile (modelDir .resolve ("model.bst" ))
786- || Files .isRegularFile (modelDir .resolve ("model.xgb" ))) {
791+ } else if (hasModelFile (modelDir , prefix , ".json" , ".xgb" , ".bst" )) {
787792 return "XGBoost" ;
793+ } else if (hasModelFile (
794+ modelDir , prefix , ".skops" , ".joblib" , ".pkl" , ".pickle" , ".cloudpkl" )) {
795+ return "Python" ;
788796 } else if (isPythonModel (prefix )) {
789797 // TODO: How to differentiate Rust model from Python
790798 return "Python" ;
0 commit comments