|
| 1 | +package io.milvus.pool; |
| 2 | + |
| 3 | +import io.milvus.v2.exception.ErrorCode; |
| 4 | +import io.milvus.v2.exception.MilvusClientException; |
| 5 | +import org.apache.commons.pool2.impl.GenericKeyedObjectPool; |
| 6 | +import org.slf4j.Logger; |
| 7 | +import org.slf4j.LoggerFactory; |
| 8 | + |
| 9 | +import java.util.Objects; |
| 10 | +import java.util.Timer; |
| 11 | +import java.util.TimerTask; |
| 12 | +import java.util.concurrent.CopyOnWriteArrayList; |
| 13 | +import java.util.concurrent.atomic.AtomicInteger; |
| 14 | +import java.util.concurrent.atomic.AtomicLong; |
| 15 | +import java.util.concurrent.locks.Lock; |
| 16 | +import java.util.concurrent.locks.ReentrantLock; |
| 17 | + |
| 18 | +public class ClientCache<T> { |
| 19 | + public static int THRESHOLD_INCREASE = 100; |
| 20 | + public static int THRESHOLD_DECREASE = 50; |
| 21 | + |
| 22 | + private static final Logger logger = LoggerFactory.getLogger(ClientCache.class); |
| 23 | + private final String key; |
| 24 | + private final GenericKeyedObjectPool<String, T> clientPool; |
| 25 | + private final CopyOnWriteArrayList<ClientWrapper<T>> activeClientList = new CopyOnWriteArrayList<>(); |
| 26 | + private final CopyOnWriteArrayList<ClientWrapper<T>> retireClientList = new CopyOnWriteArrayList<>(); |
| 27 | + private final Timer timer = new Timer(); |
| 28 | + private final AtomicLong totalCallNumber = new AtomicLong(0L); |
| 29 | + private final Lock clientListLock; |
| 30 | + private long lastCheckMs = 0L; |
| 31 | + private float qps = 0.0F; |
| 32 | + |
| 33 | + protected ClientCache(String key, GenericKeyedObjectPool<String, T> pool) { |
| 34 | + this.key = key; |
| 35 | + this.clientPool = pool; |
| 36 | + this.clientListLock = new ReentrantLock(true); |
| 37 | + |
| 38 | + startTimer(1000L); |
| 39 | + } |
| 40 | + |
| 41 | + public void preparePool() { |
| 42 | + try { |
| 43 | + // preparePool() will create minIdlePerKey MilvusClient objects in advance, put the pre-created clients |
| 44 | + // into activeClientList |
| 45 | + clientPool.preparePool(this.key); |
| 46 | + int minIdlePerKey = clientPool.getMinIdlePerKey(); |
| 47 | + for (int i = 0; i < minIdlePerKey; i++) { |
| 48 | + activeClientList.add(new ClientWrapper<>(clientPool.borrowObject(this.key))); |
| 49 | + } |
| 50 | + |
| 51 | + if (logger.isDebugEnabled()) { |
| 52 | + logger.debug("ClientCache key: {} cache clients: {} ", key, activeClientList.size()); |
| 53 | + logger.debug("Pool initialize idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key)); |
| 54 | + } |
| 55 | +// System.out.printf("Key: %s, cache client: %d%n", key, activeClientList.size()); |
| 56 | +// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key)); |
| 57 | + } catch (Exception e) { |
| 58 | + logger.error("Failed to prepare pool {}, exception: ", key, e); |
| 59 | + throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); |
| 60 | + } |
| 61 | + } |
| 62 | + |
| 63 | + // this method is called in an interval, it does the following tasks: |
| 64 | + // - if QPS is high, borrow client from the pool and put into activeClientList |
| 65 | + // - if QPS is low, pick a client from activeClientList and put into retireClientList |
| 66 | + // |
| 67 | + // Most of gRPC implementations uses a single long-lived HTTP/2 connection, each HTTP/2 connections have a limit |
| 68 | + // on the number of concurrent streams which is default 100. When the number of active RPCs on the connection |
| 69 | + // reaches this limit, additional RPCs are queued in the client and must wait for active RPCs to finish |
| 70 | + // before they are sent. |
| 71 | + // |
| 72 | + // Treat qps >= 75 as high, qps <= 50 as low |
| 73 | + private void checkQPS() { |
| 74 | + if (activeClientList.isEmpty()) { |
| 75 | + // reset the last check time point |
| 76 | + lastCheckMs = System.currentTimeMillis(); |
| 77 | + return; |
| 78 | + } |
| 79 | + |
| 80 | + long totalCallNum = totalCallNumber.get(); |
| 81 | + float perClientCall = (float) totalCallNum / activeClientList.size(); |
| 82 | + long timeGapMs = System.currentTimeMillis() - lastCheckMs; |
| 83 | + if (timeGapMs == 0) { |
| 84 | + timeGapMs = 1; // avoid zero |
| 85 | + } |
| 86 | + float perClientPerSecond = perClientCall * 1000 / timeGapMs; |
| 87 | + this.qps = (float) (totalCallNum * 1000) / timeGapMs; |
| 88 | + if (logger.isDebugEnabled()) { |
| 89 | + |
| 90 | + logger.info("ClientCache key: {} qps: {} clientCallPerSecond: {}", key, qps, perClientPerSecond); |
| 91 | + logger.debug("Pool idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key)); |
| 92 | + } |
| 93 | +// System.out.printf("Key: %s, QPS: %.2f, PCPS: %.2f, cache client: %d%n", key, qps, perClientPerSecond, activeClientList.size()); |
| 94 | +// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key)); |
| 95 | + |
| 96 | + // reset the counter and the last check time point |
| 97 | + totalCallNumber.set(0L); |
| 98 | + lastCheckMs = System.currentTimeMillis(); |
| 99 | + |
| 100 | + if (perClientPerSecond >= THRESHOLD_INCREASE) { |
| 101 | + // try to create more clients to reduce the perClientPerSecond to under THRESHOLD_INCREASE |
| 102 | + // add no more than 3 clients since the qps could change during we're adding new clients |
| 103 | + // the next call of checkQPS() will add more clients if the perClientPerSecond is still high |
| 104 | + int expectedNum = (int) Math.ceil((double) totalCallNum / THRESHOLD_INCREASE); |
| 105 | + int moreNum = expectedNum - activeClientList.size(); |
| 106 | + if (moreNum > 3) { |
| 107 | + moreNum = 3; |
| 108 | + } |
| 109 | + |
| 110 | + for (int k = 0; k < moreNum; k++) { |
| 111 | + T client = fetchFromPool(); |
| 112 | + // if the pool reaches MaxTotalPerKey, the new client is null |
| 113 | + if (client == null) { |
| 114 | + break; |
| 115 | + } |
| 116 | + |
| 117 | + ClientWrapper<T> wrapper = new ClientWrapper<>(client); |
| 118 | + activeClientList.add(wrapper); |
| 119 | + |
| 120 | + if (logger.isDebugEnabled()) { |
| 121 | + logger.info("ClientCache key: {} borrows a client", key); |
| 122 | + } |
| 123 | +// System.out.printf("Key: %s borrows a client%n", key); |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + if (activeClientList.size() > 1 && perClientPerSecond <= THRESHOLD_DECREASE) { |
| 128 | + // if activeClientList has only one client, no need to retire it |
| 129 | + // otherwise, retire the max load client |
| 130 | + int maxLoad = -1000; |
| 131 | + int maxIndex = -1; |
| 132 | + for (int i = 0; i < activeClientList.size(); i++) { |
| 133 | + ClientWrapper<T> wrapper = activeClientList.get(i); |
| 134 | + int refCount = wrapper.getRefCount(); |
| 135 | + if (refCount > maxLoad) { |
| 136 | + maxLoad = refCount; |
| 137 | + maxIndex = i; |
| 138 | + } |
| 139 | + } |
| 140 | + if (maxIndex >= 0) { |
| 141 | + ClientWrapper<T> wrapper = activeClientList.get(maxIndex); |
| 142 | + activeClientList.remove(maxIndex); |
| 143 | + retireClientList.add(wrapper); |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + // return the retired client to pool if ref count is zero |
| 148 | + returnRetiredClients(); |
| 149 | + } |
| 150 | + |
| 151 | + private void returnRetiredClients() { |
| 152 | + retireClientList.removeIf(wrapper -> { |
| 153 | + if (wrapper.getRefCount() <= 0) { |
| 154 | + returnToPool(wrapper.getClient()); |
| 155 | + |
| 156 | + if (logger.isDebugEnabled()) { |
| 157 | + logger.info("ClientCache key: {} returns a client", key); |
| 158 | + } |
| 159 | +// System.out.printf("Key: %s returns a client%n", key); |
| 160 | + return true; |
| 161 | + } |
| 162 | + return false; |
| 163 | + }); |
| 164 | + } |
| 165 | + |
| 166 | + private void startTimer(long interval) { |
| 167 | + if (interval < 1000L) { |
| 168 | + interval = 1000L; // min 1000 |
| 169 | + } |
| 170 | + |
| 171 | + TimerTask task = new TimerTask() { |
| 172 | + @Override |
| 173 | + public void run() { |
| 174 | + Thread currentThread = Thread.currentThread(); |
| 175 | + currentThread.setPriority(Thread.MAX_PRIORITY); |
| 176 | + |
| 177 | + checkQPS(); |
| 178 | + } |
| 179 | + }; |
| 180 | + |
| 181 | + lastCheckMs = System.currentTimeMillis(); |
| 182 | + timer.schedule(task, interval, interval); |
| 183 | + } |
| 184 | + |
| 185 | + public void stopTimer() { |
| 186 | + timer.cancel(); |
| 187 | + } |
| 188 | + |
| 189 | + public T getClient() { |
| 190 | + totalCallNumber.incrementAndGet(); |
| 191 | + if (activeClientList.isEmpty()) { |
| 192 | + // multiple threads can run into this section, add a lock to ensure only one thread can fetch the first |
| 193 | + // client object, this section is entered only one time, the lock doesn't affect major performance |
| 194 | + clientListLock.lock(); |
| 195 | + try { |
| 196 | + if (activeClientList.isEmpty()) { |
| 197 | + T client = fetchFromPool(); |
| 198 | + if (client == null) { |
| 199 | + return null; // reach MaxTotalPerKey? |
| 200 | + } |
| 201 | + ClientWrapper<T> wrapper = new ClientWrapper<>(client); |
| 202 | + activeClientList.add(wrapper); |
| 203 | + return wrapper.getClient(); |
| 204 | + } |
| 205 | + } finally { |
| 206 | + clientListLock.unlock(); |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + // round-robin is not a good choice because the activeClientList is occasionally changed. |
| 211 | + // here we return the minimum load client, the for loop of CopyOnWriteArrayList is high performance |
| 212 | + // typically, the activeClientList is not a large list since a dozen of clients can take thousands of qps, |
| 213 | + // I suppose the loop is a cheap operation. |
| 214 | + int minLoad = Integer.MAX_VALUE; |
| 215 | + ClientWrapper<T> wrapper = null; |
| 216 | + for (ClientWrapper<T> tempWrapper : activeClientList) { |
| 217 | + if (tempWrapper.getRefCount() < minLoad) { |
| 218 | + minLoad = tempWrapper.getRefCount(); |
| 219 | + wrapper = tempWrapper; |
| 220 | + } |
| 221 | + } |
| 222 | + if (wrapper == null) { |
| 223 | + // should not be here |
| 224 | + wrapper = activeClientList.get(0); |
| 225 | + } |
| 226 | + |
| 227 | + return wrapper.getClient(); |
| 228 | + } |
| 229 | + |
| 230 | + public void returnClient(T grpcClient) { |
| 231 | + // for loop of CopyOnWriteArrayList is thread safe |
| 232 | + // this method only decrement the call number, the checkQPS timer will retire client accordingly |
| 233 | + for (ClientWrapper<T> wrapper : activeClientList) { |
| 234 | + if (wrapper.equals(grpcClient)) { |
| 235 | + wrapper.returnClient(); |
| 236 | + return; |
| 237 | + } |
| 238 | + } |
| 239 | + for (ClientWrapper<T> wrapper : retireClientList) { |
| 240 | + if (wrapper.equals(grpcClient)) { |
| 241 | + wrapper.returnClient(); |
| 242 | + return; |
| 243 | + } |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + private T fetchFromPool() { |
| 248 | + try { |
| 249 | + if (activeClientList.size() + retireClientList.size() >= clientPool.getMaxTotalPerKey()) { |
| 250 | + return null; |
| 251 | + } |
| 252 | + return clientPool.borrowObject(this.key); |
| 253 | + } catch (Exception e) { |
| 254 | + // the pool might return timeout exception if it could not get a client in PoolConfig.maxBlockWaitDuration |
| 255 | + logger.error("Failed to get client, exception: ", e); |
| 256 | + return null; // return null, let the ClientCache to handle |
| 257 | + } |
| 258 | + } |
| 259 | + |
| 260 | + private void returnToPool(T grpcClient) { |
| 261 | + try { |
| 262 | + clientPool.returnObject(this.key, grpcClient); |
| 263 | + } catch (Exception e) { |
| 264 | + // the pool might return exception if the key doesn't exist or the grpcClient doesn't belong to this pool |
| 265 | + logger.error("Failed to return client, exception: " + e); |
| 266 | + throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e); |
| 267 | + } |
| 268 | + } |
| 269 | + |
| 270 | + public float fetchClientPerSecond() { |
| 271 | + return qps; |
| 272 | + } |
| 273 | + |
| 274 | + private static class ClientWrapper<T> { |
| 275 | + private T client; |
| 276 | + private AtomicInteger refCount = new AtomicInteger(0); |
| 277 | + |
| 278 | + public ClientWrapper(T client) { |
| 279 | + this.client = client; |
| 280 | + } |
| 281 | + |
| 282 | + @Override |
| 283 | + public int hashCode() { |
| 284 | + // the hash code of ClientWrapper is equal to MilvusClient hash code |
| 285 | + return this.client.hashCode(); |
| 286 | + } |
| 287 | + |
| 288 | + @Override |
| 289 | + public boolean equals(Object obj) { |
| 290 | + if (this == obj) return true; |
| 291 | + |
| 292 | + if (obj == null) { |
| 293 | + return false; |
| 294 | + } |
| 295 | + |
| 296 | + // obj is ClientWrapper |
| 297 | + if (this.getClass() == obj.getClass()) { |
| 298 | + return Objects.equals(this.client, ((ClientWrapper<?>) obj).client); |
| 299 | + } |
| 300 | + |
| 301 | + // obj is MilvusClient |
| 302 | + if (this.client != null && this.client.getClass() == obj.getClass()) { |
| 303 | + return Objects.equals(this.client, obj); |
| 304 | + } |
| 305 | + return false; |
| 306 | + } |
| 307 | + |
| 308 | + public T getClient() { |
| 309 | + this.refCount.incrementAndGet(); |
| 310 | + return this.client; |
| 311 | + } |
| 312 | + |
| 313 | + public void returnClient() { |
| 314 | + this.refCount.decrementAndGet(); |
| 315 | + } |
| 316 | + |
| 317 | + public int getRefCount() { |
| 318 | + return refCount.get(); |
| 319 | + } |
| 320 | + } |
| 321 | +} |
0 commit comments