From df4726032157568783f97863ff3a4d1a37184cad Mon Sep 17 00:00:00 2001 From: yuqianwei Date: Mon, 3 Jun 2024 00:00:24 +0800 Subject: [PATCH 1/2] feat: support jedis v3 --- arex-agent/pom.xml | 5 + arex-instrumentation/pom.xml | 1 + .../jedis/v2/JedisModuleInstrumentation.java | 2 +- .../redis/arex-jedis-v3/pom.xml | 29 + .../jedis/v3/JedisFactoryInstrumentation.java | 82 +++ .../jedis/v3/JedisModuleInstrumentation.java | 24 + .../io/arex/inst/jedis/v3/JedisWrapper.java | 567 ++++++++++++++++++ .../v3/JedisFactoryInstrumentationTest.java | 63 ++ .../arex/inst/jedis/v3/JedisWrapperTest.java | 146 +++++ 9 files changed, 918 insertions(+), 1 deletion(-) create mode 100644 arex-instrumentation/redis/arex-jedis-v3/pom.xml create mode 100644 arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentation.java create mode 100644 arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisModuleInstrumentation.java create mode 100644 arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java create mode 100644 arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentationTest.java create mode 100644 arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java diff --git a/arex-agent/pom.xml b/arex-agent/pom.xml index f0b37ba78..b17715eed 100644 --- a/arex-agent/pom.xml +++ b/arex-agent/pom.xml @@ -86,6 +86,11 @@ arex-jedis-v2 ${project.version} + + ${project.groupId} + arex-jedis-v3 + ${project.version} + ${project.groupId} arex-jedis-v4 diff --git a/arex-instrumentation/pom.xml b/arex-instrumentation/pom.xml index be2210f66..592c1e859 100644 --- a/arex-instrumentation/pom.xml +++ b/arex-instrumentation/pom.xml @@ -28,6 +28,7 @@ database/arex-database-mongo redis/arex-redis-common redis/arex-jedis-v2 + redis/arex-jedis-v3 redis/arex-jedis-v4 redis/arex-lettuce-v5 redis/arex-lettuce-v6 diff --git a/arex-instrumentation/redis/arex-jedis-v2/src/main/java/io/arex/inst/jedis/v2/JedisModuleInstrumentation.java b/arex-instrumentation/redis/arex-jedis-v2/src/main/java/io/arex/inst/jedis/v2/JedisModuleInstrumentation.java index 8c8e7d846..d23177b9b 100644 --- a/arex-instrumentation/redis/arex-jedis-v2/src/main/java/io/arex/inst/jedis/v2/JedisModuleInstrumentation.java +++ b/arex-instrumentation/redis/arex-jedis-v2/src/main/java/io/arex/inst/jedis/v2/JedisModuleInstrumentation.java @@ -14,7 +14,7 @@ public class JedisModuleInstrumentation extends ModuleInstrumentation { public JedisModuleInstrumentation() { super("jedis-v2", ModuleDescription.builder() - .name("Jedis").supportFrom(ComparableVersion.of("2.0")).supportTo(ComparableVersion.of("3.99")).build()); + .name("Jedis").supportFrom(ComparableVersion.of("2.0")).supportTo(ComparableVersion.of("2.99")).build()); // todo: check this version } diff --git a/arex-instrumentation/redis/arex-jedis-v3/pom.xml b/arex-instrumentation/redis/arex-jedis-v3/pom.xml new file mode 100644 index 000000000..ca1d8c03f --- /dev/null +++ b/arex-instrumentation/redis/arex-jedis-v3/pom.xml @@ -0,0 +1,29 @@ + + + + arex-instrumentation-parent + io.arex + ${revision} + ../../pom.xml + + 4.0.0 + + arex-jedis-v3 + + + + redis.clients + jedis + 3.10.0 + provided + + + io.arex + arex-redis-common + ${project.version} + compile + + + diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentation.java b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentation.java new file mode 100644 index 000000000..2ed6e44c3 --- /dev/null +++ b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentation.java @@ -0,0 +1,82 @@ +package io.arex.inst.jedis.v3; + +import io.arex.inst.extension.MethodInstrumentation; +import io.arex.inst.extension.TypeInstrumentation; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.exceptions.JedisException; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocketFactory; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Collections.singletonList; +import static net.bytebuddy.matcher.ElementMatchers.*; + +public class JedisFactoryInstrumentation extends TypeInstrumentation { + @Override + public ElementMatcher typeMatcher() { + return named("redis.clients.jedis.JedisFactory"); + } + + @Override + public List methodAdvices() { + return singletonList(new MethodInstrumentation( + isMethod().and(named("makeObject")).and(takesArguments(0)), + this.getClass().getName() + "$MakeObjectAdvice")); + } + + @SuppressWarnings("unused") + public static class MakeObjectAdvice { + + @Advice.OnMethodEnter(skipOn = Advice.OnNonDefaultValue.class, suppress = Throwable.class) + public static Jedis onEnter(@Advice.FieldValue("hostAndPort") AtomicReference hostAndPort, + @Advice.FieldValue("connectionTimeout") Integer connectionTimeout, + @Advice.FieldValue("soTimeout") Integer soTimeout, + @Advice.FieldValue("ssl") Boolean ssl, + @Advice.FieldValue("sslSocketFactory") SSLSocketFactory sslSocketFactory, + @Advice.FieldValue("sslParameters") SSLParameters sslParameters, + @Advice.FieldValue("hostnameVerifier") HostnameVerifier hostnameVerifier) { + final HostAndPort hp = hostAndPort.get(); + return new JedisWrapper(hp.getHost(), hp.getPort(), connectionTimeout, soTimeout, + ssl, sslSocketFactory, sslParameters, hostnameVerifier); + } + + // todo: change instrumentation: JedisFactory -> DefaultPoolObject + // need throw JedisException, not suppress throwable + @Advice.OnMethodExit + public static void onExit(@Advice.Enter Jedis jedis, + @Advice.FieldValue("password") String password, + @Advice.FieldValue("database") Integer database, + @Advice.FieldValue("clientName") String clientName, + @Advice.Return(readOnly = false) PooledObject result) throws Exception { + if (jedis == null) { + return; + } + try { + jedis.connect(); + if (password != null) { + jedis.auth(password); + } + if (database != 0) { + jedis.select(database); + } + if (clientName != null) { + jedis.clientSetname(clientName); + } + result = new DefaultPooledObject(jedis); + } catch (JedisException jex) { + jedis.close(); + throw jex; + } + } + } + +} diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisModuleInstrumentation.java b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisModuleInstrumentation.java new file mode 100644 index 000000000..de500be33 --- /dev/null +++ b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisModuleInstrumentation.java @@ -0,0 +1,24 @@ +package io.arex.inst.jedis.v3; + +import com.google.auto.service.AutoService; +import io.arex.inst.extension.ModuleDescription; +import io.arex.inst.extension.ModuleInstrumentation; +import io.arex.inst.extension.TypeInstrumentation; +import io.arex.agent.bootstrap.model.ComparableVersion; + +import java.util.List; + +import static java.util.Collections.singletonList; + +@AutoService(ModuleInstrumentation.class) +public class JedisModuleInstrumentation extends ModuleInstrumentation { + public JedisModuleInstrumentation() { + super("jedis-v3", ModuleDescription.builder() + .name("Jedis").supportFrom(ComparableVersion.of("3.0")).supportTo(ComparableVersion.of("3.99")).build()); + } + + @Override + public List instrumentationTypes() { + return singletonList(new JedisFactoryInstrumentation()); + } +} diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java new file mode 100644 index 000000000..19e8851df --- /dev/null +++ b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java @@ -0,0 +1,567 @@ +package io.arex.inst.jedis.v3; + +import io.arex.agent.bootstrap.model.MockResult; +import io.arex.agent.bootstrap.util.ArrayUtils; +import io.arex.inst.redis.common.RedisExtractor; +import io.arex.inst.redis.common.RedisKeyUtil; +import io.arex.inst.runtime.context.ContextManager; +import io.arex.inst.runtime.context.RepeatedCollectManager; +import io.arex.inst.runtime.serializer.Serializer; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.params.GetExParams; +import redis.clients.jedis.params.SetParams; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocketFactory; +import java.util.*; +import java.util.concurrent.Callable; + +public class JedisWrapper extends Jedis { + private final String url; + + public JedisWrapper(final String host, final int port, final int connectionTimeout, final int soTimeout, + final boolean ssl, final SSLSocketFactory sslSocketFactory, final SSLParameters sslParameters, + final HostnameVerifier hostnameVerifier) { + super(host, port, connectionTimeout, soTimeout, ssl, sslSocketFactory, sslParameters, hostnameVerifier); + this.url = host + ":" + port; + } + + public JedisWrapper(String host, int port, int timeout) { + super(host, port, timeout); + this.url = host + ":" + port; + } + + @Override + public String set(String key, String value) { + return this.call("set", key, () -> super.set(key, value), null); + } + + @Override + public String set(String key, String value, SetParams params) { + return this.call("set", key, params.toString(), () -> super.set(key, value, params), null); + } + + @Override + public String get(String key) { + return call("get", key, () -> super.get(key), null); + } + + @Override + public String getDel(String key) { + return call("getDel", key, () -> super.getDel(key), null); + } + + @Override + public String getEx(String key, GetExParams params) { + return call("getEx", key, params.toString(), () -> super.getEx(key, params), null); + } + + @Override + public Long exists(String... keys) { + return call("exists", RedisKeyUtil.generate(keys), () -> super.exists(keys), 0L); + } + + @Override + public Boolean exists(String key) { + return call("exists", key, () -> super.exists(key), false); + } + + @Override + public Long del(String... keys) { + return call("del", RedisKeyUtil.generate(keys), () -> super.del(keys), 0L); + } + + @Override + public Long del(String key) { + return call("del", key, () -> super.del(key), 0L); + } + + @Override + public Long unlink(String... keys) { + return call("unlink", RedisKeyUtil.generate(keys), () -> super.unlink(keys), 0L); + } + + @Override + public Long unlink(String key) { + return call("unlink", key, () -> super.unlink(key), 0L); + } + + @Override + public String type(String key) { + return call("type", key, () -> super.type(key), "none"); + } + + @Override + public Set keys(String pattern) { + return call("keys", pattern, () -> super.keys(pattern), Collections.EMPTY_SET); + } + + @Override + public String rename(byte[] oldkey, byte[] newkey) { + return call("rename", RedisKeyUtil.generate(oldkey, newkey), () -> super.rename(oldkey, newkey), null); + } + + @Override + public Long renamenx(byte[] oldkey, byte[] newkey) { + return call("renamenx", RedisKeyUtil.generate(oldkey, newkey), () -> super.renamenx(oldkey, newkey), 0L); + } + + @Override + public Long expire(String key, long seconds) { + return call("expire", key, () -> super.expire(key, seconds), 0L); + } + + @Override + public Long expireAt(String key, long unixTime) { + return call("expireAt", key, () -> super.expireAt(key, unixTime), 0L); + } + + @Override + public Long ttl(String key) { + return call("ttl", key, () -> super.ttl(key), -1L); + } + + @Override + public String getSet(String key, String value) { + return call("getSet", key, () -> super.getSet(key, value), null); + } + + @Override + public List mget(String... keys) { + return call("mget", RedisKeyUtil.generate(keys), () -> super.mget(keys), Collections.EMPTY_LIST); + } + + @Override + public Long setnx(String key, String value) { + return call("setnx", key, () -> super.setnx(key, value), 0L); + } + + //@Override + //public String setex(String key, long seconds, String value) { + // return call("setex", key, () -> super.setex(key, seconds, value), null); + //} + + @Override + public String mset(String... keysvalues) { + return call("mset", keysvalues, () -> super.mset(keysvalues), null); + } + + @Override + public Long msetnx(String... keysvalues) { + return call("msetnx", keysvalues, () -> super.msetnx(keysvalues), 0L); + } + + @Override + public Long decrBy(String key, long integer) { + return call("decrBy", key, () -> super.decrBy(key, integer), 0L); + } + + @Override + public Long decr(String key) { + return call("decr", key, () -> super.decr(key), 0L); + } + + @Override + public Long incrBy(String key, long integer) { + return call("incrBy", key, () -> super.incrBy(key, integer), 0L); + } + + @Override + public Double incrByFloat(String key, double value) { + return call("incrByFloat", key, () -> super.incrByFloat(key, value), 0d); + } + + @Override + public Long incr(String key) { + return call("incr", key, () -> super.incr(key), 0L); + } + + public Long append(String key, String value) { + return call("append", key, () -> super.append(key, value), 0L); + } + + @Override + public String substr(String key, int start, int end) { + return call("substr", key, RedisKeyUtil.generate("start", String.valueOf(start), "end", String.valueOf(end)), + () -> super.substr(key, start, end), null); + } + + @Override + public Long hset(String key, String field, String value) { + return call("hset", key, field, () -> super.hset(key, field, value), 0L); + } + + @Override + public Long hset(final String key, final Map hash) { + return call("hset", key, hash.keySet(), () -> super.hset(key, hash), 0L); + } + + @Override + public String hget(String key, String field) { + return call("hget", key, field, () -> super.hget(key, field), null); + } + + @Override + public Long hsetnx(String key, String field, String value) { + return call("hsetnx", key, field, () -> super.hsetnx(key, field, value), 0L); + } + + @Override + public String hmset(String key, Map hash) { + return call("hmset", key, Serializer.serialize(hash.keySet()), () -> super.hmset(key, hash), + null); + } + + @Override + public List hmget(String key, String... fields) { + return call("hmget", key, RedisKeyUtil.generate(fields), () -> super.hmget(key, fields), + Collections.EMPTY_LIST); + } + + @Override + public Long hincrBy(String key, String field, long value) { + return call("hincrBy", key, field, () -> super.hincrBy(key, field, value), 0L); + } + + @Override + public Double hincrByFloat(String key, String field, double value) { + return call("hincrByFloat", key, field, () -> super.hincrByFloat(key, field, value), 0d); + } + + @Override + public Boolean hexists(String key, String field) { + return call("hexists", key, field, () -> super.hexists(key, field), false); + } + + @Override + public Long hdel(String key, String... fields) { + return call("hdel", key, RedisKeyUtil.generate(fields), () -> super.hdel(key, fields), 0L); + } + + @Override + public Long hlen(String key) { + return call("hlen", key, () -> super.hlen(key), 0L); + } + + @Override + public Set hkeys(String key) { + return call("hkeys", key, () -> super.hkeys(key), Collections.EMPTY_SET); + } + + @Override + public List hvals(String key) { + return call("hvals", key, () -> super.hvals(key), Collections.EMPTY_LIST); + } + + @Override + public Map hgetAll(String key) { + return call("hgetAll", key, () -> super.hgetAll(key), Collections.EMPTY_MAP); + } + + @Override + public Long llen(String key) { + return call("llen", key, () -> super.llen(key), 0L); + } + + @Override + public List lrange(String key, long start, long end) { + return call("lrange", key, RedisKeyUtil.generate("start", String.valueOf(start), "end", String.valueOf(end)), + () -> super.lrange(key, start, end), Collections.EMPTY_LIST); + } + + @Override + public String ltrim(String key, long start, long end) { + return call("ltrim", key, RedisKeyUtil.generate("start", String.valueOf(start), "end", String.valueOf(end)), + () -> super.ltrim(key, start, end), null); + } + + @Override + public String lindex(String key, long index) { + return call("lindex", key, RedisKeyUtil.generate("index", String.valueOf(index)), + () -> super.lindex(key, index), null); + } + + @Override + public String lset(String key, long index, String value) { + return call("lset", key, RedisKeyUtil.generate("index", String.valueOf(index)), + () -> super.lset(key, index, value), null); + } + + @Override + public String lpop(String key) { + return call("lpop", key, () -> super.lpop(key), null); + } + + @Override + public String rpop(String key) { + return call("rpop", key, () -> super.rpop(key), null); + } + + @Override + public String spop(String key) { + return call("spop", key, () -> super.spop(key), null); + } + + @Override + public Set spop(String key, long count) { + return call("spop", RedisKeyUtil.generate(key, String.valueOf(count)), () -> super.spop(key, count), + Collections.EMPTY_SET); + } + + @Override + public Long scard(String key) { + return call("scard", key, () -> super.scard(key), 0L); + } + + @Override + public Set sinter(String... keys) { + return call("sinter", RedisKeyUtil.generate(keys), () -> super.sinter(keys), Collections.EMPTY_SET); + } + + @Override + public Set sunion(String... keys) { + return call("sunion", RedisKeyUtil.generate(keys), () -> super.sunion(keys), Collections.EMPTY_SET); + } + + @Override + public Set sdiff(String... keys) { + return call("sdiff", RedisKeyUtil.generate(keys), () -> super.sdiff(keys), Collections.EMPTY_SET); + } + + @Override + public String srandmember(String key) { + return call("srandmember", key, () -> super.srandmember(key), null); + } + + @Override + public List srandmember(String key, int count) { + return call("srandmember", key, RedisKeyUtil.generate("count", String.valueOf(count)), + () -> super.srandmember(key, count), Collections.EMPTY_LIST); + } + + @Override + public Long zcard(String key) { + return call("zcard", key, () -> super.zcard(key), 0L); + } + + @Override + public Long strlen(String key) { + return call("strlen", key, () -> super.strlen(key), 0L); + } + + @Override + public Long persist(String key) { + return call("persist", key, () -> super.persist(key), 0L); + } + + @Override + public Long setrange(String key, long offset, String value) { + return call("setrange", key, RedisKeyUtil.generate("offset", String.valueOf(offset)), + () -> super.setrange(key, offset, value), 0L); + } + + @Override + public String getrange(String key, long startOffset, long endOffset) { + return call("getrange", key, RedisKeyUtil.generate( + RedisKeyUtil.generate("startOffset", String.valueOf(startOffset), "endOffset", String.valueOf(endOffset))), + () -> super.getrange(key, startOffset, endOffset), null); + } + + @Override + public Long pttl(String key) { + return call("pttl", key, () -> super.pttl(key), 0L); + } + + @Override + public String psetex(String key, long milliseconds, String value) { + return call("psetex", key, value, () -> super.psetex(key, milliseconds, value), null); + } + + @Override + public byte[] substr(byte[] key, int start, int end) { + return call("substr", key, + RedisKeyUtil.generate("start", String.valueOf(start), "end", String.valueOf(end)), + () -> super.substr(key, start, end), null); + } + + @Override + public Long hset(byte[] key, byte[] field, byte[] value) { + return call("hset", key, field, () -> super.hset(key, field, value), 0L); + } + + @Override + public Long hset(byte[] key, Map hash) { + return call("hset", key, hash.keySet(), () -> super.hset(key, hash), 0L); + } + + @Override + public byte[] hget(byte[] key, byte[] field) { + return call("hget", key, field, () -> super.hget(key, field), null); + } + + @Override + public Long hdel(byte[] key, byte[]... fields) { + return call("hdel", key, RedisKeyUtil.generate(fields), () -> super.hdel(key, fields), 0L); + } + + @Override + public List hvals(byte[] key) { + return call("hvals", key, () -> super.hvals(key), Collections.EMPTY_LIST); + } + + @Override + public Map hgetAll(byte[] key) { + return call("hgetAll", key, () -> super.hgetAll(key), Collections.EMPTY_MAP); + } + + @Override + public Long pexpire(String key, long milliseconds) { + return call("pexpire", key, () -> super.pexpire(key, milliseconds), 0L); + } + + @Override + public Long pexpireAt(String key, long millisecondsTimestamp) { + return call("pexpireAt", key, () -> super.pexpireAt(key, millisecondsTimestamp), 0L); + } + + @Override + public byte[] get(final byte[] key) { + return call("get", key, () -> super.get(key), null); + } + + @Override + public Long exists(final byte[]... keys) { + return call("exists", RedisKeyUtil.generate(keys), () -> super.exists(keys), 0L); + } + + @Override + public Boolean exists(final byte[] key) { + return call("exists", key, () -> super.exists(key), false); + } + + @Override + public String type(final byte[] key) { + return call("type", key, () -> super.type(key), "none"); + } + + @Override + public byte[] getSet(final byte[] key, final byte[] value) { + return call("getSet", key, () -> super.getSet(key, value), null); + } + + @Override + public List mget(final byte[]... keys) { + return call("mget", RedisKeyUtil.generate(keys), () -> super.mget(keys), Collections.EMPTY_LIST); + } + + @Override + public Long setnx(final byte[] key, final byte[] value) { + return call("setnx", key, () -> super.setnx(key, value), 0L); + } + + @Override + public String setex(byte[] key, int seconds, byte[] value) { + return call("setex", key, () -> super.setex(key, seconds, value), + null); + } + + @Override + public Long unlink(byte[]... keys) { + return call("unlink", RedisKeyUtil.generate(keys), () -> super.unlink(keys), 0L); + } + + @Override + public Long unlink(byte[] key) { + return call("unlink", key, () -> super.unlink(key), 0L); + } + + @Override + public String rename(String oldkey, String newkey) { + return call("rename", RedisKeyUtil.generate(oldkey, newkey), () -> super.rename(oldkey, newkey), null); + } + + @Override + public Long renamenx(String oldkey, String newkey) { + return call("renamenx", RedisKeyUtil.generate(oldkey, newkey), () -> super.renamenx(oldkey, newkey), 0L); + } + + @Override + public String ping() { + return call("ping", "", () -> super.ping(), null); + } + + @Override + public byte[] ping(byte[] message) { + return call("ping", message, () -> super.ping(message), null); + } + + @Override + public String ping(String message) { + return call("ping", message, () -> super.ping(message), null); + } + + /** + * mset/msetnx + */ + private U call(String command, String[] keysValues, Callable callable, U defaultValue) { + if (ArrayUtils.isEmpty(keysValues)) { + return defaultValue; + } + + if (keysValues.length == 2) { + return call(command, keysValues[0], null, callable, defaultValue); + } + + StringBuilder keyBuilder = new StringBuilder(keysValues[0]); + for (int i = 2; i < keysValues.length; i++) { + if (i % 2 == 0) { + keyBuilder.append(';').append(keysValues[i]); + } + } + + return call(command, keyBuilder.toString(), null, callable, defaultValue); + } + + private U call(String command, Object key, Callable callable, U defaultValue) { + return call(command, key, null, callable, defaultValue); + } + + private U call(String command, Object key, Object field, Callable callable, U defaultValue) { + if (ContextManager.needRecord()) { + RepeatedCollectManager.enter(); + } + if (ContextManager.needReplay()) { + RedisExtractor extractor = new RedisExtractor(this.url, command, key, field); + MockResult mockResult = extractor.replay(); + if (mockResult.notIgnoreMockResult()) { + if (mockResult.getThrowable() instanceof RuntimeException) { + throw (RuntimeException) mockResult.getThrowable(); + } + return mockResult.getResult() == null ? defaultValue : (U) mockResult.getResult(); + } + } + + U result; + try { + result = callable.call(); + } catch (Exception e) { + if (ContextManager.needRecord() && RepeatedCollectManager.exitAndValidate()) { + RedisExtractor extractor = new RedisExtractor(this.url, command, key, field); + extractor.record(e); + } + + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } + + return defaultValue; + } + + if (ContextManager.needRecord() && RepeatedCollectManager.exitAndValidate()) { + RedisExtractor extractor = new RedisExtractor(this.url, command, key, field); + extractor.record(result); + } + return result; + } +} diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentationTest.java b/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentationTest.java new file mode 100644 index 000000000..6ea444f6c --- /dev/null +++ b/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisFactoryInstrumentationTest.java @@ -0,0 +1,63 @@ +package io.arex.inst.jedis.v3; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.MockedConstruction; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.exceptions.JedisException; + +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.*; + +@ExtendWith(MockitoExtension.class) +class JedisFactoryInstrumentationTest { + static JedisFactoryInstrumentation target = null; + + @BeforeAll + static void setUp() { + target = new JedisFactoryInstrumentation(); + } + + @AfterAll + static void tearDown() { + target = null; + } + + @Test + void typeMatcher() { + assertNotNull(target.typeMatcher()); + } + + @Test + void methodAdvices() { + assertNotNull(target.methodAdvices()); + } + + @Test + void onEnter() { + HostAndPort hostAndPort = Mockito.mock(HostAndPort.class); + Mockito.when(hostAndPort.getPort()).thenReturn(0); + AtomicReference hostAndPortAR = new AtomicReference<>(hostAndPort); + try (MockedConstruction mocked = Mockito.mockConstruction(JedisWrapper.class, (mock, context) -> { + })) { + assertNotNull(JedisFactoryInstrumentation.MakeObjectAdvice.onEnter(hostAndPortAR, 0, + 0, false, null, null, null)); + } + } + + @Test + void onExit() throws Exception { + JedisFactoryInstrumentation.MakeObjectAdvice.onExit(null, null, null, null, null); + Jedis jedis = Mockito.mock(Jedis.class); + JedisFactoryInstrumentation.MakeObjectAdvice.onExit(jedis, "mock", 1, "mock", null); + Mockito.doThrow(new JedisException("")).when(jedis).connect(); + assertThrows(JedisException.class, () -> JedisFactoryInstrumentation.MakeObjectAdvice.onExit( + jedis, "mock", 1, "mock", null)); + } +} \ No newline at end of file diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java b/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java new file mode 100644 index 000000000..313c5fc6b --- /dev/null +++ b/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java @@ -0,0 +1,146 @@ +package io.arex.inst.jedis.v3; + +import io.arex.agent.bootstrap.model.MockResult; +import io.arex.inst.redis.common.RedisExtractor; +import io.arex.inst.runtime.context.ContextManager; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.*; +import org.mockito.junit.jupiter.MockitoExtension; +import redis.clients.jedis.*; +import redis.clients.jedis.params.SetParams; + +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocketFactory; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.function.Predicate; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +@ExtendWith(MockitoExtension.class) +class JedisWrapperTest { + @Mock + SSLSocketFactory factory; + @Mock + SSLParameters parameters; + @Mock + HostnameVerifier verifier; + @InjectMocks + JedisWrapper target = new JedisWrapper("", 0, 0, 0, false, factory, parameters, verifier); + static Client client; + + @BeforeAll + static void setUp() { + Mockito.mockConstruction(Client.class, (mock, context) -> { + client = mock; + }); + Mockito.mockStatic(ContextManager.class); + } + + @AfterAll + static void tearDown() { + client = null; + Mockito.clearAllCaches(); + } + + @Test + void callWithEmptyKeysValuesReturnsDefault() { + long result = target.msetnx( new String[]{}); + assertEquals(0, result); + } + + @Test + void callWithTwoKeysValuesReturnsCallableResult() { + Mockito.when(ContextManager.needRecord()).thenReturn(false); + Mockito.when(client.getIntegerReply()).thenReturn(1L); + try (MockedConstruction mocked = Mockito.mockConstruction(RedisExtractor.class, (mock, context) -> { + })) { + long result = target.msetnx("key", "value"); + assertEquals(1L, result); + + result = target.msetnx("key1", "value1", "key2", "value2", "key3", "value3"); + assertEquals(1L, result); + + result = target.exists("key1", "key2", "key3"); + assertEquals(1L, result); + } catch (Exception e) { + assertThrows(NullPointerException.class, () -> { + throw e; + }); + } + } + + @ParameterizedTest + @MethodSource("callCase") + void call(Runnable mocker, Predicate predicate) { + mocker.run(); + try (MockedConstruction mocked = Mockito.mockConstruction(RedisExtractor.class, (mock, context) -> { + System.out.println("mock RedisExtractor"); + Mockito.when(mock.replay()).thenReturn(MockResult.success(null)); + })) { + String result = target.hget("key", "field"); + assertTrue(predicate.test(result)); + } catch (Exception e) { + assertThrows(NullPointerException.class, () -> { + throw e; + }); + } + } + + static Stream callCase() { + Runnable mocker1 = () -> { + Mockito.when(ContextManager.needReplay()).thenReturn(true); + }; + Runnable mocker2 = () -> { + Mockito.when(ContextManager.needReplay()).thenReturn(false); + Mockito.when(ContextManager.needRecord()).thenReturn(true); + Mockito.when(client.getBulkReply()).thenThrow(new NullPointerException()); + }; + Runnable mocker3 = () -> { + Mockito.when(client.getBulkReply()).thenReturn("mock"); + }; + Predicate predicate1 = Objects::isNull; + Predicate predicate2 = "mock"::equals; + return Stream.of( + arguments(mocker1, predicate1), + arguments(mocker2, predicate1), + arguments(mocker3, predicate2) + ); + } + + @Test + void testApi() { + assertDoesNotThrow(() -> target.expire("key".getBytes(), 1)); + assertDoesNotThrow(() -> target.append("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.substr("key".getBytes(), 1, 2)); + assertDoesNotThrow(() -> target.hset("key".getBytes(), "field".getBytes(), "value".getBytes())); + Map hash = new HashMap<>(); + assertDoesNotThrow(() -> target.hset("key".getBytes(), hash)); + Map hash1 = new HashMap<>(); + assertDoesNotThrow(() -> target.hset("key", hash1)); + assertDoesNotThrow(() -> target.hget("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.hdel("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.hvals("key".getBytes())); + assertDoesNotThrow(() -> target.hgetAll("key".getBytes())); + assertDoesNotThrow(() -> target.set("key", "value")); + assertDoesNotThrow(() -> target.set("key", "value", new SetParams().ex(10))); + assertDoesNotThrow(() -> target.get("key".getBytes())); + assertDoesNotThrow(() -> target.exists("key".getBytes())); + assertDoesNotThrow(() -> target.type("key".getBytes())); + assertDoesNotThrow(() -> target.getSet("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.setnx("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.setex("key".getBytes(), 1, "value".getBytes())); + assertDoesNotThrow(() -> target.unlink("key".getBytes())); + assertDoesNotThrow(() -> target.ping("key".getBytes())); + } +} From 1e9885aaee10a0318442b4446ff0b9b22e0ed4b5 Mon Sep 17 00:00:00 2001 From: yuqianwei Date: Tue, 24 Dec 2024 20:15:35 +0800 Subject: [PATCH 2/2] feat: support jedis v3#496 --- .../io/arex/inst/jedis/v3/JedisWrapper.java | 8 +- .../arex/inst/jedis/v3/JedisWrapperTest.java | 91 +++++++++++++++++-- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java index 19e8851df..112f2f84f 100644 --- a/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java +++ b/arex-instrumentation/redis/arex-jedis-v3/src/main/java/io/arex/inst/jedis/v3/JedisWrapper.java @@ -137,10 +137,10 @@ public Long setnx(String key, String value) { return call("setnx", key, () -> super.setnx(key, value), 0L); } - //@Override - //public String setex(String key, long seconds, String value) { - // return call("setex", key, () -> super.setex(key, seconds, value), null); - //} + @Override + public String setex(String key, long seconds, String value) { + return call("setex", key, () -> super.setex(key, seconds, value), null); + } @Override public String mset(String... keysvalues) { diff --git a/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java b/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java index 313c5fc6b..337555442 100644 --- a/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java +++ b/arex-instrumentation/redis/arex-jedis-v3/src/test/java/io/arex/inst/jedis/v3/JedisWrapperTest.java @@ -13,6 +13,7 @@ import org.mockito.*; import org.mockito.junit.jupiter.MockitoExtension; import redis.clients.jedis.*; +import redis.clients.jedis.params.GetExParams; import redis.clients.jedis.params.SetParams; import javax.net.ssl.HostnameVerifier; @@ -120,27 +121,99 @@ static Stream callCase() { @Test void testApi() { - assertDoesNotThrow(() -> target.expire("key".getBytes(), 1)); - assertDoesNotThrow(() -> target.append("key".getBytes(), "value".getBytes())); - assertDoesNotThrow(() -> target.substr("key".getBytes(), 1, 2)); - assertDoesNotThrow(() -> target.hset("key".getBytes(), "field".getBytes(), "value".getBytes())); Map hash = new HashMap<>(); assertDoesNotThrow(() -> target.hset("key".getBytes(), hash)); Map hash1 = new HashMap<>(); + assertDoesNotThrow(() -> target.set("key", "value")); + assertDoesNotThrow(() -> target.set("key", "value", new SetParams().ex(10))); + assertDoesNotThrow(() -> target.get("key")); + assertDoesNotThrow(() -> target.getDel("key")); + assertDoesNotThrow(() -> target.getEx("key", new GetExParams().ex(10))); + assertDoesNotThrow(() -> target.exists("key1", "key2")); + assertDoesNotThrow(() -> target.exists("key")); + assertDoesNotThrow(() -> target.del("key1", "key2")); + assertDoesNotThrow(() -> target.del("key")); + assertDoesNotThrow(() -> target.unlink("key1", "key2")); + assertDoesNotThrow(() -> target.unlink("key")); + assertDoesNotThrow(() -> target.type("key")); + assertDoesNotThrow(() -> target.keys("key")); + assertDoesNotThrow(() -> target.rename("key1".getBytes(), "key2".getBytes())); + assertDoesNotThrow(() -> target.renamenx("key1".getBytes(), "key2".getBytes())); + assertDoesNotThrow(() -> target.expire("key", 1L)); + assertDoesNotThrow(() -> target.expireAt("key", 1L)); + assertDoesNotThrow(() -> target.ttl("key")); + assertDoesNotThrow(() -> target.getSet("key", "value")); + assertDoesNotThrow(() -> target.mget("key1", "key2")); + assertDoesNotThrow(() -> target.setnx("key", "value")); + assertDoesNotThrow(() -> target.setex("key", 1L, "value")); + assertDoesNotThrow(() -> target.mset("key1", "value1", "key2", "value2")); + assertDoesNotThrow(() -> target.msetnx("key1", "value1", "key2", "value2")); + assertDoesNotThrow(() -> target.decrBy("key", 1L)); + assertDoesNotThrow(() -> target.decr("key")); + assertDoesNotThrow(() -> target.incrBy("key", 1L)); + assertDoesNotThrow(() -> target.incrByFloat("key", 1.1D)); + assertDoesNotThrow(() -> target.incr("key")); + assertDoesNotThrow(() -> target.append("key", "value")); + assertDoesNotThrow(() -> target.substr("key", 1, 2)); + assertDoesNotThrow(() -> target.hset("key", "field", "value")); assertDoesNotThrow(() -> target.hset("key", hash1)); - assertDoesNotThrow(() -> target.hget("key".getBytes(), "value".getBytes())); - assertDoesNotThrow(() -> target.hdel("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.hget("key", "field")); + assertDoesNotThrow(() -> target.hsetnx("key", "field", "value")); + assertDoesNotThrow(() -> target.hmset("key", hash1)); + assertDoesNotThrow(() -> target.hmget("key", "field")); + assertDoesNotThrow(() -> target.hincrBy("key", "field", 1L)); + assertDoesNotThrow(() -> target.hincrByFloat("key", "field", 1.1D)); + assertDoesNotThrow(() -> target.hexists("key", "field")); + assertDoesNotThrow(() -> target.hdel("key", "field")); + assertDoesNotThrow(() -> target.hlen("key")); + assertDoesNotThrow(() -> target.hkeys("key")); + assertDoesNotThrow(() -> target.hvals("key")); + assertDoesNotThrow(() -> target.hgetAll("key")); + assertDoesNotThrow(() -> target.llen("key")); + assertDoesNotThrow(() -> target.lrange("key", 1L, 2L)); + assertDoesNotThrow(() -> target.ltrim("key", 1L, 2L)); + assertDoesNotThrow(() -> target.lindex("key", 1L)); + assertDoesNotThrow(() -> target.lset("key", 1L, "value")); + assertDoesNotThrow(() -> target.lpop("key")); + assertDoesNotThrow(() -> target.rpop("key")); + assertDoesNotThrow(() -> target.spop("key")); + assertDoesNotThrow(() -> target.spop("key", 1L)); + assertDoesNotThrow(() -> target.scard("key")); + assertDoesNotThrow(() -> target.sinter("key1", "key2")); + assertDoesNotThrow(() -> target.sunion("key1", "key2")); + assertDoesNotThrow(() -> target.sdiff("key1", "key2")); + assertDoesNotThrow(() -> target.srandmember("key")); + assertDoesNotThrow(() -> target.srandmember("key", 1)); + assertDoesNotThrow(() -> target.zcard("key")); + assertDoesNotThrow(() -> target.strlen("key")); + assertDoesNotThrow(() -> target.persist("key")); + assertDoesNotThrow(() -> target.setrange("key", 1L, "value")); + assertDoesNotThrow(() -> target.getrange("key", 1L, 2L)); + assertDoesNotThrow(() -> target.pttl("key")); + assertDoesNotThrow(() -> target.psetex("key", 1L, "value")); + assertDoesNotThrow(() -> target.substr("key".getBytes(), 1, 2)); + assertDoesNotThrow(() -> target.hset("key".getBytes(), "field".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.hset("key".getBytes(), hash)); + assertDoesNotThrow(() -> target.hget("key".getBytes(), "field".getBytes())); + assertDoesNotThrow(() -> target.hdel("key".getBytes(), "field".getBytes())); assertDoesNotThrow(() -> target.hvals("key".getBytes())); assertDoesNotThrow(() -> target.hgetAll("key".getBytes())); - assertDoesNotThrow(() -> target.set("key", "value")); - assertDoesNotThrow(() -> target.set("key", "value", new SetParams().ex(10))); + assertDoesNotThrow(() -> target.pexpire("key", 1L)); + assertDoesNotThrow(() -> target.pexpireAt("key", 1L)); assertDoesNotThrow(() -> target.get("key".getBytes())); + assertDoesNotThrow(() -> target.exists("key1".getBytes(), "key2".getBytes())); assertDoesNotThrow(() -> target.exists("key".getBytes())); assertDoesNotThrow(() -> target.type("key".getBytes())); assertDoesNotThrow(() -> target.getSet("key".getBytes(), "value".getBytes())); + assertDoesNotThrow(() -> target.mget("key1".getBytes(), "key2".getBytes())); assertDoesNotThrow(() -> target.setnx("key".getBytes(), "value".getBytes())); assertDoesNotThrow(() -> target.setex("key".getBytes(), 1, "value".getBytes())); + assertDoesNotThrow(() -> target.unlink("key1".getBytes(), "key2".getBytes())); assertDoesNotThrow(() -> target.unlink("key".getBytes())); - assertDoesNotThrow(() -> target.ping("key".getBytes())); + assertDoesNotThrow(() -> target.rename("key", "key2")); + assertDoesNotThrow(() -> target.renamenx("key", "key2")); + assertDoesNotThrow(() -> target.ping()); + assertDoesNotThrow(() -> target.ping("message".getBytes())); + assertDoesNotThrow(() -> target.ping("message")); } }