Skip to content

Commit ccb2243

Browse files
authored
Update build option for training in java to enable_training_api (#15638)
### Description Updating the build option for enabling training in java builds from ENABLE_TRAINING -> ENABLE_TRAINING_APIS. In the native codebase ENABLE_TRAINING is used for enabling full training and ENABLE_TRAINING_APIS is used for creating the lte builds with training apis. Making the change to sync the naming convention across all the language bindings. It was a bit confusing to see ENABLE_TRAINING when debugging the android build failures for training. Making this change just to improve readability of logs during debugging. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 686fd3c commit ccb2243

File tree

5 files changed

+10
-8
lines changed

5 files changed

+10
-8
lines changed

cmake/onnxruntime_java.cmake

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ file(GLOB onnxruntime4j_native_src
4747
"${JAVA_ROOT}/src/main/native/*.c"
4848
"${JAVA_ROOT}/src/main/native/*.h"
4949
"${REPO_ROOT}/include/onnxruntime/core/session/*.h"
50-
"${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h"
50+
"${REPO_ROOT}/orttraining/orttraining/training_api/include/*.h"
5151
)
5252
# Build the JNI library
5353
onnxruntime_add_shared_library_module(onnxruntime4j_jni ${onnxruntime4j_native_src})
@@ -193,7 +193,7 @@ endif()
193193
# Append relevant native build flags to gradle command
194194
set(GRADLE_ARGS ${GRADLE_ARGS} ${ORT_PROVIDER_FLAGS})
195195
if (onnxruntime_ENABLE_TRAINING_APIS)
196-
set(GRADLE_ARGS ${GRADLE_ARGS} "-DENABLE_TRAINING=1")
196+
set(GRADLE_ARGS ${GRADLE_ARGS} "-DENABLE_TRAINING_APIS=1")
197197
endif()
198198

199199
message(STATUS "GRADLE_ARGS: ${GRADLE_ARGS}")
@@ -208,6 +208,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android")
208208
# Copy onnxruntime.so and onnxruntime4j_jni.so for building Android AAR package
209209
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime> ${ANDROID_PACKAGE_ABI_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime>)
210210
add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:onnxruntime4j_jni> ${ANDROID_PACKAGE_ABI_DIR}/$<TARGET_LINKER_FILE_NAME:onnxruntime4j_jni>)
211+
211212
# Generate the Android AAR package
212213
add_custom_command(TARGET onnxruntime4j_jni
213214
POST_BUILD
@@ -217,6 +218,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Android")
217218
-b build-android.gradle -c settings-android.gradle
218219
-DjniLibsDir=${ANDROID_PACKAGE_JNILIBS_DIR} -DbuildDir=${ANDROID_PACKAGE_OUTPUT_DIR}
219220
-DminSdkVer=${ANDROID_MIN_SDK} -DheadersDir=${ANDROID_HEADERS_DIR}
221+
$<$<BOOL:${onnxruntime_ENABLE_TRAINING_APIS}>:-DENABLE_TRAINING_APIS=1>
220222
--stacktrace
221223
WORKING_DIRECTORY ${JAVA_ROOT})
222224

cmake/onnxruntime_java_unittests.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ FILE(TO_NATIVE_PATH ${BIN_DIR} BINDIR_NATIVE_PATH)
88
message(STATUS "GRADLE_TEST_EP_FLAGS: ${ORT_PROVIDER_FLAGS}")
99
if (onnxruntime_ENABLE_TRAINING_APIS)
1010
message(STATUS "Running ORT Java training tests")
11-
execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -Dorg.gradle.daemon=false ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING=1
11+
execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -Dorg.gradle.daemon=false ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING_APIS=1
1212
WORKING_DIRECTORY ${REPO_ROOT}/java
1313
RESULT_VARIABLE HAD_ERROR)
1414
else()

cmake/onnxruntime_unittests.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1448,7 +1448,7 @@ if (NOT onnxruntime_BUILD_WEBASSEMBLY)
14481448
${JAVA_NATIVE_TEST_DIR}/$<TARGET_LINKER_FILE_NAME:custom_op_library>)
14491449
if (onnxruntime_ENABLE_TRAINING_APIS)
14501450
message(STATUS "Running Java inference and training tests")
1451-
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING=1
1451+
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING_APIS=1
14521452
WORKING_DIRECTORY ${REPO_ROOT}/java)
14531453
else()
14541454
message(STATUS "Running Java inference tests only")

java/build.gradle

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ version = rootProject.file('../VERSION_NUMBER').text.trim()
1919
def cmakeBuildDir = System.properties['cmakeBuildDir']
2020
def useCUDA = System.properties['USE_CUDA']
2121
def useROCM = System.properties['USE_ROCM']
22-
def enableTraining = System.properties['ENABLE_TRAINING']
22+
def enableTrainingApis = System.properties['ENABLE_TRAINING_APIS']
2323
def cmakeJavaDir = "${cmakeBuildDir}/java"
2424
def cmakeNativeLibDir = "${cmakeJavaDir}/native-lib"
2525
def cmakeNativeJniDir = "${cmakeJavaDir}/native-jni"
@@ -29,7 +29,7 @@ def cmakeBuildOutputDir = "${cmakeJavaDir}/build"
2929
def mavenUser = System.properties['mavenUser']
3030
def mavenPwd = System.properties['mavenPwd']
3131

32-
def tmpArtifactId = enableTraining == null ? project.name : project.name + "-training"
32+
def tmpArtifactId = enableTrainingApis == null ? project.name : project.name + "-training"
3333
def mavenArtifactId = (useCUDA == null && useROCM == null) ? tmpArtifactId : tmpArtifactId + "_gpu"
3434

3535
java {
@@ -176,7 +176,7 @@ test {
176176
if (cmakeBuildDir != null) {
177177
workingDir cmakeBuildDir
178178
}
179-
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'JAVA_FULL_TEST', 'ENABLE_TRAINING'])
179+
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
180180
testLogging {
181181
events "passed", "skipped", "failed"
182182
showStandardStreams = true

java/src/test/java/ai/onnxruntime/TrainingTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
2222

2323
/** Tests for the ORT training apis. */
24-
@EnabledIfSystemProperty(named = "ENABLE_TRAINING", matches = "1")
24+
@EnabledIfSystemProperty(named = "ENABLE_TRAINING_APIS", matches = "1")
2525
public class TrainingTest {
2626

2727
private static final OrtEnvironment env = OrtEnvironment.getEnvironment();

0 commit comments

Comments
 (0)