diff --git a/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java b/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java index 53e83dac30d4e..f35e1f1d6b659 100644 --- a/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java +++ b/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/GPUIndexIT.java @@ -15,6 +15,7 @@ import org.elasticsearch.search.vectors.KnnSearchBuilder; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.xpack.gpu.GPUPlugin; +import org.elasticsearch.xpack.gpu.GPUSupport; import java.util.Collection; import java.util.List; @@ -32,6 +33,7 @@ protected Collection> nodePlugins() { } public void testBasic() { + assumeTrue("cuvs not supported", GPUSupport.isSupported(false)); final int dims = randomIntBetween(4, 128); final int[] numDocs = new int[] { randomIntBetween(1, 100), 1, 2, randomIntBetween(1, 100) }; createIndex(dims); @@ -45,6 +47,7 @@ public void testBasic() { } public void testSearchWithoutGPU() { + assumeTrue("cuvs not supported", GPUSupport.isSupported(false)); final int dims = randomIntBetween(4, 128); final int numDocs = randomIntBetween(1, 500); createIndex(dims); diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java index 8f093ac553112..eb7d3b4f594d2 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUPlugin.java @@ -6,8 +6,6 @@ */ package org.elasticsearch.xpack.gpu; -import com.nvidia.cuvs.CuVSResources; - import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -34,17 +32,16 @@ public VectorsFormatProvider getVectorsFormatProvider() { + "]" ); } - CuVSResources resources = GPUVectorsFormat.cuVSResourcesOrNull(true); - if (resources == null) { + if (GPUSupport.isSupported(true) == false) { throw new IllegalArgumentException( "[index.vectors.indexing.use_gpu] was set to [true], but GPU resources are not accessible on the node." ); } return new GPUVectorsFormat(); } - if ((gpuMode == IndexSettings.GpuMode.AUTO) + if (gpuMode == IndexSettings.GpuMode.AUTO && vectorIndexTypeSupported(indexOptions.getType()) - && GPUVectorsFormat.cuVSResourcesOrNull(false) != null) { + && GPUSupport.isSupported(false)) { return new GPUVectorsFormat(); } } diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUSupport.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUSupport.java new file mode 100644 index 0000000000000..67fd97faec259 --- /dev/null +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/GPUSupport.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu; + +import com.nvidia.cuvs.CuVSResources; + +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; + +public class GPUSupport { + + private static final Logger LOG = LogManager.getLogger(GPUSupport.class); + + /** Tells whether the platform supports cuvs. */ + public static boolean isSupported(boolean logError) { + try (var resources = cuVSResourcesOrNull(logError)) { + if (resources != null) { + return true; + } + } + return false; + } + + /** Returns a resources if supported, otherwise null. */ + public static CuVSResources cuVSResourcesOrNull(boolean logError) { + try { + var resources = CuVSResources.create(); + return resources; + } catch (UnsupportedOperationException uoe) { + if (logError) { + String msg = ""; + if (uoe.getMessage() == null) { + msg = "Runtime Java version: " + Runtime.version().feature(); + } else { + msg = ": " + uoe.getMessage(); + } + LOG.warn("GPU based vector indexing is not supported on this platform or java version; " + msg); + } + } catch (Throwable t) { + if (logError) { + if (t instanceof ExceptionInInitializerError ex) { + t = ex.getCause(); + } + LOG.warn("Exception occurred during creation of cuvs resources. " + t); + } + } + return null; + } +} diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java new file mode 100644 index 0000000000000..e7977f28c9c22 --- /dev/null +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManager.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu.codec; + +import com.nvidia.cuvs.CuVSResources; + +import org.elasticsearch.xpack.gpu.GPUSupport; + +import java.nio.file.Path; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * A manager of {@link com.nvidia.cuvs.CuVSResources}. There is one manager per GPU. + * + *

All access to GPU resources is mediated through a manager. A manager helps coordinate usage threads to: + *

+ * + *

Fundamentally, a resource is used in compute and memory bound operations. The former occurs prior to the latter, e.g. + * index build (compute), followed by a copy/process of the newly built index (memory). The manager allows the resource + * user to indicate that compute is complete before releasing the resources. This can help improve parallelism of compute + * on the GPU - allowing the next compute operation to proceed before releasing the resources. + * + */ +public interface CuVSResourceManager { + + /** + * Acquires a resource from the manager. + * + *

A manager can use the given parameters, numVectors and dims, to estimate the potential + * effect on GPU memory and compute usage to determine whether to give out + * another resource or wait for a resources to be returned before giving out another. + */ + // numVectors and dims are currently unused, but could be used along with GPU metadata, + // memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims, + // to give out a resources or not. + ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException; + + /** Marks the resources as finished with regard to compute. */ + void finishedComputation(ManagedCuVSResources resources); + + /** Returns the given resource to the manager. */ + void release(ManagedCuVSResources resources); + + /** Shuts down the manager, releasing all open resources. */ + void shutdown(); + + /** Returns the system-wide pooling manager. */ + static CuVSResourceManager pooling() { + return PoolingCuVSResourceManager.INSTANCE; + } + + /** + * A manager that maintains a pool of resources. + */ + class PoolingCuVSResourceManager implements CuVSResourceManager { + + static final int MAX_RESOURCES = 2; + static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(MAX_RESOURCES); + + final BlockingQueue pool; + final int capacity; + int createdCount; + + public PoolingCuVSResourceManager(int capacity) { + if (capacity < 1 || capacity > MAX_RESOURCES) { + throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES); + } + this.capacity = capacity; + this.pool = new ArrayBlockingQueue<>(capacity); + } + + @Override + public ManagedCuVSResources acquire(int numVectors, int dims) throws InterruptedException { + ManagedCuVSResources res = pool.poll(); + if (res != null) { + return res; + } + synchronized (this) { + if (createdCount < capacity) { + createdCount++; + return new ManagedCuVSResources(Objects.requireNonNull(createNew())); + } + } + // Otherwise, wait for one to be released + return pool.take(); + } + + // visible for testing + protected CuVSResources createNew() { + return GPUSupport.cuVSResourcesOrNull(true); + } + + @Override + public void finishedComputation(ManagedCuVSResources resources) { + // currently does nothing, but could allow acquire to return possibly blocked resources + } + + @Override + public void release(ManagedCuVSResources resources) { + var added = pool.offer(Objects.requireNonNull(resources)); + assert added : "Failed to release resource back to pool"; + } + + @Override + public void shutdown() { + for (ManagedCuVSResources res : pool) { + res.delegate.close(); + } + pool.clear(); + } + } + + /** A managed resource. Cannot be closed. */ + final class ManagedCuVSResources implements CuVSResources { + + final CuVSResources delegate; + + ManagedCuVSResources(CuVSResources resources) { + this.delegate = resources; + } + + @Override + public ScopedAccess access() { + return delegate.access(); + } + + @Override + public void close() { + throw new UnsupportedOperationException("this resource is managed, cannot be closed by clients"); + } + + @Override + public Path tempDirectory() { + return null; + } + + @Override + public String toString() { + return "ManagedCuVSResources[delegate=" + delegate + "]"; + } + } +} diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java index a90fdf866643f..163e2277137ed 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUToHNSWVectorsWriter.java @@ -9,7 +9,6 @@ import com.nvidia.cuvs.CagraIndex; import com.nvidia.cuvs.CagraIndexParams; -import com.nvidia.cuvs.CuVSResources; import com.nvidia.cuvs.Dataset; import org.apache.lucene.codecs.CodecUtil; @@ -68,7 +67,7 @@ final class GPUToHNSWVectorsWriter extends KnnVectorsWriter { private static final long SHALLOW_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(GPUToHNSWVectorsWriter.class); private static final int LUCENE99_HNSW_DIRECT_MONOTONIC_BLOCK_SHIFT = 16; - private final CuVSResources cuVSResources; + private final CuVSResourceManager cuVSResourceManager; private final SegmentWriteState segmentWriteState; private final IndexOutput meta, vectorIndex; private final int M; @@ -78,10 +77,15 @@ final class GPUToHNSWVectorsWriter extends KnnVectorsWriter { private final List fields = new ArrayList<>(); private boolean finished; - GPUToHNSWVectorsWriter(CuVSResources cuVSResources, SegmentWriteState state, int M, int beamWidth, FlatVectorsWriter flatVectorWriter) - throws IOException { - assert cuVSResources != null : "CuVSResources must not be null"; - this.cuVSResources = cuVSResources; + GPUToHNSWVectorsWriter( + CuVSResourceManager cuVSResourceManager, + SegmentWriteState state, + int M, + int beamWidth, + FlatVectorsWriter flatVectorWriter + ) throws IOException { + assert cuVSResourceManager != null : "CuVSResources must not be null"; + this.cuVSResourceManager = cuVSResourceManager; this.M = M; this.flatVectorWriter = flatVectorWriter; this.beamWidth = beamWidth; @@ -267,42 +271,52 @@ private String buildGPUIndex(VectorSimilarityFunction similarityFunction, Datase .withMetric(distanceType) .build(); - // build index on GPU - long startTime = System.nanoTime(); - var index = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(params).build(); - if (logger.isDebugEnabled()) { - logger.debug("Carga index created in: {} ms; #num vectors: {}", (System.nanoTime() - startTime) / 1_000_000.0, dataset.size()); - } - - // TODO: do serialization through MemorySegment instead of a temp file - // serialize index for CPU consumption to the hnwslib format - startTime = System.nanoTime(); - IndexOutput tempCagraHNSW = null; - boolean success = false; + var cuVSResources = cuVSResourceManager.acquire(dataset.size(), dataset.dimensions()); try { - tempCagraHNSW = segmentWriteState.directory.createTempOutput( - vectorIndex.getName(), - "cagra_hnws_temp", - segmentWriteState.context - ); - var tempCagraHNSWOutputStream = new IndexOutputOutputStream(tempCagraHNSW); - index.serializeToHNSW(tempCagraHNSWOutputStream); + long startTime = System.nanoTime(); + var indexBuilder = CagraIndex.newBuilder(cuVSResources).withDataset(dataset).withIndexParams(params); + var index = indexBuilder.build(); + cuVSResourceManager.finishedComputation(cuVSResources); if (logger.isDebugEnabled()) { - logger.debug("Carga index serialized to hnswlib format in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0); + logger.debug( + "Carga index created in: {} ms; #num vectors: {}", + (System.nanoTime() - startTime) / 1_000_000.0, + dataset.size() + ); } - success = true; - } finally { - index.destroyIndex(); - if (success) { - org.elasticsearch.core.IOUtils.close(tempCagraHNSW); - } else { - if (tempCagraHNSW != null) { - IOUtils.closeWhileHandlingException(tempCagraHNSW); - org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempCagraHNSW.getName()); + + // TODO: do serialization through MemorySegment instead of a temp file + // serialize index for CPU consumption to the hnwslib format + startTime = System.nanoTime(); + IndexOutput tempCagraHNSW = null; + boolean success = false; + try { + tempCagraHNSW = segmentWriteState.directory.createTempOutput( + vectorIndex.getName(), + "cagra_hnws_temp", + segmentWriteState.context + ); + var tempCagraHNSWOutputStream = new IndexOutputOutputStream(tempCagraHNSW); + index.serializeToHNSW(tempCagraHNSWOutputStream); + if (logger.isDebugEnabled()) { + logger.debug("Carga index serialized to hnswlib format in: {} ms", (System.nanoTime() - startTime) / 1_000_000.0); + } + success = true; + } finally { + index.destroyIndex(); + if (success) { + org.elasticsearch.core.IOUtils.close(tempCagraHNSW); + } else { + if (tempCagraHNSW != null) { + IOUtils.closeWhileHandlingException(tempCagraHNSW); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(segmentWriteState.directory, tempCagraHNSW.getName()); + } } } + return tempCagraHNSW.getName(); + } finally { + cuVSResourceManager.release(cuVSResources); } - return tempCagraHNSW.getName(); } @SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)") diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormat.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormat.java index c87b5538e2a8b..8620795ddff41 100644 --- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormat.java +++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormat.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.gpu.codec; -import com.nvidia.cuvs.CuVSResources; - import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -48,18 +46,21 @@ public class GPUVectorsFormat extends KnnVectorsFormat { FlatVectorScorerUtil.getLucene99FlatVectorsScorer() ); + final CuVSResourceManager cuVSResourceManager; + public GPUVectorsFormat() { + this(CuVSResourceManager.pooling()); + } + + public GPUVectorsFormat(CuVSResourceManager cuVSResourceManager) { super(NAME); + this.cuVSResourceManager = cuVSResourceManager; } @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - CuVSResources cuVSResources = cuVSResourcesOrNull(true); - if (cuVSResources == null) { - throw new IllegalArgumentException("GPU based vector indexing is not supported on this platform or java version"); - } return new GPUToHNSWVectorsWriter( - cuVSResources, + cuVSResourceManager, state, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, @@ -81,30 +82,4 @@ public int getMaxDimensions(String fieldName) { public String toString() { return NAME + "()"; } - - /** Tells whether the platform supports cuvs. */ - public static CuVSResources cuVSResourcesOrNull(boolean logError) { - try { - var resources = CuVSResources.create(); - return resources; - } catch (UnsupportedOperationException uoe) { - if (logError) { - String msg = ""; - if (uoe.getMessage() == null) { - msg = "Runtime Java version: " + Runtime.version().feature(); - } else { - msg = ": " + uoe.getMessage(); - } - LOG.warn("GPU based vector indexing is not supported on this platform or java version; " + msg); - } - } catch (Throwable t) { - if (logError) { - if (t instanceof ExceptionInInitializerError ex) { - t = ex.getCause(); - } - LOG.warn("Exception occurred during creation of cuvs resources. " + t); - } - } - return null; - } } diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java new file mode 100644 index 0000000000000..a5bac96cc3b51 --- /dev/null +++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/CuVSResourceManagerTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.gpu.codec; + +import com.nvidia.cuvs.CuVSResources; + +import org.elasticsearch.test.ESTestCase; + +import java.nio.file.Path; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; + +public class CuVSResourceManagerTests extends ESTestCase { + + public void testBasic() throws InterruptedException { + var mgr = new MockPoolingCuVSResourceManager(2); + var res1 = mgr.acquire(0, 0); + var res2 = mgr.acquire(0, 0); + assertThat(res1.toString(), containsString("id=0")); + assertThat(res2.toString(), containsString("id=1")); + mgr.release(res1); + mgr.release(res2); + res1 = mgr.acquire(0, 0); + res2 = mgr.acquire(0, 0); + assertThat(res1.toString(), containsString("id=0")); + assertThat(res2.toString(), containsString("id=1")); + mgr.release(res1); + mgr.release(res2); + mgr.shutdown(); + } + + public void testBlocking() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(2); + var res1 = mgr.acquire(0, 0); + var res2 = mgr.acquire(0, 0); + + AtomicReference holder = new AtomicReference<>(); + Thread t = new Thread(() -> { + try { + var res3 = mgr.acquire(0, 0); + holder.set(res3); + } catch (InterruptedException e) { + throw new AssertionError(e); + } + }); + t.start(); + Thread.sleep(1_000); + assertNull(holder.get()); + mgr.release(randomFrom(res1, res2)); + t.join(); + assertThat(holder.get().toString(), anyOf(containsString("id=0"), containsString("id=1"))); + mgr.shutdown(); + } + + public void testManagedResIsNotClosable() throws Exception { + var mgr = new MockPoolingCuVSResourceManager(1); + var res = mgr.acquire(0, 0); + assertThrows(UnsupportedOperationException.class, () -> res.close()); + mgr.release(res); + mgr.shutdown(); + } + + public void testDoubleRelease() throws InterruptedException { + var mgr = new MockPoolingCuVSResourceManager(2); + var res1 = mgr.acquire(0, 0); + var res2 = mgr.acquire(0, 0); + mgr.release(res1); + mgr.release(res2); + assertThrows(AssertionError.class, () -> mgr.release(randomFrom(res1, res2))); + mgr.shutdown(); + } + + static class MockPoolingCuVSResourceManager extends CuVSResourceManager.PoolingCuVSResourceManager { + + final AtomicInteger idGenerator = new AtomicInteger(); + + MockPoolingCuVSResourceManager(int capacity) { + super(capacity); + } + + @Override + protected CuVSResources createNew() { + return new MockCuVSResources(idGenerator.getAndIncrement()); + } + } + + static class MockCuVSResources implements CuVSResources { + + final int id; + + MockCuVSResources(int id) { + this.id = id; + } + + @Override + public ScopedAccess access() { + throw new UnsupportedOperationException(); + } + + @Override + public void close() {} + + @Override + public Path tempDirectory() { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return "MockCuVSResources[id=" + id + "]"; + } + } +} diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java index e86270d73695b..bfc4ee6d48d0d 100644 --- a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java +++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.store.MemorySegmentAccessInput; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.gpu.GPUSupport; import org.junit.Before; import java.lang.foreign.MemorySegment; @@ -27,9 +28,7 @@ public class DatasetUtilsTests extends ESTestCase { @Before public void setup() { // TODO: abstract out setup in to common GPUTestcase assumeTrue("cuvs runtime only supported on 22 or greater, your JDK is " + Runtime.version(), Runtime.version().feature() >= 22); - try (var resources = GPUVectorsFormat.cuVSResourcesOrNull(false)) { - assumeTrue("cuvs not supported", resources != null); - } + assumeTrue("cuvs not supported", GPUSupport.isSupported(false)); datasetUtils = DatasetUtils.getInstance(); } diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java index fd8db09461e39..1a1ecf9f85ec4 100644 --- a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java +++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUDenseVectorFieldMapperTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.xpack.gpu.GPUPlugin; +import org.elasticsearch.xpack.gpu.GPUSupport; import org.junit.Before; import java.io.IOException; @@ -29,7 +30,7 @@ public class GPUDenseVectorFieldMapperTests extends AbstractDenseVectorFieldMapp @Before public void setup() { - assumeTrue("cuvs not supported", GPUVectorsFormat.cuVSResourcesOrNull(false) != null); + assumeTrue("cuvs not supported", GPUSupport.isSupported(false)); } @Override diff --git a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormatTests.java b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormatTests.java index 99a823277506b..b1fcc1b01b58e 100644 --- a/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormatTests.java +++ b/x-pack/plugin/gpu/src/test/java/org/elasticsearch/xpack/gpu/codec/GPUVectorsFormatTests.java @@ -14,6 +14,7 @@ import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.xpack.gpu.GPUSupport; import org.junit.BeforeClass; public class GPUVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @@ -25,7 +26,7 @@ public class GPUVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @BeforeClass public static void beforeClass() { - assumeTrue("cuvs not supported", GPUVectorsFormat.cuVSResourcesOrNull(false) != null); + assumeTrue("cuvs not supported", GPUSupport.isSupported(false)); } static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new GPUVectorsFormat());