Skip to content

Commit 6baf6eb

Browse files
authored
refactor: add DL model class (#722)
* refactor: add DL model class Signed-off-by: Yaliang Wu <[email protected]> * fix model url in example doc Signed-off-by: Yaliang Wu <[email protected]> * address comments Signed-off-by: Yaliang Wu <[email protected]> * fix failed ut Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]>
1 parent 50116bb commit 6baf6eb

File tree

6 files changed

+384
-277
lines changed

6 files changed

+384
-277
lines changed

common/src/main/java/org/opensearch/ml/common/FunctionName.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,15 @@ public static FunctionName from(String value) {
2525
throw new IllegalArgumentException("Wrong function name");
2626
}
2727
}
28+
29+
/**
30+
* Check if model is deep learning model.
31+
* @return true for deep learning model.
32+
*/
33+
public static boolean isDLModel(FunctionName functionName) {
34+
if (functionName == TEXT_EMBEDDING) {
35+
return true;
36+
}
37+
return false;
38+
}
2839
}

docs/model_serving_framework/text_embedding_model_examples.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ POST /_plugins/_ml/models/_upload
6363
"embedding_dimension": 384,
6464
"framework_type": "sentence_transformers"
6565
},
66-
"url": "https://github.com/opensearch-project/ml-commons/tree/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_sentence-transformer.zip?raw=true"
66+
"url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_sentence-transformer.zip?raw=true"
6767
}
6868
6969
# Sample response
@@ -262,7 +262,7 @@ POST /_plugins/_ml/models/_upload
262262
"pooling_mode":"mean",
263263
"normalize_result":"true"
264264
},
265-
"url": "https://github.com/opensearch-project/ml-commons/tree/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_huggingface.zip?raw=true"
265+
"url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_torchscript_huggingface.zip?raw=true"
266266
}
267267
```
268268

@@ -289,6 +289,6 @@ POST /_plugins/_ml/models/_upload
289289
"pooling_mode":"mean",
290290
"normalize_result":"true"
291291
},
292-
"url": "https://github.com/opensearch-project/ml-commons/tree/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_onnx.zip?raw=true"
292+
"url": "https://github.com/opensearch-project/ml-commons/raw/2.x/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/text_embedding/all-MiniLM-L6-v2_onnx.zip?raw=true"
293293
}
294294
```
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
package org.opensearch.ml.engine.algorithms;
2+
3+
import ai.djl.Application;
4+
import ai.djl.Device;
5+
import ai.djl.engine.Engine;
6+
import ai.djl.inference.Predictor;
7+
import ai.djl.modality.Input;
8+
import ai.djl.modality.Output;
9+
import ai.djl.repository.zoo.Criteria;
10+
import ai.djl.repository.zoo.ZooModel;
11+
import ai.djl.translate.TranslateException;
12+
import ai.djl.translate.Translator;
13+
import ai.djl.translate.TranslatorFactory;
14+
import ai.djl.util.ZipUtils;
15+
import lombok.extern.log4j.Log4j2;
16+
import org.apache.commons.io.FileUtils;
17+
import org.opensearch.ml.common.FunctionName;
18+
import org.opensearch.ml.common.MLModel;
19+
import org.opensearch.ml.common.dataset.MLInputDataset;
20+
import org.opensearch.ml.common.exception.MLException;
21+
import org.opensearch.ml.common.input.MLInput;
22+
import org.opensearch.ml.common.model.MLModelConfig;
23+
import org.opensearch.ml.common.output.MLOutput;
24+
import org.opensearch.ml.common.output.model.ModelResultFilter;
25+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
26+
import org.opensearch.ml.common.output.model.ModelTensors;
27+
import org.opensearch.ml.engine.MLEngine;
28+
import org.opensearch.ml.engine.ModelHelper;
29+
import org.opensearch.ml.engine.Predictable;
30+
31+
import java.io.File;
32+
import java.io.FileInputStream;
33+
import java.nio.file.Path;
34+
import java.security.AccessController;
35+
import java.security.PrivilegedActionException;
36+
import java.security.PrivilegedExceptionAction;
37+
import java.util.ArrayList;
38+
import java.util.List;
39+
import java.util.Map;
40+
import java.util.concurrent.atomic.AtomicInteger;
41+
42+
import static org.opensearch.ml.engine.ModelHelper.ONNX_ENGINE;
43+
import static org.opensearch.ml.engine.ModelHelper.ONNX_FILE_EXTENSION;
44+
import static org.opensearch.ml.engine.ModelHelper.PYTORCH_ENGINE;
45+
import static org.opensearch.ml.engine.ModelHelper.PYTORCH_FILE_EXTENSION;
46+
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
47+
48+
@Log4j2
49+
public abstract class DLModel implements Predictable {
50+
public static final String MODEL_ZIP_FILE = "model_zip_file";
51+
public static final String MODEL_HELPER = "model_helper";
52+
public static final String ML_ENGINE = "ml_engine";
53+
protected ModelHelper modelHelper;
54+
protected MLEngine mlEngine;
55+
protected String modelId;
56+
57+
protected Predictor<Input, Output>[] predictors;
58+
protected ZooModel[] models;
59+
protected Device[] devices;
60+
protected AtomicInteger nextDevice = new AtomicInteger(0);
61+
62+
@Override
63+
public MLOutput predict(MLInput mlInput, MLModel model) {
64+
throw new MLException("model not loaded");
65+
}
66+
67+
@Override
68+
public MLOutput predict(MLInput mlInput) {
69+
if (modelHelper == null || modelId == null) {
70+
throw new MLException("model not loaded");
71+
}
72+
try {
73+
return AccessController.doPrivileged((PrivilegedExceptionAction<ModelTensorOutput>) () -> {
74+
Thread.currentThread().setContextClassLoader(getClass().getClassLoader());
75+
if (predictors == null) {
76+
throw new MLException("model not loaded.");
77+
}
78+
return predict(modelId, mlInput.getInputDataset());
79+
});
80+
} catch (Throwable e) {
81+
String errorMsg = "Failed to inference " + mlInput.getAlgorithm() + " model: " + modelId;
82+
log.error(errorMsg, e);
83+
throw new MLException(errorMsg, e);
84+
}
85+
}
86+
87+
protected Predictor<Input, Output> getPredictor() {
88+
int currentDevice = nextDevice.getAndIncrement();
89+
if (currentDevice > devices.length - 1) {
90+
currentDevice = currentDevice % devices.length;
91+
nextDevice.set(currentDevice + 1);
92+
}
93+
return predictors[currentDevice];
94+
}
95+
96+
public abstract ModelTensorOutput predict(String modelId, MLInputDataset inputDataSet) throws TranslateException;
97+
98+
@Override
99+
public void initModel(MLModel model, Map<String, Object> params) {
100+
String engine;
101+
switch (model.getModelFormat()) {
102+
case TORCH_SCRIPT:
103+
engine = PYTORCH_ENGINE;
104+
break;
105+
case ONNX:
106+
engine = ONNX_ENGINE;
107+
break;
108+
default:
109+
throw new IllegalArgumentException("unsupported engine");
110+
}
111+
112+
File modelZipFile = (File)params.get(MODEL_ZIP_FILE);
113+
modelHelper = (ModelHelper)params.get(MODEL_HELPER);
114+
mlEngine = (MLEngine)params.get(ML_ENGINE);
115+
if (modelZipFile == null) {
116+
throw new IllegalArgumentException("model file is null");
117+
}
118+
if (modelHelper == null) {
119+
throw new IllegalArgumentException("model helper is null");
120+
}
121+
if (mlEngine == null) {
122+
throw new IllegalArgumentException("ML engine is null");
123+
}
124+
modelId = model.getModelId();
125+
if (modelId == null) {
126+
throw new IllegalArgumentException("model id is null");
127+
}
128+
if (!FunctionName.isDLModel(model.getAlgorithm())) {
129+
throw new IllegalArgumentException("wrong function name");
130+
}
131+
loadModel(
132+
modelZipFile,
133+
modelId,
134+
model.getName(),
135+
model.getVersion(),
136+
model.getModelConfig(),
137+
engine
138+
);
139+
}
140+
141+
@Override
142+
public void close() {
143+
if (modelHelper != null && modelId != null) {
144+
modelHelper.deleteFileCache(modelId);
145+
if (predictors != null) {
146+
closePredictors(predictors);
147+
predictors = null;
148+
}
149+
if (models != null) {
150+
closeModels(models);
151+
models = null;
152+
}
153+
}
154+
}
155+
156+
public abstract Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig);
157+
158+
public abstract TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig);
159+
160+
public Map<String, Object> getArguments(MLModelConfig modelConfig) {
161+
return null;
162+
}
163+
164+
public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {}
165+
166+
private void loadModel(File modelZipFile, String modelId, String modelName, String version,
167+
MLModelConfig modelConfig,
168+
String engine) {
169+
try {
170+
List<Predictor<Input, Output>> predictorList = new ArrayList<>();
171+
List<ZooModel<Input, Output>> modelList = new ArrayList<>();
172+
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
173+
ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
174+
try {
175+
System.setProperty("PYTORCH_PRECXX11", "true");
176+
System.setProperty("DJL_CACHE_DIR", mlEngine.getDjlCachePath().toAbsolutePath().toString());
177+
// DJL will read "/usr/java/packages/lib" if don't set "java.library.path". That will throw
178+
// access denied exception
179+
System.setProperty("java.library.path", mlEngine.getDjlCachePath().toAbsolutePath().toString());
180+
System.setProperty("ai.djl.pytorch.num_interop_threads", "1");
181+
System.setProperty("ai.djl.pytorch.num_threads", "1");
182+
Thread.currentThread().setContextClassLoader(ai.djl.Model.class.getClassLoader());
183+
Path modelPath = mlEngine.getModelCachePath(modelId, modelName, version);
184+
File pathFile = new File(modelPath.toUri());
185+
if (pathFile.exists()) {
186+
FileUtils.deleteDirectory(pathFile);
187+
}
188+
try (FileInputStream fileInputStream = new FileInputStream(modelZipFile)) {
189+
ZipUtils.unzip(fileInputStream, modelPath);
190+
}
191+
boolean findModelFile = false;
192+
for (File file : pathFile.listFiles()) {
193+
String name = file.getName();
194+
if (name.endsWith(PYTORCH_FILE_EXTENSION) || name.endsWith(ONNX_FILE_EXTENSION)) {
195+
if (findModelFile) {
196+
throw new IllegalArgumentException("found multiple models");
197+
}
198+
findModelFile = true;
199+
int dotIndex = name.lastIndexOf(".");
200+
String suffix = name.substring(dotIndex);
201+
String targetModelFileName = modelPath.getFileName().toString();
202+
if (!targetModelFileName.equals(name.substring(0, dotIndex))) {
203+
file.renameTo(new File(modelPath.resolve(targetModelFileName + suffix).toUri()));
204+
}
205+
}
206+
}
207+
devices = Engine.getEngine(engine).getDevices();
208+
for (int i = 0; i < devices.length; i++) {
209+
log.debug("load model {} on device {}: {}", modelId, i, devices[i]);
210+
Criteria.Builder<Input, Output> criteriaBuilder = Criteria.builder()
211+
.setTypes(Input.class, Output.class)
212+
.optApplication(Application.UNDEFINED)
213+
.optEngine(engine)
214+
.optDevice(devices[i])
215+
.optModelPath(modelPath);
216+
Translator translator = getTranslator(engine, modelConfig);
217+
TranslatorFactory translatorFactory = getTranslatorFactory(engine, modelConfig);
218+
if (translatorFactory != null) {
219+
criteriaBuilder.optTranslatorFactory(translatorFactory);
220+
} else if (translator != null) {
221+
criteriaBuilder.optTranslator(translator);
222+
}
223+
224+
Map<String, Object> arguments = getArguments(modelConfig);
225+
if (arguments != null && arguments.size() > 0) {
226+
for (Map.Entry<String,Object> entry : arguments.entrySet()) {
227+
criteriaBuilder.optArgument(entry.getKey(), entry.getValue());
228+
}
229+
}
230+
231+
Criteria<Input, Output> criteria = criteriaBuilder.build();
232+
ZooModel<Input, Output> model = criteria.loadModel();
233+
Predictor<Input, Output> predictor = model.newPredictor();
234+
predictorList.add(predictor);
235+
modelList.add(model);
236+
237+
// First request takes longer time. Predict once to warm up model.
238+
warmUp(predictor, modelId, modelConfig);
239+
}
240+
if (predictorList.size() > 0) {
241+
this.predictors = predictorList.toArray(new Predictor[0]);
242+
predictorList.clear();
243+
}
244+
if (modelList.size() > 0) {
245+
this.models = modelList.toArray(new ZooModel[0]);
246+
modelList.clear();
247+
}
248+
log.info("Model {} is successfully loaded on {} devices", modelId, devices.length);
249+
return null;
250+
} catch (Throwable e) {
251+
String errorMessage = "Failed to load model " + modelId;
252+
log.error(errorMessage, e);
253+
close();
254+
if (predictorList.size() > 0) {
255+
closePredictors(predictorList.toArray(new Predictor[0]));
256+
predictorList.clear();
257+
}
258+
if (modelList.size() > 0) {
259+
closeModels(modelList.toArray(new ZooModel[0]));
260+
modelList.clear();
261+
}
262+
throw new MLException(errorMessage, e);
263+
} finally {
264+
deleteFileQuietly(mlEngine.getLoadModelPath(modelId));
265+
Thread.currentThread().setContextClassLoader(contextClassLoader);
266+
}
267+
});
268+
} catch (PrivilegedActionException e) {
269+
String errorMsg = "Failed to load model " + modelId;
270+
log.error(errorMsg, e);
271+
throw new MLException(errorMsg, e);
272+
}
273+
}
274+
275+
protected void closePredictors(Predictor[] predictors) {
276+
log.debug("will close {} predictor for model {}", predictors.length, modelId);
277+
for (Predictor<Input, Output> predictor : predictors) {
278+
predictor.close();
279+
}
280+
}
281+
282+
protected void closeModels(ZooModel[] models) {
283+
log.debug("will close {} zoo model for model {}", models.length, modelId);
284+
for (ZooModel model : models) {
285+
model.close();
286+
}
287+
}
288+
289+
/**
290+
* Parse model output to model tensor output and apply result filter.
291+
* @param output model output
292+
* @param resultFilter result filter
293+
* @return model tensor output
294+
*/
295+
public ModelTensors parseModelTensorOutput(Output output, ModelResultFilter resultFilter) {
296+
if (output == null) {
297+
throw new MLException("No output generated");
298+
}
299+
byte[] bytes = output.getData().getAsBytes();
300+
ModelTensors tensorOutput = ModelTensors.fromBytes(bytes);
301+
if (resultFilter != null) {
302+
tensorOutput.filter(resultFilter);
303+
}
304+
return tensorOutput;
305+
}
306+
307+
}

0 commit comments

Comments
 (0)