Skip to content

Commit b72dd15

Browse files
authored
[Java] Add OrtCompiledModelCompatibility and minor updates for OrtEpDevices (#26028)
### Description Adds the Java bits mirroring #25878, and renames a few things in #25131 for uniformity with the other APIs. ### Motivation and Context Java API parity.
1 parent 69308a3 commit b72dd15

File tree

7 files changed

+225
-30
lines changed

7 files changed

+225
-30
lines changed

java/src/main/java/ai/onnxruntime/OrtEnvironment.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,30 @@ public List<OrtEpDevice> getEpDevices() throws OrtException {
488488
return Collections.unmodifiableList(devicesList);
489489
}
490490

491+
/**
492+
* Checks the supplied model info string against the list of {@link OrtEpDevice}s to see if the
493+
* model is compatible.
494+
*
495+
* @param epDevices The EP-Device tuples to use.
496+
* @param modelInfo The model info string to check.
497+
* @return The model compatibility.
498+
* @throws OrtException If the call failed.
499+
*/
500+
public OrtCompiledModelCompatibility getModelCompatibilityForEpDevices(
501+
List<OrtEpDevice> epDevices, String modelInfo) throws OrtException {
502+
if (epDevices == null || epDevices.isEmpty()) {
503+
throw new IllegalArgumentException("Must supply at least one OrtEpDevice");
504+
}
505+
long[] deviceHandles = new long[epDevices.size()];
506+
for (int i = 0; i < epDevices.size(); i++) {
507+
deviceHandles[i] = epDevices.get(i).getNativeHandle();
508+
}
509+
510+
int output =
511+
getModelCompatibilityForEpDevices(OnnxRuntime.ortApiHandle, deviceHandles, modelInfo);
512+
return OrtCompiledModelCompatibility.mapFromInt(output);
513+
}
514+
491515
/**
492516
* Creates the native object.
493517
*
@@ -556,6 +580,18 @@ private static native void unregisterExecutionProviderLibrary(
556580
*/
557581
private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException;
558582

583+
/**
584+
* Checks if a model is compatible with the supplied list of EP device handles.
585+
*
586+
* @param apiHandle The API handle to use.
587+
* @param epHandles An array of OrtEpDevice handles.
588+
* @param modelInfo The model info string.
589+
* @return An int representing the {@link OrtCompiledModelCompatibility} value.
590+
* @throws OrtException If the call failed.
591+
*/
592+
private static native int getModelCompatibilityForEpDevices(
593+
long apiHandle, long[] epHandles, String modelInfo) throws OrtException;
594+
559595
/**
560596
* Closes the OrtEnvironment, frees the handle.
561597
*
@@ -580,6 +616,59 @@ private static native void setTelemetry(long apiHandle, long nativeHandle, boole
580616
@Override
581617
public void close() {}
582618

619+
/** Enum representing a compiled model's compatibility with a set of {@link OrtEpDevice}s. */
620+
public enum OrtCompiledModelCompatibility {
621+
/** The EP is not applicable for the model. */
622+
EP_NOT_APPLICABLE(0),
623+
/** The EP supports the model optimally. */
624+
EP_SUPPORTED_OPTIMAL(1),
625+
/** The EP supports the model, but the model would perform better if recompiled. */
626+
EP_SUPPORTED_PREFER_RECOMPILATION(2),
627+
/** The EP does not support the model. */
628+
EP_UNSUPPORTED(3);
629+
630+
private final int value;
631+
632+
private static final Logger logger =
633+
Logger.getLogger(OrtCompiledModelCompatibility.class.getName());
634+
private static final OrtCompiledModelCompatibility[] values =
635+
new OrtCompiledModelCompatibility[4];
636+
637+
static {
638+
for (OrtCompiledModelCompatibility ot : OrtCompiledModelCompatibility.values()) {
639+
values[ot.value] = ot;
640+
}
641+
}
642+
643+
OrtCompiledModelCompatibility(int value) {
644+
this.value = value;
645+
}
646+
647+
/**
648+
* Gets the native value associated with this model compatibility value.
649+
*
650+
* @return The native value.
651+
*/
652+
public int getValue() {
653+
return value;
654+
}
655+
656+
/**
657+
* Maps from the C API's int enum to the Java enum.
658+
*
659+
* @param logLevel The index of the Java enum.
660+
* @return The Java enum.
661+
*/
662+
public static OrtCompiledModelCompatibility mapFromInt(int logLevel) {
663+
if ((logLevel >= 0) && (logLevel < values.length)) {
664+
return values[logLevel];
665+
} else {
666+
logger.warning("Unknown model compatibility " + logLevel + " setting to EP_UNSUPPORTED");
667+
return EP_UNSUPPORTED;
668+
}
669+
}
670+
}
671+
583672
/**
584673
* Controls the global thread pools in the environment. Only used if the session is constructed
585674
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.

java/src/main/java/ai/onnxruntime/OrtEpDevice.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ public final class OrtEpDevice {
2424
*/
2525
OrtEpDevice(long nativeHandle) {
2626
this.nativeHandle = nativeHandle;
27-
this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle);
28-
this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle);
29-
String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
27+
this.epName = getEpName(OnnxRuntime.ortApiHandle, nativeHandle);
28+
this.epVendor = getEpVendor(OnnxRuntime.ortApiHandle, nativeHandle);
29+
String[][] metadata = getEpMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
3030
this.epMetadata = OrtUtil.convertToMap(metadata);
31-
String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle);
31+
String[][] options = getEpOptions(OnnxRuntime.ortApiHandle, nativeHandle);
3232
this.epOptions = OrtUtil.convertToMap(options);
3333
this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle));
3434
}
@@ -43,38 +43,38 @@ long getNativeHandle() {
4343
}
4444

4545
/**
46-
* Gets the EP name.
46+
* Gets the Execution Provider name.
4747
*
4848
* @return The EP name.
4949
*/
50-
public String getName() {
50+
public String getEpName() {
5151
return epName;
5252
}
5353

5454
/**
55-
* Gets the vendor name.
55+
* Gets the Execution Provider vendor name.
5656
*
57-
* @return The vendor name.
57+
* @return The EP vendor name.
5858
*/
59-
public String getVendor() {
59+
public String getEpVendor() {
6060
return epVendor;
6161
}
6262

6363
/**
64-
* Gets an unmodifiable view on the EP metadata.
64+
* Gets an unmodifiable view on the Execution Provider metadata.
6565
*
6666
* @return The EP metadata.
6767
*/
68-
public Map<String, String> getMetadata() {
68+
public Map<String, String> getEpMetadata() {
6969
return epMetadata;
7070
}
7171

7272
/**
73-
* Gets an unmodifiable view on the EP options.
73+
* Gets an unmodifiable view on the Execution Provider options.
7474
*
7575
* @return The EP options.
7676
*/
77-
public Map<String, String> getOptions() {
77+
public Map<String, String> getEpOptions() {
7878
return epOptions;
7979
}
8080

@@ -105,13 +105,13 @@ public String toString() {
105105
+ '}';
106106
}
107107

108-
private static native String getName(long apiHandle, long nativeHandle);
108+
private static native String getEpName(long apiHandle, long nativeHandle);
109109

110-
private static native String getVendor(long apiHandle, long nativeHandle);
110+
private static native String getEpVendor(long apiHandle, long nativeHandle);
111111

112-
private static native String[][] getMetadata(long apiHandle, long nativeHandle);
112+
private static native String[][] getEpMetadata(long apiHandle, long nativeHandle);
113113

114-
private static native String[][] getOptions(long apiHandle, long nativeHandle);
114+
private static native String[][] getEpOptions(long apiHandle, long nativeHandle);
115115

116116
private static native long getDeviceHandle(long apiHandle, long nativeHandle);
117117
}

java/src/main/native/OrtJniUtil.c

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
#include <jni.h>
@@ -109,6 +109,42 @@ jint convertFromOrtSparseFormat(OrtSparseFormat format) {
109109
}
110110
}
111111

112+
/**
113+
* Must be kept in sync with convertToCompiledModelCompatibility.
114+
*/
115+
jint convertFromCompiledModelCompatibility(OrtCompiledModelCompatibility compat) {
116+
switch (compat) {
117+
case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE:
118+
return 0;
119+
case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL:
120+
return 1;
121+
case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION:
122+
return 2;
123+
case OrtCompiledModelCompatibility_EP_UNSUPPORTED:
124+
return 3;
125+
default:
126+
// if this value is observed the enum has changed and the code should be updated.
127+
return -1;
128+
}
129+
}
130+
131+
/**
132+
* Must be kept in sync with convertFromCompiledModelCompatibility.
133+
*/
134+
OrtCompiledModelCompatibility convertToCompiledModelCompatibility(jint compat) {
135+
switch (compat) {
136+
case 0:
137+
return OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
138+
case 1:
139+
return OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
140+
case 2:
141+
return OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
142+
case 3:
143+
default:
144+
return OrtCompiledModelCompatibility_EP_UNSUPPORTED;
145+
}
146+
}
147+
112148
/**
113149
* Must be kept in sync with convertToONNXDataFormat
114150
*/

java/src/main/native/OrtJniUtil.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
#include <jni.h>
@@ -34,6 +34,10 @@ OrtSparseFormat convertToOrtSparseFormat(jint format);
3434

3535
jint convertFromOrtSparseFormat(OrtSparseFormat format);
3636

37+
jint convertFromCompiledModelCompatibility(OrtCompiledModelCompatibility compat);
38+
39+
OrtCompiledModelCompatibility convertToCompiledModelCompatibility(jint compat);
40+
3741
jint convertFromONNXDataFormat(ONNXTensorElementDataType type);
3842

3943
ONNXTensorElementDataType convertToONNXDataFormat(jint type);

java/src/main/native/ai_onnxruntime_OrtEnvironment.c

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
33
* Licensed under the MIT License.
44
*/
55
#include <jni.h>
@@ -130,6 +130,42 @@ JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices
130130
}
131131
}
132132

133+
/*
134+
* Class: ai_onnxruntime_OrtEnvironment
135+
* Method: getModelCompatibilityForEpDevices
136+
* Signature: (J[JLjava/lang/String;)I
137+
*/
138+
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtEnvironment_getModelCompatibilityForEpDevices
139+
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlongArray epHandles, jstring modelInfo) {
140+
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
141+
const OrtApi* api = (const OrtApi*) apiHandle;
142+
143+
// convert pointers for EpDevice handles
144+
jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, epHandles);
145+
const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *));
146+
jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, epHandles, NULL);
147+
for (jsize i = 0; i < deviceCount; i++) {
148+
devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i];
149+
}
150+
(*jniEnv)->ReleaseLongArrayElements(jniEnv, epHandles, deviceHandleElements, JNI_ABORT);
151+
152+
// get utf-8 string
153+
const char* modelStr = (*jniEnv)->GetStringUTFChars(jniEnv, modelInfo, NULL);
154+
155+
OrtCompiledModelCompatibility compatibility;
156+
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetModelCompatibilityForEpDevices(devicePtrs, deviceCount, modelStr, &compatibility));
157+
158+
// cleanup
159+
(*jniEnv)->ReleaseStringUTFChars(jniEnv, modelInfo, modelStr);
160+
free((void*)devicePtrs);
161+
if (code != ORT_OK) {
162+
return -1;
163+
} else {
164+
jint returnVal = convertFromCompiledModelCompatibility(compatibility);
165+
return returnVal;
166+
}
167+
}
168+
133169
/*
134170
* Class: ai_onnxruntime_OrtEnvironment
135171
* Method: close

java/src/main/native/ai_onnxruntime_OrtEpDevice.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
* Method: getName
1313
* Signature: (JJ)Ljava/lang/String;
1414
*/
15-
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName
15+
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpName
1616
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
1717
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
1818
const OrtApi* api = (const OrtApi*) apiHandle;
@@ -27,7 +27,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName
2727
* Method: getVendor
2828
* Signature: (JJ)Ljava/lang/String;
2929
*/
30-
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor
30+
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpVendor
3131
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
3232
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
3333
const OrtApi* api = (const OrtApi*) apiHandle;
@@ -42,7 +42,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor
4242
* Method: getMetadata
4343
* Signature: (JJ)[[Ljava/lang/String;
4444
*/
45-
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata
45+
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpMetadata
4646
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
4747
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
4848
const OrtApi* api = (const OrtApi*) apiHandle;
@@ -57,7 +57,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata
5757
* Method: getOptions
5858
* Signature: (JJ)[[Ljava/lang/String;
5959
*/
60-
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions
60+
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpOptions
6161
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
6262
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
6363
const OrtApi* api = (const OrtApi*) apiHandle;

0 commit comments

Comments
 (0)