Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.core.Strings;
import org.elasticsearch.xpack.gpu.GPUSupport;

import java.lang.foreign.Arena;
import java.nio.file.Path;
import java.util.Objects;
import java.util.concurrent.locks.Condition;
Expand Down Expand Up @@ -63,34 +64,55 @@ static CuVSResourceManager pooling() {
return PoolingCuVSResourceManager.INSTANCE;
}

@FunctionalInterface
interface GpuInfoFunction {
long get(CuVSResources resources);
}

/**
* A manager that maintains a pool of resources.
*/
class PoolingCuVSResourceManager implements CuVSResourceManager {

/** A multiplier on input data to account for intermediate and output data size required while processing it */
static final double GPU_COMPUTATION_MEMORY_FACTOR = 2.0;
static final int GPU_UTILIZATION_MAX_PERCENT = 80;
static final int MAX_RESOURCES = 2;
static final GPUInfoProvider gpuInfoProvider = CuVSProvider.provider().gpuInfoProvider();
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
MAX_RESOURCES,
CuVSProvider.provider().gpuInfoProvider()
res -> gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes(),
res -> gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes(),
PoolingCuVSResourceManager::getGpuUtilizationPercent
);

private final ManagedCuVSResources[] pool;
private final int capacity;
private final GPUInfoProvider gpuInfoProvider;
private int createdCount;

private final GpuInfoFunction totalMemoryInBytesProvider;
private final GpuInfoFunction freeMemoryInBytesProvider;
private final GpuInfoFunction gpuUtilizationPercentProvider;

ReentrantLock lock = new ReentrantLock();
Condition enoughResourcesCondition = lock.newCondition();

public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
public PoolingCuVSResourceManager(
int capacity,
GpuInfoFunction totalMemoryInBytesProvider,
GpuInfoFunction freeMemoryInBytesProvider,
GpuInfoFunction gpuUtilizationPercentProvider
) {
this.totalMemoryInBytesProvider = totalMemoryInBytesProvider;
this.freeMemoryInBytesProvider = freeMemoryInBytesProvider;
this.gpuUtilizationPercentProvider = gpuUtilizationPercentProvider;
if (capacity < 1 || capacity > MAX_RESOURCES) {
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
}
this.capacity = capacity;
this.gpuInfoProvider = gpuInfoProvider;
this.pool = new ManagedCuVSResources[MAX_RESOURCES];

NVML.nvmlInit_v2();
}

private ManagedCuVSResources getResourceFromPool() {
Expand Down Expand Up @@ -130,29 +152,33 @@ public ManagedCuVSResources acquire(int numVectors, int dims) throws Interrupted
res = getResourceFromPool();

final boolean enoughMemory;
final boolean enoughComputation;
if (res != null) {
// If no resource in the pool is locked, short circuit to avoid livelock
if (numLockedResources() == 0) {
break;
}

// Check resources availability
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims);
if (requiredMemoryInBytes > gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes()) {
if (requiredMemoryInBytes > totalMemoryInBytesProvider.get(res)) {
throw new IllegalArgumentException(
Strings.format(
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%dMB]",
numVectors,
dims,
gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes() / (1024L * 1024L)
totalMemoryInBytesProvider.get(res) / (1024L * 1024L)
)
);
}
enoughMemory = requiredMemoryInBytes <= gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
enoughMemory = requiredMemoryInBytes <= freeMemoryInBytesProvider.get(res);
enoughComputation = gpuUtilizationPercentProvider.get(res) < GPU_UTILIZATION_MAX_PERCENT;
} else {
enoughMemory = false;
enoughComputation = false;
}
// TODO: add enoughComputation / enoughComputationCondition here
allConditionsMet = enoughMemory; // && enoughComputation

allConditionsMet = enoughMemory && enoughComputation;
if (allConditionsMet == false) {
enoughResourcesCondition.await();
}
Expand All @@ -168,15 +194,30 @@ private long estimateRequiredMemory(int numVectors, int dims) {
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * Float.BYTES);
}

private static int getGpuUtilizationPercent(CuVSResources resources) {
try (var localArena = Arena.ofConfined()) {
var deviceHandle = NVML.nvmlDeviceGetHandleByIndex_v2(resources.deviceId());
var nvmlUtilizationPtr = localArena.allocate(NVML.nvmlUtilization_t.layout());
NVML.nvmlDeviceGetUtilizationRates(deviceHandle, nvmlUtilizationPtr);
return NVML.nvmlUtilization_t.gpu(nvmlUtilizationPtr);
}
}

// visible for testing
protected CuVSResources createNew() {
return GPUSupport.cuVSResourcesOrNull(true);
}

@Override
public void finishedComputation(ManagedCuVSResources resources) {
// currently does nothing, but could allow acquire to return possibly blocked resources
// enoughResourcesCondition.signalAll()
// Allow acquire to return possibly blocked resources
try {
lock.lock();
assert resources.locked;
enoughResourcesCondition.signalAll();
} finally {
lock.unlock();
}
}

@Override
Expand All @@ -198,6 +239,7 @@ public void shutdown() {
assert res != null;
res.delegate.close();
}
NVML.nvmlShutdown();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.gpu.codec;

import java.lang.foreign.MemorySegment;

/**
* Utility methods to act on MemorySegment apis which have changed in subsequent JDK releases.
*/
class MemorySegmentUtil {
static String getString(MemorySegment segment, long offset) {
return segment.getUtf8String(offset);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.gpu.codec;

import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.GroupLayout;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;

import static java.lang.foreign.MemoryLayout.PathElement.groupElement;

class NVML {

private static final SymbolLookup SYMBOL_LOOKUP = SymbolLookup.libraryLookup("libnvidia-ml.so.1", Arena.ofAuto())
.or(SymbolLookup.loaderLookup())
.or(Linker.nativeLinker().defaultLookup());

public static final int NVML_SUCCESS = 0;

/**
* nvmlReturn_t nvmlInit_v2 ( void )
*/
static final MethodHandle nvmlInit_v2$mh = Linker.nativeLinker()
.downcallHandle(findOrThrow("nvmlInit_v2"), FunctionDescriptor.of(ValueLayout.JAVA_INT));

/**
* nvmlReturn_t nvmlShutdown ( void )
*/
static final MethodHandle nvmlShutdown$mh = Linker.nativeLinker()
.downcallHandle(findOrThrow("nvmlShutdown"), FunctionDescriptor.of(ValueLayout.JAVA_INT));

/**
* const DECLDIR char* nvmlErrorString ( nvmlReturn_t result )
*/
static final MethodHandle nvmlErrorString$mh = Linker.nativeLinker()
.downcallHandle(findOrThrow("nvmlErrorString"), FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.JAVA_INT));

/**
* nvmlReturn_t nvmlDeviceGetHandleByIndex_v2 ( unsigned int index, nvmlDevice_t* device )
*/
static final MethodHandle nvmlDeviceGetHandleByIndex_v2$mh = Linker.nativeLinker()
.downcallHandle(
findOrThrow("nvmlDeviceGetHandleByIndex_v2"),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS)
);

/**
* nvmlReturn_t nvmlDeviceGetUtilizationRates ( nvmlDevice_t device, nvmlUtilization_t* utilization )
*/
static final MethodHandle nvmlDeviceGetUtilizationRates$mh = Linker.nativeLinker()
.downcallHandle(
findOrThrow("nvmlDeviceGetUtilizationRates"),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS)
);

public static class nvmlUtilization_t {

nvmlUtilization_t() {
// Should not be called directly
}

private static final GroupLayout $LAYOUT = MemoryLayout.structLayout(
ValueLayout.JAVA_INT.withName("gpu"),
ValueLayout.JAVA_INT.withName("memory")
);

/**
* The layout of this struct
*/
public static GroupLayout layout() {
return $LAYOUT;
}

private static final ValueLayout.OfInt gpu$LAYOUT = (ValueLayout.OfInt) $LAYOUT.select(groupElement("gpu"));

/**
* Getter for field: gpu
* Percent of time over the past sample period during which one or more kernels was executing on the GPU.
*/
public static int gpu(MemorySegment struct) {
return struct.get(gpu$LAYOUT, 0);
}

private static final ValueLayout.OfInt memory$LAYOUT = (ValueLayout.OfInt) $LAYOUT.select(groupElement("memory"));

/**
* Getter for field: memory
* Percent of time over the past sample period during which global (device) memory was being read or written.
*/
public static int memory(MemorySegment struct) {
return struct.get(memory$LAYOUT, 4);
}
}

private static MemorySegment findOrThrow(String symbol) {
return SYMBOL_LOOKUP.find(symbol).orElseThrow(() -> new UnsatisfiedLinkError("unresolved symbol: " + symbol));
}

public static void nvmlInit_v2() {
int res;
try {
res = (int) nvmlInit_v2$mh.invokeExact();
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
if (res != NVML_SUCCESS) {
throw buildException(res);
}
}

public static void nvmlShutdown() {
int res;
try {
res = (int) nvmlShutdown$mh.invokeExact();
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
if (res != NVML_SUCCESS) {
throw buildException(res);
}
}

public static MemorySegment nvmlDeviceGetHandleByIndex_v2(int index) {
int res;
MemorySegment nvmlDevice;
try (var localArena = Arena.ofConfined()) {
MemorySegment devicePtr = localArena.allocate(ValueLayout.ADDRESS);
res = (int) nvmlDeviceGetHandleByIndex_v2$mh.invokeExact(index, devicePtr);
nvmlDevice = devicePtr.get(ValueLayout.ADDRESS, 0);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
if (res != NVML_SUCCESS) {
throw buildException(res);
}
return nvmlDevice;
}

public static void nvmlDeviceGetUtilizationRates(MemorySegment nvmlDevice, MemorySegment nvmlUtilizationPtr) {
int res;
try {
res = (int) nvmlDeviceGetUtilizationRates$mh.invokeExact(nvmlDevice, nvmlUtilizationPtr);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
if (res != NVML_SUCCESS) {
throw buildException(res);
}
}

private static RuntimeException buildException(int res) {
return new RuntimeException("Error invoking NVML: " + res + "[" + nvmlErrorString(res) + "]");
}

public static String nvmlErrorString(int result) {
try {
var seg = (MemorySegment) nvmlErrorString$mh.invokeExact(result);
if (seg.equals(MemorySegment.NULL)) {
return "no last error text";
}
return MemorySegmentUtil.getString(seg, 0);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
com.nvidia.cuvs:
- load_native_libraries
org.elasticsearch.gpu:
- load_native_libraries
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.gpu.codec;

import java.lang.foreign.MemorySegment;

/**
* Utility methods to act on MemorySegment apis which have changed in subsequent JDK releases.
*/
class MemorySegmentUtil {
static String getString(MemorySegment segment, long offset) {
return segment.getString(offset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingC
}

private MockPoolingCuVSResourceManager(int capacity, List<Long> allocationList) {
super(capacity, new MockGPUInfoProvider(() -> freeMemoryFunction(allocationList)));
super(capacity, res -> TOTAL_DEVICE_MEMORY_IN_BYTES, res -> freeMemoryFunction(allocationList), res -> 50);
this.allocations = allocationList;
}

Expand Down