diff --git a/examples/src/main/java/io/milvus/v2/ClientPoolDemo.java b/examples/src/main/java/io/milvus/v2/ClientPoolDemo.java new file mode 100644 index 000000000..810da3148 --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/ClientPoolDemo.java @@ -0,0 +1,273 @@ +package io.milvus.v2; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import io.milvus.pool.MilvusClientV2Pool; +import io.milvus.pool.PoolConfig; +import io.milvus.v1.CommonUtils; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.ConsistencyLevel; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.HasCollectionReq; +import io.milvus.v2.service.vector.request.InsertReq; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.request.SearchReq; +import io.milvus.v2.service.vector.request.data.FloatVec; +import io.milvus.v2.service.vector.response.InsertResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +public class ClientPoolDemo { + private static final String ServerUri = "http://localhost:19530"; + private static final String CollectionName = "java_sdk_example_pool_demo"; + private static final String IDFieldName = "id"; + private static final String VectorFieldName = "vector"; + private static final String TextFieldName = "text"; + private static final int DIM = 256; + private static final String DemoKey = "for_demo"; + + private static final MilvusClientV2Pool pool; + + static { + ConnectConfig defaultConnectConfig = ConnectConfig.builder() + .uri(ServerUri) + .build(); + // read this issue for more details about the pool configurations: + // https://github.com/milvus-io/milvus-sdk-java/issues/1577 + PoolConfig poolConfig = PoolConfig.builder() + .minIdlePerKey(1) + .maxIdlePerKey(2) + .maxTotalPerKey(5) + .maxBlockWaitDuration(Duration.ofSeconds(5L)) // getClient() will wait 5 seconds if no idle client available + .build(); + try { + pool = new MilvusClientV2Pool(poolConfig, defaultConnectConfig); + System.out.printf("Pool is created with config:%n%s%n", poolConfig); + + // prepare the pool to pre-create some clients according to the minIdlePerKey. + // it is like a warmup to reduce the first time cost to call the getClient() + pool.preparePool(DemoKey); + } catch (ClassNotFoundException | NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + private static void createCollection(boolean recreate, long rowCount) { + System.out.println("========== createCollection() =========="); + MilvusClientV2 client = null; + try { + client = pool.getClient(DemoKey); + if (client == null) { + System.out.println("Cannot not get client from key:" + DemoKey); + return; + } + + if (recreate) { + client.dropCollection(DropCollectionReq.builder() + .collectionName(CollectionName) + .build()); + } else if (client.hasCollection(HasCollectionReq.builder() + .collectionName(CollectionName) + .build())) { + return; + } + + CreateCollectionReq.CollectionSchema collectionSchema = CreateCollectionReq.CollectionSchema.builder() + .build(); + collectionSchema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(true) + .autoID(true) + .build()); + collectionSchema.addField(AddFieldReq.builder() + .fieldName(VectorFieldName) + .dataType(DataType.FloatVector) + .dimension(DIM) + .build()); + collectionSchema.addField(AddFieldReq.builder() + .fieldName(TextFieldName) + .dataType(DataType.VarChar) + .maxLength(1024) + .build()); + + List indexes = new ArrayList<>(); + indexes.add(IndexParam.builder() + .fieldName(VectorFieldName) + .indexType(IndexParam.IndexType.FLAT) + .metricType(IndexParam.MetricType.COSINE) + .build()); + + CreateCollectionReq requestCreate = CreateCollectionReq.builder() + .collectionName(CollectionName) + .collectionSchema(collectionSchema) + .indexParams(indexes) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .build(); + client.createCollection(requestCreate); + + insertData(rowCount); + } finally { + pool.returnClient(DemoKey, client); + } + } + + private static void insertData(long rowCount) { + System.out.println("========== insertData() =========="); + MilvusClientV2 client = null; + try { + client = pool.getClient(DemoKey); + if (client == null) { + System.out.println("Cannot not get client from key:" + DemoKey); + return; + } + + Gson gson = new Gson(); + long inserted = 0L; + while (inserted < rowCount) { + long batch = 1000L; + if (rowCount - inserted < batch) { + batch = rowCount - inserted; + } + List rows = new ArrayList<>(); + for (long i = 0; i < batch; i++) { + JsonObject row = new JsonObject(); + row.add(VectorFieldName, gson.toJsonTree(CommonUtils.generateFloatVector(DIM))); + row.addProperty(TextFieldName, "text_" + i); + rows.add(row); + } + InsertResp resp = client.insert(InsertReq.builder() + .collectionName(CollectionName) + .data(rows) + .build()); + inserted += resp.getInsertCnt(); + System.out.println("Inserted count:" + resp.getInsertCnt()); + } + + QueryResp countR = client.query(QueryReq.builder() + .collectionName(CollectionName) + .outputFields(Collections.singletonList("count(*)")) + .consistencyLevel(ConsistencyLevel.STRONG) + .build()); + System.out.printf("%d rows persisted%n", (long) countR.getQueryResults().get(0).getEntity().get("count(*)")); + } finally { + pool.returnClient(DemoKey, client); + } + } + + private static void search() { + MilvusClientV2 client = null; + try { + client = pool.getClient(DemoKey); + while (client == null) { + try { + // getClient() might return null if it exceeds the borrowMaxWaitMillis when the pool is full. + // retry to call until it return a client. + client = pool.getClient(DemoKey); + } catch (Exception e) { + System.out.printf("Failed to get client, will retry, error: %s%n", e.getMessage()); + } + } + +// long start = System.currentTimeMillis(); + FloatVec vector = new FloatVec(CommonUtils.generateFloatVector(DIM)); + SearchResp resp = client.search(SearchReq.builder() + .collectionName(CollectionName) + .limit(10) + .data(Collections.singletonList(vector)) + .annsField(VectorFieldName) + .outputFields(Collections.singletonList(TextFieldName)) + .build()); +// System.out.printf("search time cost: %dms%n", System.currentTimeMillis() - start); + } finally { + pool.returnClient(DemoKey, client); + } + } + + private static void printPoolState() { + System.out.println("========== printPoolState() =========="); + System.out.printf("%d idle clients and %d active clients%n", + pool.getIdleClientNumber(DemoKey), pool.getActiveClientNumber(DemoKey)); + System.out.printf("%.2f clients fetched per second%n", pool.fetchClientPerSecond(DemoKey)); + } + + private static void concurrentSearch(int threadCount, int requestCount) { + System.out.println("\n======================================================================"); + System.out.println("======================= ConcurrentSearch ============================="); + System.out.println("======================================================================"); + + AtomicLong totalTimeCostMs = new AtomicLong(0L); + class Worker implements Runnable { + @Override + public void run() { + long start = System.currentTimeMillis(); + search(); + long end = System.currentTimeMillis(); + totalTimeCostMs.addAndGet(end - start); + } + } + + try { + long start = System.currentTimeMillis(); + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + for (int i = 0; i < requestCount; i++) { + Runnable worker = new Worker(); + executor.execute(worker); + } + executor.shutdown(); + + // with requests start, more active clients will be created + boolean done = false; + while (!done) { + printPoolState(); + done = executor.awaitTermination(1, TimeUnit.SECONDS); + } + + long timeGapMs = System.currentTimeMillis() - start; + float avgQPS = (float) requestCount * 1000 / timeGapMs; + long avgLatency = totalTimeCostMs.get() / requestCount; + System.out.printf("%n%d requests done in %.1f seconds, average QPS: %.1f, average latency: %dms%n%n", + requestCount, (float) timeGapMs / 1000, avgQPS, avgLatency); + + // after all requests are done, the active clients will be retired and eventually only one idle client left. + // just demo the pool can automatically destroy idle clients, you can directly close the pool without waiting + // it in practice. + while (pool.getActiveClientNumber(DemoKey) > 1) { + TimeUnit.SECONDS.sleep(1); + printPoolState(); + } + } catch (Exception e) { + System.err.println("Failed to create executor: " + e); + } + } + + public static void main(String[] args) throws InterruptedException { + long rowCount = 10000; + createCollection(true, rowCount); + + int threadCount = 50; + int requestCount = 10000; + concurrentSearch(threadCount, requestCount); + + // do again + threadCount = 100; + requestCount = 20000; + concurrentSearch(threadCount, requestCount); + + pool.close(); + } +} diff --git a/examples/src/main/java/io/milvus/v2/ClientPoolExample.java b/examples/src/main/java/io/milvus/v2/ClientPoolExample.java index 35f99dcdf..a4959d6e1 100644 --- a/examples/src/main/java/io/milvus/v2/ClientPoolExample.java +++ b/examples/src/main/java/io/milvus/v2/ClientPoolExample.java @@ -47,7 +47,7 @@ import java.util.List; public class ClientPoolExample { - public static String serverUri = "http://localhost:19530"; + public static String ServerUri = "http://localhost:19530"; public static String CollectionName = "java_sdk_example_pool_v2"; public static String VectorFieldName = "vector"; public static int DIM = 128; @@ -95,7 +95,7 @@ public static void createDatabases(MilvusClientV2Pool pool) { // the ClientPool will use different config to create client to connect to specific database for (String dbName : dbNames) { ConnectConfig config = ConnectConfig.builder() - .uri(serverUri) + .uri(ServerUri) .dbName(dbName) .build(); pool.configForKey(dbName, config); @@ -288,13 +288,13 @@ public static void dropDatabases(MilvusClientV2Pool pool) { public static void main(String[] args) throws InterruptedException { ConnectConfig defaultConfig = ConnectConfig.builder() - .uri(serverUri) + .uri(ServerUri) .build(); // read this issue for more details about the pool configurations: // https://github.com/milvus-io/milvus-sdk-java/issues/1577 PoolConfig poolConfig = PoolConfig.builder() - .maxIdlePerKey(10) // max idle clients per key - .maxTotalPerKey(50) // max total(idle + active) clients per key + .maxIdlePerKey(1) // max idle clients per key + .maxTotalPerKey(5) // max total(idle + active) clients per key .maxTotal(1000) // max total clients for all keys .maxBlockWaitDuration(Duration.ofSeconds(5L)) // getClient() will wait 5 seconds if no idle client available .minEvictableIdleDuration(Duration.ofSeconds(10L)) // if number of idle clients is larger than maxIdlePerKey, redundant idle clients will be evicted after 10 seconds @@ -340,7 +340,7 @@ public static void main(String[] args) throws InterruptedException { long end = System.currentTimeMillis(); System.out.printf("%d insert requests and %d search requests finished in %.3f seconds%n", - threadCount * repeatRequests * 3, threadCount * repeatRequests * 3, (end - start) * 0.001); + threadCount * repeatRequests * dbNames.size(), threadCount * repeatRequests * dbNames.size(), (end - start) * 0.001); printClientNumber(pool); pool.clear(); // clear idle clients diff --git a/sdk-core/src/main/java/io/milvus/pool/ClientCache.java b/sdk-core/src/main/java/io/milvus/pool/ClientCache.java new file mode 100644 index 000000000..db35fe282 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/pool/ClientCache.java @@ -0,0 +1,334 @@ +package io.milvus.pool; + +import io.milvus.v2.exception.ErrorCode; +import io.milvus.v2.exception.MilvusClientException; +import org.apache.commons.pool2.impl.GenericKeyedObjectPool; +import org.jetbrains.annotations.NotNull; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +public class ClientCache { + public static final int THRESHOLD_INCREASE = 100; + public static final int THRESHOLD_DECREASE = 50; + + private static final Logger logger = LoggerFactory.getLogger(ClientCache.class); + private final String key; + private final GenericKeyedObjectPool clientPool; + private final CopyOnWriteArrayList> activeClientList = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList> retireClientList = new CopyOnWriteArrayList<>(); + private final ScheduledExecutorService scheduler; + private final AtomicLong totalCallNumber = new AtomicLong(0L); + private final Lock clientListLock; + private long lastCheckMs = 0L; + private float fetchClientPerSecond = 0.0F; + + protected ClientCache(String key, GenericKeyedObjectPool pool) { + this.key = key; + this.clientPool = pool; + this.clientListLock = new ReentrantLock(true); + + ThreadFactory threadFactory = new ThreadFactory() { + @Override + public Thread newThread(@NotNull Runnable r) { + Thread t = new Thread(r); + t.setPriority(Thread.MAX_PRIORITY); // set the highest priority for the timer + return t; + } + }; + this.scheduler = Executors.newScheduledThreadPool(1, threadFactory); + + startTimer(1000L); + } + + public void preparePool() { + try { + // preparePool() will create minIdlePerKey MilvusClient objects in advance, put the pre-created clients + // into activeClientList + clientPool.preparePool(this.key); + int minIdlePerKey = clientPool.getMinIdlePerKey(); + for (int i = 0; i < minIdlePerKey; i++) { + activeClientList.add(new ClientWrapper<>(clientPool.borrowObject(this.key))); + } + + if (logger.isDebugEnabled()) { + logger.debug("ClientCache key: {} cache clients: {} ", key, activeClientList.size()); + logger.debug("Pool initialize idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key)); + } +// System.out.printf("Key: %s, cache client: %d%n", key, activeClientList.size()); +// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key)); + } catch (Exception e) { + logger.error("Failed to prepare pool {}, exception: ", key, e); + throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); + } + } + + // this method is called in an interval, it does the following tasks: + // - if QPS is high, borrow client from the pool and put into activeClientList + // - if QPS is low, pick a client from activeClientList and put into retireClientList + // + // Most of gRPC implementations uses a single long-lived HTTP/2 connection, each HTTP/2 connections have a limit + // on the number of concurrent streams which is default 100. When the number of active RPCs on the connection + // reaches this limit, additional RPCs are queued in the client and must wait for active RPCs to finish + // before they are sent. + // + // Treat qps >= THRESHOLD_INCREASE as high, qps <= THRESHOLD_DECREASE as low + private void checkQPS() { + if (activeClientList.isEmpty()) { + // reset the last check time point + lastCheckMs = System.currentTimeMillis(); + return; + } + + long totalCallNum = totalCallNumber.get(); + float perClientCall = (float) totalCallNum / activeClientList.size(); + long timeGapMs = System.currentTimeMillis() - lastCheckMs; + if (timeGapMs == 0) { + timeGapMs = 1; // avoid zero + } + float perClientPerSecond = perClientCall * 1000 / timeGapMs; + this.fetchClientPerSecond = (float) (totalCallNum * 1000) / timeGapMs; + if (logger.isDebugEnabled()) { + + logger.debug("ClientCache key: {} fetchClientPerSecond: {} perClientPerSecond: {}, cached clients: {}", + key, fetchClientPerSecond, perClientPerSecond, activeClientList.size()); + logger.debug("Pool idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key)); + } +// System.out.printf("Key: %s, fetchClientPerSecond: %.2f, perClientPerSecond: %.2f, cache client: %d%n", key, fetchClientPerSecond, perClientPerSecond, activeClientList.size()); +// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key)); + + // reset the counter and the last check time point + totalCallNumber.set(0L); + lastCheckMs = System.currentTimeMillis(); + + if (perClientPerSecond >= THRESHOLD_INCREASE) { + // try to create more clients to reduce the perClientPerSecond to under THRESHOLD_INCREASE + // add no more than 3 clients since the qps could change during we're adding new clients + // the next call of checkQPS() will add more clients if the perClientPerSecond is still high + int expectedNum = (int) Math.ceil((double) totalCallNum / THRESHOLD_INCREASE); + int moreNum = expectedNum - activeClientList.size(); + if (moreNum > 3) { + moreNum = 3; + } + + for (int k = 0; k < moreNum; k++) { + T client = fetchFromPool(); + // if the pool reaches MaxTotalPerKey, the new client is null + if (client == null) { + break; + } + + ClientWrapper wrapper = new ClientWrapper<>(client); + activeClientList.add(wrapper); + + if (logger.isDebugEnabled()) { + logger.debug("ClientCache key: {} borrows a client", key); + } +// System.out.printf("Key: %s borrows a client%n", key); + } + } + + if (activeClientList.size() > 1 && perClientPerSecond <= THRESHOLD_DECREASE) { + // if activeClientList has only one client, no need to retire it + // otherwise, retire the max load client + int maxLoad = -1000; + int maxIndex = -1; + for (int i = 0; i < activeClientList.size(); i++) { + ClientWrapper wrapper = activeClientList.get(i); + int refCount = wrapper.getRefCount(); + if (refCount > maxLoad) { + maxLoad = refCount; + maxIndex = i; + } + } + if (maxIndex >= 0) { + ClientWrapper wrapper = activeClientList.get(maxIndex); + activeClientList.remove(maxIndex); + retireClientList.add(wrapper); + } + } + + // return the retired client to pool if ref count is zero + returnRetiredClients(); + } + + private void returnRetiredClients() { + retireClientList.removeIf(wrapper -> { + if (wrapper.getRefCount() <= 0) { + returnToPool(wrapper.getClient()); + + if (logger.isDebugEnabled()) { + logger.debug("ClientCache key: {} returns a client", key); + } +// System.out.printf("Key: %s returns a client%n", key); + return true; + } + return false; + }); + } + + private void startTimer(long interval) { + if (interval < 1000L) { + interval = 1000L; // min 1000 + } + + lastCheckMs = System.currentTimeMillis(); + scheduler.scheduleAtFixedRate(new Runnable() { + @Override + public void run() { + checkQPS(); + } + }, interval, interval, TimeUnit.MILLISECONDS); + } + + public void stopTimer() { + scheduler.shutdown(); + } + + public T getClient() { + if (activeClientList.isEmpty()) { + // multiple threads can run into this section, add a lock to ensure only one thread can fetch the first + // client object, this section is entered only one time, the lock doesn't affect major performance + clientListLock.lock(); + try { + if (activeClientList.isEmpty()) { + T client = fetchFromPool(); + if (client == null) { + // no need to count the totalCallNumber is cannot fetch a client + return null; // reach MaxTotalPerKey? + } + ClientWrapper wrapper = new ClientWrapper<>(client); + activeClientList.add(wrapper); + totalCallNumber.incrementAndGet(); // count the totalCallNumber when successfully fetch a client + return wrapper.getClient(); + } + } finally { + clientListLock.unlock(); + } + } + + // round-robin is not a good choice because the activeClientList is occasionally changed. + // here we return the minimum load client, the for loop of CopyOnWriteArrayList is high performance + // typically, the activeClientList is not a large list since a dozen of clients can take thousands of qps, + // I suppose the loop is a cheap operation. + int minLoad = Integer.MAX_VALUE; + ClientWrapper wrapper = null; + for (ClientWrapper tempWrapper : activeClientList) { + if (tempWrapper.getRefCount() < minLoad) { + minLoad = tempWrapper.getRefCount(); + wrapper = tempWrapper; + } + } + if (wrapper == null) { + // should not be here, the "if (activeClientList.isEmpty())" section has already ensured + // there must be a client in activeClientList + wrapper = activeClientList.get(0); + } + + totalCallNumber.incrementAndGet(); // count the totalCallNumber when successfully fetch a client + return wrapper.getClient(); + } + + public void returnClient(T grpcClient) { + // for-loop of CopyOnWriteArrayList is thread safe + // this method only decrements the call number, the checkQPS timer will retire client accordingly + for (ClientWrapper wrapper : activeClientList) { + if (wrapper.equals(grpcClient)) { + wrapper.returnClient(); + return; + } + } + for (ClientWrapper wrapper : retireClientList) { + if (wrapper.equals(grpcClient)) { + wrapper.returnClient(); + return; + } + } + } + + private T fetchFromPool() { + try { + // borrowed clients exceeds MaxTotalPerKey? + if (activeClientList.size() + retireClientList.size() >= clientPool.getMaxTotalPerKey()) { + return null; + } + // TODO: how to check borrowed clients exceeds MaxTotal? + // if the number of borrowed clients is less than MaxTotalPerKey but the total borrowed clients of all keys + // exceeds MaxTotal, clientPool.borrowObject() will throw an exception "Timeout waiting for idle object". + return clientPool.borrowObject(this.key); + } catch (Exception e) { + // the pool might return timeout exception if it could not get a client in PoolConfig.maxBlockWaitDuration + // fetchFromPool() is internal use, return null here, let the caller handle. + logger.error("Failed to get client, exception: ", e); + return null; + } + } + + private void returnToPool(T grpcClient) { + try { + clientPool.returnObject(this.key, grpcClient); + } catch (Exception e) { + // the pool might return exception if the key doesn't exist or the grpcClient doesn't belong to this pool + // returnToPool is internal use, the client must be in this pool, mute the exception + logger.error("Failed to return client, exception: ", e); + } + } + + public float fetchClientPerSecond() { + return this.fetchClientPerSecond; + } + + private static class ClientWrapper { + private final T client; + private final AtomicInteger refCount = new AtomicInteger(0); + + public ClientWrapper(T client) { + this.client = client; + } + + @Override + public int hashCode() { + // the hash code of ClientWrapper is equal to MilvusClient hash code + return this.client.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + + if (obj == null) { + return false; + } + + // obj is ClientWrapper + if (this.getClass() == obj.getClass()) { + return Objects.equals(this.client, ((ClientWrapper) obj).client); + } + + // obj is MilvusClient + if (this.client != null && this.client.getClass() == obj.getClass()) { + return Objects.equals(this.client, obj); + } + return false; + } + + public T getClient() { + this.refCount.incrementAndGet(); + return this.client; + } + + public void returnClient() { + this.refCount.decrementAndGet(); + } + + public int getRefCount() { + return refCount.get(); + } + } +} diff --git a/sdk-core/src/main/java/io/milvus/pool/ClientPool.java b/sdk-core/src/main/java/io/milvus/pool/ClientPool.java index 5a6f6debe..0e324ee0b 100644 --- a/sdk-core/src/main/java/io/milvus/pool/ClientPool.java +++ b/sdk-core/src/main/java/io/milvus/pool/ClientPool.java @@ -1,29 +1,34 @@ package io.milvus.pool; -import io.milvus.v2.exception.ErrorCode; -import io.milvus.v2.exception.MilvusClientException; import org.apache.commons.pool2.impl.GenericKeyedObjectPool; import org.apache.commons.pool2.impl.GenericKeyedObjectPoolConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + public class ClientPool { protected static final Logger logger = LoggerFactory.getLogger(ClientPool.class); protected GenericKeyedObjectPool clientPool; protected PoolConfig config; protected PoolClientFactory clientFactory; + private final ConcurrentMap> clientsCache = new ConcurrentHashMap<>(); + private final Lock cacheMapLock = new ReentrantLock(true); protected ClientPool() { } - protected ClientPool(PoolConfig config, PoolClientFactory clientFactory) { + protected ClientPool(PoolConfig config, PoolClientFactory clientFactory) { this.config = config; this.clientFactory = clientFactory; - GenericKeyedObjectPoolConfig poolConfig = new GenericKeyedObjectPoolConfig(); + GenericKeyedObjectPoolConfig poolConfig = new GenericKeyedObjectPoolConfig<>(); poolConfig.setMaxIdlePerKey(config.getMaxIdlePerKey()); poolConfig.setMinIdlePerKey(config.getMinIdlePerKey()); poolConfig.setMaxTotal(config.getMaxTotal()); @@ -44,8 +49,8 @@ public void configForKey(String key, C config) { this.clientFactory.configForKey(key, config); } - public C removeConfig(String key) { - return this.clientFactory.removeConfig(key); + public void removeConfig(String key) { + this.clientFactory.removeConfig(key); } public Set configKeys() { @@ -57,8 +62,24 @@ public C getConfig(String key) { } /** - * Get a client object which is idle from the pool. - * Once the client is hold by the caller, it will be marked as active state and cannot be fetched by other caller. + * Create minIdlePerKey clients for the pool of the key. + * Call this method before business can reduce the latency of the first time to getClient(). + */ + public void preparePool(String key) { + ClientCache cache = getCache(key); + if (cache != null) { + cache.preparePool(); + } + } + + /** + * Get a client object from the cache. If the cache is empty, it will fetch a client from the underlying pool. + * The cache maintains a list of clients. The cache will increase the ref-count of this client when it is fetched + * by getClient(), and decrease the ref-count when it is returned by returnClient(). + * The cache balances the caller to multiple clients according to the ref-count of each client. getClient() will + * return the client which has the minimum ref-count to the caller. + * If the average ref-count of clients is smaller than a threshold, the cache will retire a client which has + * the maximum ref-count, wait its ref-count to be zero and return it to the underlying pool. * If the number of clients hits the MaxTotalPerKey value, this method will be blocked for MaxBlockWaitDuration. * If no idle client available after MaxBlockWaitDuration, this method will return a null object to caller. * @@ -66,13 +87,36 @@ public C getConfig(String key) { * @return MilvusClient or MilvusClientV2 */ public T getClient(String key) { - try { - return clientPool.borrowObject(key); - } catch (Exception e) { - // the pool might return timeout exception if it could not get a client in PoolConfig.maxBlockWaitDuration - logger.error("Failed to get client, exception: ", e); - throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); + ClientCache cache = getCache(key); + if (cache == null) { + logger.error("Not able to create a client cache for key: {}", key); + return null; } + return cache.getClient(); + } + + private ClientCache getCache(String key) { + ClientCache cache = clientsCache.get(key); + if (cache == null) { + // If clientsCache doesn't contain this key, there might be multiple threads run into this section. + // Although ConcurrentMap.putIfAbsent() is atomic action, we don't intend to allow multiple threads + // to create multiple ClientCache objects at this line, so we add a lock here. + // Only one thread that first obtains the lock runs into putIfAbsent(), the others will be blocked + // and get the object after obtaining the lock. + // This section is entered one time for each key, the lock basically doesn't affect performance. + cacheMapLock.lock(); + try { + if (!clientsCache.containsKey(key)) { + cache = new ClientCache<>(key, clientPool); + clientsCache.put(key, cache); + } else { + cache = clientsCache.get(key); + } + } finally { + cacheMapLock.unlock(); + } + } + return cache; } /** @@ -84,12 +128,11 @@ public T getClient(String key) { * @param grpcClient the client object to return */ public void returnClient(String key, T grpcClient) { - try { - clientPool.returnObject(key, grpcClient); - } catch (Exception e) { - // the pool might return exception if the key doesn't exist or the grpcClient doesn't belong to this pool - logger.error("Failed to return client, exception: " + e); - throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); + ClientCache cache = clientsCache.get(key); + if (cache != null) { + cache.returnClient(grpcClient); + } else { + logger.warn("No such key: {}", key); } } @@ -99,6 +142,10 @@ public void returnClient(String key, T grpcClient) { */ public void close() { if (clientPool != null && !clientPool.isClosed()) { + // how about if clientPool and clientsCache are cleared but some clients are not returned? + // after clear(), all the milvus clients will be closed, if user continue to use the unreturned client + // to call api, the client will receive a io.grpc.Status.UNAVAILABLE error and retry the call + clear(); clientPool.close(); clientPool = null; } @@ -106,10 +153,16 @@ public void close() { /** * Release/disconnect idle clients of all key groups. - * */ public void clear() { if (clientPool != null && !clientPool.isClosed()) { + // how about if clientPool and clientsCache are cleared but some clients are not returned? + // after clear(), all the milvus clients will be closed, if user continue to use the unreturned client + // to call api, the client will receive a io.grpc.Status.UNAVAILABLE error and retry the call + for (ClientCache cache : clientsCache.values()) { + cache.stopTimer(); + } + clientsCache.clear(); clientPool.clear(); } } @@ -121,12 +174,21 @@ public void clear() { */ public void clear(String key) { if (clientPool != null && !clientPool.isClosed()) { + // how about if clientPool and clientsCache are cleared but some clients are not returned? + // after clear(), all the milvus clients will be closed, if user continue to use the unreturned client + // to call api, the client will receive a io.grpc.Status.UNAVAILABLE error and retry the call + ClientCache cache = clientsCache.get(key); + if (cache != null) { + cache.stopTimer(); + } + clientsCache.remove(key); clientPool.clear(key); } } /** * Return the number of idle clients of a key group + * Threadsafe method. * * @param key the key of a group */ @@ -136,6 +198,7 @@ public int getIdleClientNumber(String key) { /** * Return the number of active clients of a key group + * Threadsafe method. * * @param key the key of a group */ @@ -145,7 +208,7 @@ public int getActiveClientNumber(String key) { /** * Return the number of idle clients of all key group - * + * Threadsafe method. */ public int getTotalIdleClientNumber() { return clientPool.getNumIdle(); @@ -153,9 +216,17 @@ public int getTotalIdleClientNumber() { /** * Return the number of active clients of all key group - * + * Threadsafe method. */ public int getTotalActiveClientNumber() { return clientPool.getNumActive(); } + + public float fetchClientPerSecond(String key) { + ClientCache cache = clientsCache.get(key); + if (cache != null) { + return cache.fetchClientPerSecond(); + } + return 0.0F; + } } diff --git a/sdk-core/src/main/java/io/milvus/pool/PoolClientFactory.java b/sdk-core/src/main/java/io/milvus/pool/PoolClientFactory.java index 6f736507c..f289e8794 100644 --- a/sdk-core/src/main/java/io/milvus/pool/PoolClientFactory.java +++ b/sdk-core/src/main/java/io/milvus/pool/PoolClientFactory.java @@ -36,8 +36,8 @@ public PoolClientFactory(C configDefault, String clientClassName) throws ClassNo } } - public C configForKey(String key, C config) { - return configForKeys.put(key, config); + public void configForKey(String key, C config) { + configForKeys.put(key, config); } public C removeConfig(String key) { @@ -55,6 +55,9 @@ public C getConfig(String key) { @Override public T create(String key) throws Exception { try { + if (logger.isDebugEnabled()) { + logger.info("PoolClientFactory key: {} creates a client", key); + } C keyConfig = configForKeys.get(key); if (keyConfig == null) { return (T) constructor.newInstance(this.configDefault); @@ -74,6 +77,9 @@ public PooledObject wrap(T client) { @Override public void destroyObject(String key, PooledObject p) throws Exception { + if (logger.isDebugEnabled()) { + logger.info("PoolClientFactory key: {} closes a client", key); + } T client = p.getObject(); closeMethod.invoke(client, 3L); } @@ -84,7 +90,7 @@ public boolean validateObject(String key, PooledObject p) { T client = p.getObject(); return (boolean) verifyMethod.invoke(client); } catch (Exception e) { - logger.error("Failed to validate client, exception: " + e); + logger.error("Failed to validate client, exception: ", e); throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); } } diff --git a/sdk-core/src/main/java/io/milvus/pool/PoolConfig.java b/sdk-core/src/main/java/io/milvus/pool/PoolConfig.java index 77d08f6c6..47d8ef64b 100644 --- a/sdk-core/src/main/java/io/milvus/pool/PoolConfig.java +++ b/sdk-core/src/main/java/io/milvus/pool/PoolConfig.java @@ -130,9 +130,9 @@ public String toString() { } public static class Builder { - private int maxIdlePerKey = 10; - private int minIdlePerKey = 0; - private int maxTotalPerKey = 50; + private int minIdlePerKey = 1; + private int maxIdlePerKey = 2; + private int maxTotalPerKey = 5; private int maxTotal = 1000; private boolean blockWhenExhausted = true; private Duration maxBlockWaitDuration = Duration.ofSeconds(3L); diff --git a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java index 82737206f..ef974a794 100644 --- a/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java +++ b/sdk-core/src/test/java/io/milvus/client/MilvusClientDockerTest.java @@ -3138,22 +3138,38 @@ void testClientPool() { ConnectParam connectParam = ConnectParam.newBuilder() .withUri(milvus.getEndpoint()) .build(); + int minIdlePerKey = 1; + int maxIdlePerKey = 2; + int maxTotalPerKey = 4; PoolConfig poolConfig = PoolConfig.builder() + .minIdlePerKey(minIdlePerKey) + .maxIdlePerKey(maxIdlePerKey) + .maxTotalPerKey(maxTotalPerKey) .build(); MilvusClientV1Pool pool = new MilvusClientV1Pool(poolConfig, connectParam); + String key = "dummy"; + pool.preparePool(key); + Assertions.assertEquals(minIdlePerKey, pool.getActiveClientNumber(key)); + List threadList = new ArrayList<>(); - int threadCount = 10; - int requestPerThread = 10; - String key = "192.168.1.1"; + int threadCount = 20; + int requestPerThread = 1000; for (int k = 0; k < threadCount; k++) { Thread t = new Thread(() -> { for (int i = 0; i < requestPerThread; i++) { - MilvusClient client = pool.getClient(key); - R resp = client.getVersion(); -// System.out.printf("%d, %s%n", i, resp.getData().getVersion()); - System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)); - pool.returnClient(key, client); + MilvusClient client = null; + try { + client = pool.getClient(key); + R resp = client.getVersion(); + Assertions.assertEquals(R.Status.Success.getCode(), resp.getStatus().intValue()); +// System.out.printf("%d, %s%n", i, resp.getData().getVersion()); +// System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)); + } catch (Exception e) { + System.out.printf("request failed: %s%n", e); + } finally { + pool.returnClient(key, client); + } } System.out.printf("Thread %s finished%n", Thread.currentThread().getName()); }); @@ -3164,8 +3180,18 @@ void testClientPool() { for (Thread t : threadList) { t.join(); } + Assertions.assertEquals(maxTotalPerKey, pool.getActiveClientNumber(key)); + Assertions.assertEquals(maxTotalPerKey, pool.getTotalActiveClientNumber()); + System.out.printf("qps: %.2f%n", pool.fetchClientPerSecond(key)); - System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)); + while (pool.getActiveClientNumber(key) > 1) { + TimeUnit.SECONDS.sleep(1); + System.out.printf("waiting idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)); + } + Assertions.assertEquals(maxIdlePerKey, pool.getIdleClientNumber(key)); + Assertions.assertEquals(maxIdlePerKey, pool.getTotalIdleClientNumber()); + Assertions.assertEquals(1, pool.getActiveClientNumber(key)); + Assertions.assertEquals(1, pool.getTotalActiveClientNumber()); pool.close(); } catch (Exception e) { System.out.println(e.getMessage()); diff --git a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java index 7faac012e..38d4242a5 100644 --- a/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java +++ b/sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java @@ -85,6 +85,8 @@ import java.nio.ByteBuffer; import java.util.*; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.function.Function; @@ -2970,13 +2972,42 @@ void testClientPool() { .databaseName(dummyDb) .build()); + String collectionName = "test_pool_coll"; + client.createCollection(CreateCollectionReq.builder() + .databaseName(dummyDb) + .collectionName(collectionName) + .autoID(true) + .primaryFieldName("id") + .vectorFieldName("vector") + .dimension(4) + .consistencyLevel(ConsistencyLevel.BOUNDED) + .enableDynamicField(false) + .build()); + + JsonObject row = new JsonObject(); + row.add("vector", JsonUtils.toJsonTree(utils.generateFloatVector(4))); + client.insert(InsertReq.builder() + .databaseName(dummyDb) + .collectionName(collectionName) + .data(Collections.singletonList(row)) + .build()); + client.loadCollection(LoadCollectionReq.builder() + .databaseName(dummyDb) + .collectionName(collectionName) + .build()); + try { // the default connection config will connect to default db ConnectConfig connectConfig = ConnectConfig.builder() .uri(milvus.getEndpoint()) - .rpcDeadlineMs(100L) .build(); + int minIdlePerKey = 1; + int maxIdlePerKey = 2; + int maxTotalPerKey = 4; PoolConfig poolConfig = PoolConfig.builder() + .minIdlePerKey(minIdlePerKey) + .maxIdlePerKey(maxIdlePerKey) + .maxTotalPerKey(maxTotalPerKey) .build(); MilvusClientV2Pool pool = new MilvusClientV2Pool(poolConfig, connectConfig); @@ -2989,38 +3020,76 @@ void testClientPool() { Set keys = pool.configKeys(); Assertions.assertTrue(keys.contains(dummyDb)); ConnectConfig dummyConfig = pool.getConfig(dummyDb); - Assertions.assertEquals(dummyConfig.getDbName(), dummyDb); + Assertions.assertEquals(dummyDb, dummyConfig.getDbName()); - List threadList = new ArrayList<>(); - int threadCount = 10; - int requestPerThread = 10; - String key = "default"; - for (int k = 0; k < threadCount; k++) { - Thread t = new Thread(() -> { - for (int i = 0; i < requestPerThread; i++) { - MilvusClientV2 client = pool.getClient(key); - String version = client.getServerVersion(); -// System.out.printf("%d, %s%n", i, version); - Assertions.assertEquals(client.currentUsedDatabase(), "default"); - System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)); - pool.returnClient(key, client); + pool.preparePool(dummyDb); + Assertions.assertEquals(minIdlePerKey, pool.getActiveClientNumber(dummyDb)); + + class Worker implements Runnable { + private int id = 0; + + public Worker(int id) { + this.id = id; + } + + @Override + public void run() { + MilvusClientV2 client = null; + try { + client = pool.getClient(dummyDb); + Assertions.assertEquals(dummyDb, client.currentUsedDatabase()); + + FloatVec vector = new FloatVec(utils.generateFloatVector(4)); + SearchResp resp = client.search(SearchReq.builder() + .collectionName(collectionName) + .limit(1) + .data(Collections.singletonList(vector)) + .build()); + Assertions.assertEquals(1, resp.getSearchResults().size()); + + if ((id + 1) % 10000 == 0) { + System.out.printf("current qps: %.2f%n", pool.fetchClientPerSecond(dummyDb)); + } + } catch (Exception e) { + System.out.printf("request failed: %s%n", e); + } finally { + pool.returnClient(dummyDb, client); } - System.out.printf("Thread %s finished%n", Thread.currentThread().getName()); - }); - t.start(); - threadList.add(t); + } } - - for (Thread t : threadList) { - t.join(); + long start = System.currentTimeMillis(); + int threadCount = 20; + int requestCount = 50000; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + for (int i = 0; i < requestCount; i++) { + Runnable worker = new Worker(i); + executor.execute(worker); + } + executor.shutdown(); + if (!executor.awaitTermination(100, TimeUnit.SECONDS)) { + System.err.println("Executor did not terminate in the specified time."); + Assertions.fail(); } + Assertions.assertEquals(maxTotalPerKey, pool.getActiveClientNumber(dummyDb)); + Assertions.assertEquals(maxTotalPerKey, pool.getTotalActiveClientNumber()); - System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(key), pool.getActiveClientNumber(key)); + long end = System.currentTimeMillis(); + System.out.printf("time cost: %dms, average qps: %f%n", end - start, (float) requestCount * 1000 / (end - start)); + System.out.printf("idle %d, active %d%n", pool.getIdleClientNumber(dummyDb), pool.getActiveClientNumber(dummyDb)); System.out.printf("total idle %d, total active %d%n", pool.getTotalIdleClientNumber(), pool.getTotalActiveClientNumber()); + while (pool.getActiveClientNumber(dummyDb) > 1) { + TimeUnit.SECONDS.sleep(1); + System.out.printf("waiting idle %d, active %d%n", pool.getIdleClientNumber(dummyDb), pool.getActiveClientNumber(dummyDb)); + } + Assertions.assertEquals(maxIdlePerKey, pool.getIdleClientNumber(dummyDb)); + Assertions.assertEquals(maxIdlePerKey, pool.getTotalIdleClientNumber()); + Assertions.assertEquals(1, pool.getActiveClientNumber(dummyDb)); + Assertions.assertEquals(1, pool.getTotalActiveClientNumber()); + // get client connect to the dummy db MilvusClientV2 dummyClient = pool.getClient(dummyDb); - Assertions.assertEquals(dummyClient.currentUsedDatabase(), dummyDb); + Assertions.assertEquals(dummyDb, dummyClient.currentUsedDatabase()); pool.removeConfig(dummyDb); Assertions.assertNull(pool.getConfig(dummyDb)); pool.close();