Skip to content

Commit ba036a2

Browse files
committed
Add GPU utilization info to CuVSResourceManager via NVML
1 parent abe4245 commit ba036a2

File tree

6 files changed

+282
-23
lines changed

6 files changed

+282
-23
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.Strings;
1717
import org.elasticsearch.xpack.gpu.GPUSupport;
1818

19+
import java.lang.foreign.Arena;
1920
import java.nio.file.Path;
2021
import java.util.Objects;
2122
import java.util.concurrent.locks.Condition;
@@ -65,34 +66,55 @@ static CuVSResourceManager pooling() {
6566
return PoolingCuVSResourceManager.INSTANCE;
6667
}
6768

69+
@FunctionalInterface
70+
interface GpuInfoFunction {
71+
long get(CuVSResources resources);
72+
}
73+
6874
/**
6975
* A manager that maintains a pool of resources.
7076
*/
7177
class PoolingCuVSResourceManager implements CuVSResourceManager {
7278

7379
/** A multiplier on input data to account for intermediate and output data size required while processing it */
7480
static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
81+
static final int GPU_UTILIZATION_MAX_PERCENT = 80;
7582
static final int MAX_RESOURCES = 2;
83+
static final GPUInfoProvider gpuInfoProvider = CuVSProvider.provider().gpuInfoProvider();
7684
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
7785
MAX_RESOURCES,
78-
CuVSProvider.provider().gpuInfoProvider()
86+
res ->gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes(),
87+
res ->gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes(),
88+
PoolingCuVSResourceManager::getGpuUtilizationPercent
7989
);
8090

8191
private final ManagedCuVSResources[] pool;
8292
private final int capacity;
83-
private final GPUInfoProvider gpuInfoProvider;
8493
private int createdCount;
8594

86-
ReentrantLock lock = new ReentrantLock();
87-
Condition enoughMemoryCondition = lock.newCondition();
95+
private final GpuInfoFunction totalMemoryInBytesProvider;
96+
private final GpuInfoFunction freeMemoryInBytesProvider;
97+
private final GpuInfoFunction gpuUtilizationPercentProvider;
8898

89-
public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
99+
ReentrantLock lock = new ReentrantLock();
100+
Condition enoughResourcesCondition = lock.newCondition();
101+
102+
public PoolingCuVSResourceManager(
103+
int capacity,
104+
GpuInfoFunction totalMemoryInBytesProvider,
105+
GpuInfoFunction freeMemoryInBytesProvider,
106+
GpuInfoFunction gpuUtilizationPercentProvider
107+
) {
108+
this.totalMemoryInBytesProvider = totalMemoryInBytesProvider;
109+
this.freeMemoryInBytesProvider = freeMemoryInBytesProvider;
110+
this.gpuUtilizationPercentProvider = gpuUtilizationPercentProvider;
90111
if (capacity < 1 || capacity > MAX_RESOURCES) {
91112
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
92113
}
93114
this.capacity = capacity;
94-
this.gpuInfoProvider = gpuInfoProvider;
95115
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
116+
117+
NVML.nvmlInit_v2();
96118
}
97119

98120
private ManagedCuVSResources getResourceFromPool() {
@@ -130,35 +152,38 @@ public ManagedCuVSResources acquire(int numVectors, int dims) throws Interrupted
130152
ManagedCuVSResources res = null;
131153
while (allConditionsMet == false) {
132154
res = getResourceFromPool();
133-
// If no resource in the pool is locked, short circuit to avoid livelock
134-
if (numLockedResources() == 0) {
135-
break;
136-
}
155+
137156
final boolean enoughMemory;
157+
final boolean enoughComputation;
138158
if (res != null) {
159+
// If no resource in the pool is locked, short circuit to avoid livelock
160+
if (numLockedResources() == 0) {
161+
break;
162+
}
163+
139164
// Check resources availability
140-
// Memory
141165
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
142-
if (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes()) {
166+
if (requiredMemoryInBytes > totalMemoryInBytesProvider.get(res)) {
143167
throw new IllegalArgumentException(
144168
Strings.format(
145169
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
146170
numVectors,
147171
dims,
148-
gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes() / (1024L * 1024L)
172+
totalMemoryInBytesProvider.get(res) / (1024L * 1024L)
149173
)
150174
);
151175
}
152-
enoughMemory = requiredMemoryInBytes <= gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
176+
enoughMemory = requiredMemoryInBytes <= freeMemoryInBytesProvider.get(res);
177+
enoughComputation = gpuUtilizationPercentProvider.get(res) < GPU_UTILIZATION_MAX_PERCENT;
153178
} else {
154179
enoughMemory = false;
155-
}
156-
if (enoughMemory == false) {
157-
enoughMemoryCondition.await();
180+
enoughComputation = false;
158181
}
159182

160-
// TODO: add enoughComputation / enoughComputationCondition here
161-
allConditionsMet = enoughMemory; // && enoughComputation
183+
allConditionsMet = enoughMemory && enoughComputation;
184+
if (allConditionsMet == false) {
185+
enoughResourcesCondition.await();
186+
}
162187
}
163188
res.locked = true;
164189
return res;
@@ -171,15 +196,24 @@ private long estimateRequiredMemory(int numVectors, int dims) {
171196
return (long)(GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * Float.BYTES);
172197
}
173198

199+
private static int getGpuUtilizationPercent(CuVSResources resources) {
200+
try (var localArena = Arena.ofConfined()) {
201+
var deviceHandle = NVML.nvmlDeviceGetHandleByIndex_v2(resources.deviceId());
202+
var nvmlUtilizationPtr = localArena.allocate(NVML.nvmlUtilization_t.layout());
203+
NVML.nvmlDeviceGetUtilizationRates(deviceHandle, nvmlUtilizationPtr);
204+
return NVML.nvmlUtilization_t.gpu(nvmlUtilizationPtr);
205+
}
206+
}
207+
174208
// visible for testing
175209
protected CuVSResources createNew() {
176210
return GPUSupport.cuVSResourcesOrNull(true);
177211
}
178212

179213
@Override
180214
public void finishedComputation(ManagedCuVSResources resources) {
181-
// currently does nothing, but could allow acquire to return possibly blocked resources
182-
// something like enoughComputationCondition.signalAll()?
215+
// Allow acquire to return possibly blocked resources
216+
enoughResourcesCondition.signalAll();
183217
}
184218

185219
@Override
@@ -188,7 +222,7 @@ public void release(ManagedCuVSResources resources) {
188222
lock.lock();
189223
assert resources.locked;
190224
resources.locked = false;
191-
enoughMemoryCondition.signalAll();
225+
enoughResourcesCondition.signalAll();
192226
} finally {
193227
lock.unlock();
194228
}
@@ -201,6 +235,7 @@ public void shutdown() {
201235
assert res != null;
202236
res.delegate.close();
203237
}
238+
NVML.nvmlShutdown();
204239
}
205240
}
206241

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import java.lang.foreign.MemorySegment;
11+
12+
/**
13+
* Utility methods to act on MemorySegment apis which have changed in subsequent JDK releases.
14+
*/
15+
class MemorySegmentUtil {
16+
static String getString(MemorySegment segment, long offset) {
17+
return segment.getUtf8String(offset);
18+
}
19+
}
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import java.lang.foreign.Arena;
11+
import java.lang.foreign.FunctionDescriptor;
12+
import java.lang.foreign.GroupLayout;
13+
import java.lang.foreign.Linker;
14+
import java.lang.foreign.MemoryLayout;
15+
import java.lang.foreign.MemorySegment;
16+
import java.lang.foreign.SymbolLookup;
17+
import java.lang.foreign.ValueLayout;
18+
import java.lang.invoke.MethodHandle;
19+
20+
import static java.lang.foreign.MemoryLayout.PathElement.groupElement;
21+
22+
class NVML {
23+
24+
private static final SymbolLookup SYMBOL_LOOKUP = SymbolLookup.libraryLookup("libnvidia-ml.so.1", Arena.ofAuto())
25+
.or(SymbolLookup.loaderLookup())
26+
.or(Linker.nativeLinker().defaultLookup());
27+
28+
public static final int NVML_SUCCESS = 0;
29+
30+
/**
31+
* nvmlReturn_t nvmlInit_v2 ( void )
32+
*/
33+
static final MethodHandle nvmlInit_v2$mh = Linker.nativeLinker().downcallHandle(
34+
findOrThrow("nvmlInit_v2"),
35+
FunctionDescriptor.of(ValueLayout.JAVA_INT)
36+
);
37+
38+
/**
39+
* nvmlReturn_t nvmlShutdown ( void )
40+
*/
41+
static final MethodHandle nvmlShutdown$mh = Linker.nativeLinker().downcallHandle(
42+
findOrThrow("nvmlShutdown"),
43+
FunctionDescriptor.of(ValueLayout.JAVA_INT)
44+
);
45+
46+
/**
47+
* const DECLDIR char* nvmlErrorString ( nvmlReturn_t result )
48+
*/
49+
static final MethodHandle nvmlErrorString$mh = Linker.nativeLinker().downcallHandle(
50+
findOrThrow("nvmlErrorString"),
51+
FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.JAVA_INT)
52+
);
53+
54+
/**
55+
* nvmlReturn_t nvmlDeviceGetHandleByIndex_v2 ( unsigned int index, nvmlDevice_t* device )
56+
*/
57+
static final MethodHandle nvmlDeviceGetHandleByIndex_v2$mh = Linker.nativeLinker().downcallHandle(
58+
findOrThrow("nvmlDeviceGetHandleByIndex_v2"),
59+
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS)
60+
);
61+
62+
/**
63+
* nvmlReturn_t nvmlDeviceGetUtilizationRates ( nvmlDevice_t device, nvmlUtilization_t* utilization )
64+
*/
65+
static final MethodHandle nvmlDeviceGetUtilizationRates$mh = Linker.nativeLinker().downcallHandle(
66+
findOrThrow("nvmlDeviceGetUtilizationRates"),
67+
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS)
68+
);
69+
70+
public static class nvmlUtilization_t {
71+
72+
nvmlUtilization_t() {
73+
// Should not be called directly
74+
}
75+
76+
private static final GroupLayout $LAYOUT = MemoryLayout.structLayout(
77+
ValueLayout.JAVA_INT.withName("gpu"),
78+
ValueLayout.JAVA_INT.withName("memory")
79+
);
80+
81+
/**
82+
* The layout of this struct
83+
*/
84+
public static GroupLayout layout() {
85+
return $LAYOUT;
86+
}
87+
88+
private static final ValueLayout.OfInt gpu$LAYOUT = (ValueLayout.OfInt)$LAYOUT.select(groupElement("gpu"));
89+
90+
/**
91+
* Getter for field: gpu
92+
* Percent of time over the past sample period during which one or more kernels was executing on the GPU.
93+
*/
94+
public static int gpu(MemorySegment struct) {
95+
return struct.get(gpu$LAYOUT, 0);
96+
}
97+
98+
private static final ValueLayout.OfInt memory$LAYOUT = (ValueLayout.OfInt)$LAYOUT.select(groupElement("memory"));
99+
100+
/**
101+
* Getter for field: memory
102+
* Percent of time over the past sample period during which global (device) memory was being read or written.
103+
*/
104+
public static int memory(MemorySegment struct) {
105+
return struct.get(memory$LAYOUT, 4);
106+
}
107+
}
108+
109+
private static MemorySegment findOrThrow(String symbol) {
110+
return SYMBOL_LOOKUP.find(symbol).orElseThrow(() -> new UnsatisfiedLinkError("unresolved symbol: " + symbol));
111+
}
112+
113+
public static void nvmlInit_v2() {
114+
int res;
115+
try {
116+
res = (int)nvmlInit_v2$mh.invokeExact();
117+
} catch (Throwable ex$) {
118+
throw new AssertionError("should not reach here", ex$);
119+
}
120+
if (res != NVML_SUCCESS) {
121+
throw buildException(res);
122+
}
123+
}
124+
125+
public static void nvmlShutdown() {
126+
int res;
127+
try {
128+
res = (int)nvmlShutdown$mh.invokeExact();
129+
} catch (Throwable ex$) {
130+
throw new AssertionError("should not reach here", ex$);
131+
}
132+
if (res != NVML_SUCCESS) {
133+
throw buildException(res);
134+
}
135+
}
136+
137+
public static MemorySegment nvmlDeviceGetHandleByIndex_v2(int index) {
138+
int res;
139+
MemorySegment nvmlDevice;
140+
try (var localArena = Arena.ofConfined()) {
141+
MemorySegment devicePtr = localArena.allocate(ValueLayout.ADDRESS);
142+
res = (int)nvmlDeviceGetHandleByIndex_v2$mh.invokeExact(index,devicePtr);
143+
nvmlDevice = devicePtr.get(ValueLayout.ADDRESS, 0);
144+
} catch (Throwable ex$) {
145+
throw new AssertionError("should not reach here", ex$);
146+
}
147+
if (res != NVML_SUCCESS) {
148+
throw buildException(res);
149+
}
150+
return nvmlDevice;
151+
}
152+
153+
public static void nvmlDeviceGetUtilizationRates(MemorySegment nvmlDevice, MemorySegment nvmlUtilizationPtr) {
154+
int res;
155+
try {
156+
res = (int)nvmlDeviceGetUtilizationRates$mh.invokeExact(nvmlDevice, nvmlUtilizationPtr);
157+
} catch (Throwable ex$) {
158+
throw new AssertionError("should not reach here", ex$);
159+
}
160+
if (res != NVML_SUCCESS) {
161+
throw buildException(res);
162+
}
163+
}
164+
165+
private static RuntimeException buildException(int res) {
166+
return new RuntimeException("Error invoking NVML: " + res + "[" + nvmlErrorString(res) + "]");
167+
}
168+
169+
public static String nvmlErrorString(int result) {
170+
try {
171+
var seg = (MemorySegment) nvmlErrorString$mh.invokeExact(result);
172+
if (seg.equals(MemorySegment.NULL)) {
173+
return "no last error text";
174+
}
175+
return MemorySegmentUtil.getString(seg,0);
176+
} catch (Throwable ex$) {
177+
throw new AssertionError("should not reach here", ex$);
178+
}
179+
}
180+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
com.nvidia.cuvs:
22
- load_native_libraries
3+
org.elasticsearch.gpu:
4+
- load_native_libraries
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.gpu.codec;
9+
10+
import java.lang.foreign.MemorySegment;
11+
12+
/**
13+
* Utility methods to act on MemorySegment apis which have changed in subsequent JDK releases.
14+
*/
15+
class MemorySegmentUtil {
16+
static String getString(MemorySegment segment, long offset) {
17+
return segment.getString(offset);
18+
}
19+
}

0 commit comments

Comments
 (0)