Skip to content

Commit e6796d1

Browse files
committed
Optimize client pool
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent eb10051 commit e6796d1

File tree

4 files changed

+507
-44
lines changed

4 files changed

+507
-44
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
package io.milvus.pool;
2+
3+
import io.milvus.v2.exception.ErrorCode;
4+
import io.milvus.v2.exception.MilvusClientException;
5+
import org.apache.commons.pool2.impl.GenericKeyedObjectPool;
6+
import org.slf4j.Logger;
7+
import org.slf4j.LoggerFactory;
8+
9+
import java.util.Objects;
10+
import java.util.Timer;
11+
import java.util.TimerTask;
12+
import java.util.concurrent.CopyOnWriteArrayList;
13+
import java.util.concurrent.atomic.AtomicInteger;
14+
import java.util.concurrent.atomic.AtomicLong;
15+
import java.util.concurrent.locks.Lock;
16+
import java.util.concurrent.locks.ReentrantLock;
17+
18+
public class ClientCache<T> {
19+
public static int THRESHOLD_INCREASE = 100;
20+
public static int THRESHOLD_DECREASE = 50;
21+
22+
private static final Logger logger = LoggerFactory.getLogger(ClientCache.class);
23+
private final String key;
24+
private final GenericKeyedObjectPool<String, T> clientPool;
25+
private final CopyOnWriteArrayList<ClientWrapper<T>> activeClientList = new CopyOnWriteArrayList<>();
26+
private final CopyOnWriteArrayList<ClientWrapper<T>> retireClientList = new CopyOnWriteArrayList<>();
27+
private final Timer timer = new Timer();
28+
private final AtomicLong totalCallNumber = new AtomicLong(0L);
29+
private final Lock clientListLock;
30+
private long lastCheckMs = 0L;
31+
private float qps = 0.0F;
32+
33+
protected ClientCache(String key, GenericKeyedObjectPool<String, T> pool) {
34+
this.key = key;
35+
this.clientPool = pool;
36+
this.clientListLock = new ReentrantLock(true);
37+
38+
startTimer(1000L);
39+
}
40+
41+
public void preparePool() {
42+
try {
43+
// preparePool() will create minIdlePerKey MilvusClient objects in advance, put the pre-created clients
44+
// into activeClientList
45+
clientPool.preparePool(this.key);
46+
int minIdlePerKey = clientPool.getMinIdlePerKey();
47+
for (int i = 0; i < minIdlePerKey; i++) {
48+
activeClientList.add(new ClientWrapper<>(clientPool.borrowObject(this.key)));
49+
}
50+
51+
if (logger.isDebugEnabled()) {
52+
logger.debug("ClientCache key: {} cache clients: {} ", key, activeClientList.size());
53+
logger.debug("Pool initialize idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key));
54+
}
55+
// System.out.printf("Key: %s, cache client: %d%n", key, activeClientList.size());
56+
// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key));
57+
} catch (Exception e) {
58+
logger.error("Failed to prepare pool {}, exception: ", key, e);
59+
throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e);
60+
}
61+
}
62+
63+
// this method is called in an interval, it does the following tasks:
64+
// - if QPS is high, borrow client from the pool and put into activeClientList
65+
// - if QPS is low, pick a client from activeClientList and put into retireClientList
66+
//
67+
// Most of gRPC implementations uses a single long-lived HTTP/2 connection, each HTTP/2 connections have a limit
68+
// on the number of concurrent streams which is default 100. When the number of active RPCs on the connection
69+
// reaches this limit, additional RPCs are queued in the client and must wait for active RPCs to finish
70+
// before they are sent.
71+
//
72+
// Treat qps >= 75 as high, qps <= 50 as low
73+
private void checkQPS() {
74+
if (activeClientList.isEmpty()) {
75+
// reset the last check time point
76+
lastCheckMs = System.currentTimeMillis();
77+
return;
78+
}
79+
80+
long totalCallNum = totalCallNumber.get();
81+
float perClientCall = (float) totalCallNum / activeClientList.size();
82+
long timeGapMs = System.currentTimeMillis() - lastCheckMs;
83+
if (timeGapMs == 0) {
84+
timeGapMs = 1; // avoid zero
85+
}
86+
float perClientPerSecond = perClientCall * 1000 / timeGapMs;
87+
this.qps = (float) (totalCallNum * 1000) / timeGapMs;
88+
if (logger.isDebugEnabled()) {
89+
90+
logger.info("ClientCache key: {} qps: {} clientCallPerSecond: {}", key, qps, perClientPerSecond);
91+
logger.debug("Pool idle: {} active: {} ", clientPool.getNumIdle(key), clientPool.getNumActive(key));
92+
}
93+
// System.out.printf("Key: %s, QPS: %.2f, PCPS: %.2f, cache client: %d%n", key, qps, perClientPerSecond, activeClientList.size());
94+
// System.out.printf("Pool idle %d, active %d%n", clientPool.getNumIdle(key), clientPool.getNumActive(key));
95+
96+
// reset the counter and the last check time point
97+
totalCallNumber.set(0L);
98+
lastCheckMs = System.currentTimeMillis();
99+
100+
if (perClientPerSecond >= THRESHOLD_INCREASE) {
101+
// try to create more clients to reduce the perClientPerSecond to under THRESHOLD_INCREASE
102+
// add no more than 3 clients since the qps could change during we're adding new clients
103+
// the next call of checkQPS() will add more clients if the perClientPerSecond is still high
104+
int expectedNum = (int) Math.ceil((double) totalCallNum / THRESHOLD_INCREASE);
105+
int moreNum = expectedNum - activeClientList.size();
106+
if (moreNum > 3) {
107+
moreNum = 3;
108+
}
109+
110+
for (int k = 0; k < moreNum; k++) {
111+
T client = fetchFromPool();
112+
// if the pool reaches MaxTotalPerKey, the new client is null
113+
if (client == null) {
114+
break;
115+
}
116+
117+
ClientWrapper<T> wrapper = new ClientWrapper<>(client);
118+
activeClientList.add(wrapper);
119+
120+
if (logger.isDebugEnabled()) {
121+
logger.info("ClientCache key: {} borrows a client", key);
122+
}
123+
// System.out.printf("Key: %s borrows a client%n", key);
124+
}
125+
}
126+
127+
if (activeClientList.size() > 1 && perClientPerSecond <= THRESHOLD_DECREASE) {
128+
// if activeClientList has only one client, no need to retire it
129+
// otherwise, retire the max load client
130+
int maxLoad = -1000;
131+
int maxIndex = -1;
132+
for (int i = 0; i < activeClientList.size(); i++) {
133+
ClientWrapper<T> wrapper = activeClientList.get(i);
134+
int refCount = wrapper.getRefCount();
135+
if (refCount > maxLoad) {
136+
maxLoad = refCount;
137+
maxIndex = i;
138+
}
139+
}
140+
if (maxIndex >= 0) {
141+
ClientWrapper<T> wrapper = activeClientList.get(maxIndex);
142+
activeClientList.remove(maxIndex);
143+
retireClientList.add(wrapper);
144+
}
145+
}
146+
147+
// return the retired client to pool if ref count is zero
148+
returnRetiredClients();
149+
}
150+
151+
private void returnRetiredClients() {
152+
retireClientList.removeIf(wrapper -> {
153+
if (wrapper.getRefCount() <= 0) {
154+
returnToPool(wrapper.getClient());
155+
156+
if (logger.isDebugEnabled()) {
157+
logger.info("ClientCache key: {} returns a client", key);
158+
}
159+
// System.out.printf("Key: %s returns a client%n", key);
160+
return true;
161+
}
162+
return false;
163+
});
164+
}
165+
166+
private void startTimer(long interval) {
167+
if (interval < 1000L) {
168+
interval = 1000L; // min 1000
169+
}
170+
171+
TimerTask task = new TimerTask() {
172+
@Override
173+
public void run() {
174+
Thread currentThread = Thread.currentThread();
175+
currentThread.setPriority(Thread.MAX_PRIORITY);
176+
177+
checkQPS();
178+
}
179+
};
180+
181+
lastCheckMs = System.currentTimeMillis();
182+
timer.schedule(task, interval, interval);
183+
}
184+
185+
public void stopTimer() {
186+
timer.cancel();
187+
}
188+
189+
public T getClient() {
190+
totalCallNumber.incrementAndGet();
191+
if (activeClientList.isEmpty()) {
192+
// multiple threads can run into this section, add a lock to ensure only one thread can fetch the first
193+
// client object, this section is entered only one time, the lock doesn't affect major performance
194+
clientListLock.lock();
195+
try {
196+
if (activeClientList.isEmpty()) {
197+
T client = fetchFromPool();
198+
if (client == null) {
199+
return null; // reach MaxTotalPerKey?
200+
}
201+
ClientWrapper<T> wrapper = new ClientWrapper<>(client);
202+
activeClientList.add(wrapper);
203+
return wrapper.getClient();
204+
}
205+
} finally {
206+
clientListLock.unlock();
207+
}
208+
}
209+
210+
// round-robin is not a good choice because the activeClientList is occasionally changed.
211+
// here we return the minimum load client, the for loop of CopyOnWriteArrayList is high performance
212+
// typically, the activeClientList is not a large list since a dozen of clients can take thousands of qps,
213+
// I suppose the loop is a cheap operation.
214+
int minLoad = Integer.MAX_VALUE;
215+
ClientWrapper<T> wrapper = null;
216+
for (ClientWrapper<T> tempWrapper : activeClientList) {
217+
if (tempWrapper.getRefCount() < minLoad) {
218+
minLoad = tempWrapper.getRefCount();
219+
wrapper = tempWrapper;
220+
}
221+
}
222+
if (wrapper == null) {
223+
// should not be here
224+
wrapper = activeClientList.get(0);
225+
}
226+
227+
return wrapper.getClient();
228+
}
229+
230+
public void returnClient(T grpcClient) {
231+
// for loop of CopyOnWriteArrayList is thread safe
232+
// this method only decrement the call number, the checkQPS timer will retire client accordingly
233+
for (ClientWrapper<T> wrapper : activeClientList) {
234+
if (wrapper.equals(grpcClient)) {
235+
wrapper.returnClient();
236+
return;
237+
}
238+
}
239+
for (ClientWrapper<T> wrapper : retireClientList) {
240+
if (wrapper.equals(grpcClient)) {
241+
wrapper.returnClient();
242+
return;
243+
}
244+
}
245+
}
246+
247+
private T fetchFromPool() {
248+
try {
249+
if (activeClientList.size() + retireClientList.size() >= clientPool.getMaxTotalPerKey()) {
250+
return null;
251+
}
252+
return clientPool.borrowObject(this.key);
253+
} catch (Exception e) {
254+
// the pool might return timeout exception if it could not get a client in PoolConfig.maxBlockWaitDuration
255+
logger.error("Failed to get client, exception: ", e);
256+
return null; // return null, let the ClientCache to handle
257+
}
258+
}
259+
260+
private void returnToPool(T grpcClient) {
261+
try {
262+
clientPool.returnObject(this.key, grpcClient);
263+
} catch (Exception e) {
264+
// the pool might return exception if the key doesn't exist or the grpcClient doesn't belong to this pool
265+
logger.error("Failed to return client, exception: " + e);
266+
throw new MilvusClientException(ErrorCode.CLIENT_ERROR, e);
267+
}
268+
}
269+
270+
public float fetchClientPerSecond() {
271+
return qps;
272+
}
273+
274+
private static class ClientWrapper<T> {
275+
private T client;
276+
private AtomicInteger refCount = new AtomicInteger(0);
277+
278+
public ClientWrapper(T client) {
279+
this.client = client;
280+
}
281+
282+
@Override
283+
public int hashCode() {
284+
// the hash code of ClientWrapper is equal to MilvusClient hash code
285+
return this.client.hashCode();
286+
}
287+
288+
@Override
289+
public boolean equals(Object obj) {
290+
if (this == obj) return true;
291+
292+
if (obj == null) {
293+
return false;
294+
}
295+
296+
// obj is ClientWrapper
297+
if (this.getClass() == obj.getClass()) {
298+
return Objects.equals(this.client, ((ClientWrapper<?>) obj).client);
299+
}
300+
301+
// obj is MilvusClient
302+
if (this.client != null && this.client.getClass() == obj.getClass()) {
303+
return Objects.equals(this.client, obj);
304+
}
305+
return false;
306+
}
307+
308+
public T getClient() {
309+
this.refCount.incrementAndGet();
310+
return this.client;
311+
}
312+
313+
public void returnClient() {
314+
this.refCount.decrementAndGet();
315+
}
316+
317+
public int getRefCount() {
318+
return refCount.get();
319+
}
320+
}
321+
}

0 commit comments

Comments
 (0)