Skip to content

Commit 8e0bfbf

Browse files
committed
Fix: missing lock acquisition
1 parent 0dfa526 commit 8e0bfbf

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,13 @@ protected CuVSResources createNew() {
211211
@Override
212212
public void finishedComputation(ManagedCuVSResources resources) {
213213
// Allow acquire to return possibly blocked resources
214-
enoughResourcesCondition.signalAll();
214+
try {
215+
lock.lock();
216+
assert resources.locked;
217+
enoughResourcesCondition.signalAll();
218+
} finally {
219+
lock.unlock();
220+
}
215221
}
216222

217223
@Override

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/NVML.java

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,42 +30,38 @@ class NVML {
3030
/**
3131
* nvmlReturn_t nvmlInit_v2 ( void )
3232
*/
33-
static final MethodHandle nvmlInit_v2$mh = Linker.nativeLinker().downcallHandle(
34-
findOrThrow("nvmlInit_v2"),
35-
FunctionDescriptor.of(ValueLayout.JAVA_INT)
36-
);
33+
static final MethodHandle nvmlInit_v2$mh = Linker.nativeLinker()
34+
.downcallHandle(findOrThrow("nvmlInit_v2"), FunctionDescriptor.of(ValueLayout.JAVA_INT));
3735

3836
/**
3937
* nvmlReturn_t nvmlShutdown ( void )
4038
*/
41-
static final MethodHandle nvmlShutdown$mh = Linker.nativeLinker().downcallHandle(
42-
findOrThrow("nvmlShutdown"),
43-
FunctionDescriptor.of(ValueLayout.JAVA_INT)
44-
);
39+
static final MethodHandle nvmlShutdown$mh = Linker.nativeLinker()
40+
.downcallHandle(findOrThrow("nvmlShutdown"), FunctionDescriptor.of(ValueLayout.JAVA_INT));
4541

4642
/**
4743
* const DECLDIR char* nvmlErrorString ( nvmlReturn_t result )
4844
*/
49-
static final MethodHandle nvmlErrorString$mh = Linker.nativeLinker().downcallHandle(
50-
findOrThrow("nvmlErrorString"),
51-
FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.JAVA_INT)
52-
);
45+
static final MethodHandle nvmlErrorString$mh = Linker.nativeLinker()
46+
.downcallHandle(findOrThrow("nvmlErrorString"), FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.JAVA_INT));
5347

5448
/**
5549
* nvmlReturn_t nvmlDeviceGetHandleByIndex_v2 ( unsigned int index, nvmlDevice_t* device )
5650
*/
57-
static final MethodHandle nvmlDeviceGetHandleByIndex_v2$mh = Linker.nativeLinker().downcallHandle(
58-
findOrThrow("nvmlDeviceGetHandleByIndex_v2"),
59-
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS)
60-
);
51+
static final MethodHandle nvmlDeviceGetHandleByIndex_v2$mh = Linker.nativeLinker()
52+
.downcallHandle(
53+
findOrThrow("nvmlDeviceGetHandleByIndex_v2"),
54+
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS)
55+
);
6156

6257
/**
6358
* nvmlReturn_t nvmlDeviceGetUtilizationRates ( nvmlDevice_t device, nvmlUtilization_t* utilization )
6459
*/
65-
static final MethodHandle nvmlDeviceGetUtilizationRates$mh = Linker.nativeLinker().downcallHandle(
66-
findOrThrow("nvmlDeviceGetUtilizationRates"),
67-
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS)
68-
);
60+
static final MethodHandle nvmlDeviceGetUtilizationRates$mh = Linker.nativeLinker()
61+
.downcallHandle(
62+
findOrThrow("nvmlDeviceGetUtilizationRates"),
63+
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS)
64+
);
6965

7066
public static class nvmlUtilization_t {
7167

@@ -85,7 +81,7 @@ public static GroupLayout layout() {
8581
return $LAYOUT;
8682
}
8783

88-
private static final ValueLayout.OfInt gpu$LAYOUT = (ValueLayout.OfInt)$LAYOUT.select(groupElement("gpu"));
84+
private static final ValueLayout.OfInt gpu$LAYOUT = (ValueLayout.OfInt) $LAYOUT.select(groupElement("gpu"));
8985

9086
/**
9187
* Getter for field: gpu
@@ -95,7 +91,7 @@ public static int gpu(MemorySegment struct) {
9591
return struct.get(gpu$LAYOUT, 0);
9692
}
9793

98-
private static final ValueLayout.OfInt memory$LAYOUT = (ValueLayout.OfInt)$LAYOUT.select(groupElement("memory"));
94+
private static final ValueLayout.OfInt memory$LAYOUT = (ValueLayout.OfInt) $LAYOUT.select(groupElement("memory"));
9995

10096
/**
10197
* Getter for field: memory
@@ -113,7 +109,7 @@ private static MemorySegment findOrThrow(String symbol) {
113109
public static void nvmlInit_v2() {
114110
int res;
115111
try {
116-
res = (int)nvmlInit_v2$mh.invokeExact();
112+
res = (int) nvmlInit_v2$mh.invokeExact();
117113
} catch (Throwable ex$) {
118114
throw new AssertionError("should not reach here", ex$);
119115
}
@@ -125,7 +121,7 @@ public static void nvmlInit_v2() {
125121
public static void nvmlShutdown() {
126122
int res;
127123
try {
128-
res = (int)nvmlShutdown$mh.invokeExact();
124+
res = (int) nvmlShutdown$mh.invokeExact();
129125
} catch (Throwable ex$) {
130126
throw new AssertionError("should not reach here", ex$);
131127
}
@@ -139,7 +135,7 @@ public static MemorySegment nvmlDeviceGetHandleByIndex_v2(int index) {
139135
MemorySegment nvmlDevice;
140136
try (var localArena = Arena.ofConfined()) {
141137
MemorySegment devicePtr = localArena.allocate(ValueLayout.ADDRESS);
142-
res = (int)nvmlDeviceGetHandleByIndex_v2$mh.invokeExact(index,devicePtr);
138+
res = (int) nvmlDeviceGetHandleByIndex_v2$mh.invokeExact(index, devicePtr);
143139
nvmlDevice = devicePtr.get(ValueLayout.ADDRESS, 0);
144140
} catch (Throwable ex$) {
145141
throw new AssertionError("should not reach here", ex$);
@@ -152,8 +148,8 @@ public static MemorySegment nvmlDeviceGetHandleByIndex_v2(int index) {
152148

153149
public static void nvmlDeviceGetUtilizationRates(MemorySegment nvmlDevice, MemorySegment nvmlUtilizationPtr) {
154150
int res;
155-
try {
156-
res = (int)nvmlDeviceGetUtilizationRates$mh.invokeExact(nvmlDevice, nvmlUtilizationPtr);
151+
try {
152+
res = (int) nvmlDeviceGetUtilizationRates$mh.invokeExact(nvmlDevice, nvmlUtilizationPtr);
157153
} catch (Throwable ex$) {
158154
throw new AssertionError("should not reach here", ex$);
159155
}
@@ -172,7 +168,7 @@ public static String nvmlErrorString(int result) {
172168
if (seg.equals(MemorySegment.NULL)) {
173169
return "no last error text";
174170
}
175-
return MemorySegmentUtil.getString(seg,0);
171+
return MemorySegmentUtil.getString(seg, 0);
176172
} catch (Throwable ex$) {
177173
throw new AssertionError("should not reach here", ex$);
178174
}

0 commit comments

Comments
 (0)