Skip to content

Commit 2e1626c

Browse files
authored
[jvm-packages] Small cleanup for GPU QDM. (dmlc#11179)
1 parent a97f379 commit 2e1626c

File tree

7 files changed

+246
-175
lines changed

7 files changed

+246
-175
lines changed

jvm-packages/create_jni.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,17 @@ def native_build(cli_args: argparse.Namespace) -> None:
9595
CONFIG["USE_DLOPEN_NCCL"] = "OFF"
9696

9797
args = ["-D{0}:BOOL={1}".format(k, v) for k, v in CONFIG.items()]
98+
if sys.platform != "win32":
99+
try:
100+
subprocess.check_call(["ninja", "--version"])
101+
args.append("-GNinja")
102+
except FileNotFoundError:
103+
pass
98104

99105
# if enviorment set GPU_ARCH_FLAG
100106
gpu_arch_flag = os.getenv("GPU_ARCH_FLAG", None)
101107
if gpu_arch_flag is not None:
102-
args.append("%s" % gpu_arch_flag)
108+
args.append("-DCMAKE_CUDA_ARCHITECTURES=%s" % gpu_arch_flag)
103109

104110
with cd(build_dir):
105111
lib_dir = os.path.join(os.pardir, "lib")

jvm-packages/xgboost4j-spark-gpu/src/main/java/ml/dmlc/xgboost4j/java/QuantileDMatrix.java

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
Copyright (c) 2021-2024 by Contributors
2+
Copyright (c) 2021-2025 by Contributors
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -15,7 +15,40 @@
1515
*/
1616
package ml.dmlc.xgboost4j.java;
1717

18+
import java.io.IOException;
1819
import java.util.Iterator;
20+
import java.util.Map;
21+
22+
import com.fasterxml.jackson.core.JsonGenerator;
23+
import com.fasterxml.jackson.core.JsonProcessingException;
24+
import com.fasterxml.jackson.databind.JsonSerializer;
25+
import com.fasterxml.jackson.databind.ObjectMapper;
26+
import com.fasterxml.jackson.databind.SerializerProvider;
27+
import com.fasterxml.jackson.databind.module.SimpleModule;
28+
29+
class F64NaNSerializer extends JsonSerializer<Double> {
30+
@Override
31+
public void serialize(Double value, JsonGenerator gen,
32+
SerializerProvider serializers) throws IOException {
33+
if (value.isNaN()) {
34+
gen.writeRawValue("NaN"); // Write NaN without quotes
35+
} else {
36+
gen.writeNumber(value);
37+
}
38+
}
39+
}
40+
41+
class F32NaNSerializer extends JsonSerializer<Float> {
42+
@Override
43+
public void serialize(Float value, JsonGenerator gen,
44+
SerializerProvider serializers) throws IOException {
45+
if (value.isNaN()) {
46+
gen.writeRawValue("NaN"); // Write NaN without quotes
47+
} else {
48+
gen.writeNumber(value);
49+
}
50+
}
51+
}
1952

2053
/**
2154
* QuantileDMatrix will only be used to train
@@ -112,8 +145,23 @@ public void setGroup(int[] group) throws XGBoostError {
112145
}
113146

114147
private String getConfig(float missing, int maxBin, int nthread) {
115-
return String.format("{\"missing\":%f,\"max_bin\":%d,\"nthread\":%d}",
116-
missing, maxBin, nthread);
148+
Map<String, Object> conf = new java.util.HashMap<>();
149+
conf.put("missing", missing);
150+
conf.put("max_bin", maxBin);
151+
conf.put("nthread", nthread);
152+
ObjectMapper mapper = new ObjectMapper();
153+
154+
// Handle NaN values. Jackson by default serializes NaN values into strings.
155+
SimpleModule module = new SimpleModule();
156+
module.addSerializer(Double.class, new F64NaNSerializer());
157+
module.addSerializer(Float.class, new F32NaNSerializer());
158+
mapper.registerModule(module);
159+
160+
try {
161+
String config = mapper.writeValueAsString(conf);
162+
return config;
163+
} catch (JsonProcessingException e) {
164+
throw new RuntimeException("Failed to serialize configuration", e);
165+
}
117166
}
118-
119167
}

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Communicator.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import java.io.Serializable;
44
import java.nio.ByteBuffer;
55
import java.nio.ByteOrder;
6-
import java.util.LinkedList;
7-
import java.util.List;
86
import java.util.Map;
97

108
import com.fasterxml.jackson.core.JsonProcessingException;
Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
1+
/**
2+
* Copyright 2014-2025, XGBoost Contributors
3+
*/
14
#ifndef JVM_UTILS_H_
25
#define JVM_UTILS_H_
36

4-
#define JVM_CHECK_CALL(__expr) \
5-
{ \
6-
int __errcode = (__expr); \
7-
if (__errcode != 0) { \
8-
return __errcode; \
9-
} \
7+
#include <jni.h>
8+
9+
#include "xgboost/logging.h" // for Check
10+
11+
#define JVM_CHECK_CALL(__expr) \
12+
{ \
13+
int __errcode = (__expr); \
14+
if (__errcode != 0) { \
15+
return __errcode; \
16+
} \
1017
}
1118

12-
JavaVM*& GlobalJvm();
13-
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle);
19+
JavaVM *&GlobalJvm();
20+
void setHandle(JNIEnv *jenv, jlongArray jhandle, void *handle);
21+
22+
template <typename T>
23+
T CheckJvmCall(T const &v, JNIEnv *jenv) {
24+
if (!v) {
25+
CHECK(jenv->ExceptionOccurred());
26+
jenv->ExceptionDescribe();
27+
}
28+
return v;
29+
}
1430

1531
#endif // JVM_UTILS_H_

0 commit comments

Comments
 (0)