Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import io.netty.channel.EventLoop;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRawRecord;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.resolver.ResolvedAddressTypes;
Expand All @@ -57,6 +58,7 @@
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.PromiseNotifier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -75,6 +77,7 @@
import java.util.List;
import java.util.Map;
import java.util.RandomAccess;
import java.util.function.Function;
import java.util.function.IntFunction;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -127,7 +130,7 @@ final class DefaultDnsClient implements DnsClient {
private static final Cancellable TERMINATED = () -> { };

private final EventLoopAwareNettyIoExecutor nettyIoExecutor;
private final DnsNameResolver resolver;
private final DnsNameResolverDelegate resolver;
private final MinTtlCache ttlCache;
private final long maxTTLNanos;
private final long ttlJitterNanos;
Expand Down Expand Up @@ -160,7 +163,8 @@ final class DefaultDnsClient implements DnsClient {
final ServiceDiscovererEvent.Status missingRecordStatus,
final boolean nxInvalidation,
final boolean tcpFallbackOnTimeout,
final String datagramChannelStrategy) {
final String datagramChannelStrategy,
@Nullable final Integer backupRequestDelay) {
this.srvConcurrency = srvConcurrency;
this.srvFilterDuplicateEvents = srvFilterDuplicateEvents;
// Implementation of this class expects to use only single EventLoop from IoExecutor
Expand Down Expand Up @@ -229,7 +233,9 @@ final class DefaultDnsClient implements DnsClient {
if (dnsServerAddressStreamProvider != null) {
builder.nameServerProvider(toNettyType(dnsServerAddressStreamProvider));
}
resolver = builder.build();
resolver = backupRequestDelay == null || backupRequestDelay <= 0 ? new DefaultResolver(builder.build()) :
new BackupRequestResolver(builder.build(), builder.consolidateCacheSize(0).build(),
eventLoop, backupRequestDelay);
this.resolutionTimeoutMillis = resolutionTimeout != null ? resolutionTimeout.toMillis() :
// Default value is chosen based on a combination of default "timeout" and "attempts" options of
// /etc/resolv.conf: https://man7.org/linux/man-pages/man5/resolv.conf.5.html
Expand Down Expand Up @@ -1149,4 +1155,112 @@ static SrvAddressRemovedException newInstance(Class<?> clazz, String method) {
return unknownStackTrace(new SrvAddressRemovedException(), clazz, method);
}
}

interface DnsNameResolverDelegate {
void close();

Future<List<InetAddress>> resolveAll(String name);

Future<List<DnsRecord>> resolveAll(DnsQuestion name);

long queryTimeoutMillis();
}

static final class DefaultResolver implements DnsNameResolverDelegate {
private final DnsNameResolver nettyResolver;

DefaultResolver(DnsNameResolver nettyResolver) {
this.nettyResolver = nettyResolver;
}

@Override
public void close() {
nettyResolver.close();
}

@Override
public Future<List<InetAddress>> resolveAll(String name) {
return nettyResolver.resolveAll(name);
}

@Override
public Future<List<DnsRecord>> resolveAll(DnsQuestion question) {
return nettyResolver.resolveAll(question);
}

@Override
public long queryTimeoutMillis() {
return nettyResolver.queryTimeoutMillis();
}
}

static final class BackupRequestResolver implements DnsNameResolverDelegate {

private final DnsNameResolver primaryResolver;
private final DnsNameResolver backupResolver;
private final EventLoop eventLoop;
private final int backupDelayMs;

BackupRequestResolver(DnsNameResolver primaryResolver, DnsNameResolver backupResolver,
EventLoop eventLoop, int backupDelayMs) {
this.primaryResolver = primaryResolver;
this.backupResolver = backupResolver;
this.eventLoop = eventLoop;
this.backupDelayMs = backupDelayMs;
}

@Override
public void close() {
try {
primaryResolver.close();
} finally {
backupResolver.close();
}
}

@Override
public Future<List<InetAddress>> resolveAll(String name) {
return withBackup(resolver -> resolver.resolveAll(name));
}

@Override
public Future<List<DnsRecord>> resolveAll(DnsQuestion name) {
return withBackup(resolver -> resolver.resolveAll(name));
}

@Override
public long queryTimeoutMillis() {
return primaryResolver.queryTimeoutMillis();
}

private <T> Future<T> withBackup(Function<? super DnsNameResolver, ? extends Future<T>> query) {
Future<T> primaryQuery = query.apply(primaryResolver);
if (primaryQuery.isDone()) {
return primaryQuery;
}
int backupDelay = backupDelayMs();
if (backupDelay <= 0) {
// no backup for this request
return primaryQuery;
}
Promise<T> result = eventLoop.newPromise();
Future<?> timer = eventLoop.schedule(() -> {
if (allowBackupRequest()) {
PromiseNotifier.cascade(false, query.apply(backupResolver), result);
}
}, backupDelay, MILLISECONDS);
primaryQuery.addListener(_unused -> timer.cancel(true));
PromiseNotifier.cascade(false, primaryQuery, result);
return result;
}

private boolean allowBackupRequest() {
// In the future we should make this predicated on a token bucket.
return true;
}

private int backupDelayMs() {
return backupDelayMs;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import static io.servicetalk.utils.internal.NumberUtils.ensureNonNegative;
import static io.servicetalk.utils.internal.NumberUtils.ensurePositive;
import static java.lang.Boolean.getBoolean;
import static java.lang.Integer.getInteger;
import static java.lang.Math.min;
import static java.lang.System.getProperty;
import static java.time.Duration.ofSeconds;
Expand All @@ -58,6 +59,9 @@ public final class DefaultDnsServiceDiscovererBuilder implements DnsServiceDisco

private static final Logger LOGGER = LoggerFactory.getLogger(DefaultDnsServiceDiscovererBuilder.class);

// Backup request static configuration: values > 0 mean allow a backup request with fixed delay, disabled otherwise.
private static final String DNS_BACKUP_REQUEST_DELAY_MS_PROPERTY =
"io.servicetalk.dns.discovery.netty.experimental.dnsBackupRequestDelayMs";
// FIXME: 0.43 - consider removing deprecated system properties.
// Those were introduced temporarily as a way for us to experiment with new Netty features.
// In the next major release, we should promote required features to builder API.
Expand All @@ -72,6 +76,8 @@ public final class DefaultDnsServiceDiscovererBuilder implements DnsServiceDisco
@Deprecated
private static final String NX_DOMAIN_INVALIDATES_PROPERTY = "io.servicetalk.dns.discovery.nxdomain.invalidation";

@Nullable
private static final Integer DNS_BACKUP_REQUEST_DELAY_MS = getInteger(DNS_BACKUP_REQUEST_DELAY_MS_PROPERTY);
private static final String DEFAULT_DATAGRAM_CHANNEL_STRATEGY =
getProperty(DATAGRAM_CHANNEL_STRATEGY_PROPERTY, "ChannelPerResolver");
private static final boolean DEFAULT_TCP_FALLBACK_ON_TIMEOUT = getBoolean(TCP_FALLBACK_ON_TIMEOUT_PROPERTY);
Expand Down Expand Up @@ -409,7 +415,7 @@ DnsClient build() {
srvHostNameRepeatInitialDelay, srvHostNameRepeatJitter, maxUdpPayloadSize, ndots, optResourceEnabled,
queryTimeout, resolutionTimeout, dnsResolverAddressTypes, localAddress, dnsServerAddressStreamProvider,
observer, missingRecordStatus, nxInvalidation,
DEFAULT_TCP_FALLBACK_ON_TIMEOUT, DEFAULT_DATAGRAM_CHANNEL_STRATEGY);
DEFAULT_TCP_FALLBACK_ON_TIMEOUT, DEFAULT_DATAGRAM_CHANNEL_STRATEGY, DNS_BACKUP_REQUEST_DELAY_MS);
return filterFactory == null ? rawClient : filterFactory.create(rawClient);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
import io.servicetalk.transport.netty.internal.EventLoopAwareNettyIoExecutor;
import io.servicetalk.utils.internal.DurationUtils;

import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.util.concurrent.Promise;
import org.apache.directory.server.dns.messages.RecordType;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
Expand All @@ -56,6 +59,8 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -94,16 +99,23 @@
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class DefaultDnsClientTest {
private static final Logger LOGGER = LoggerFactory.getLogger(DefaultDnsClientTest.class);
private static final int DEFAULT_TTL = 1;
private static final Duration DEFAULT_TIMEOUT = ofMillis(CI ? 500 : 100);
private static final int BACKUP_DELAY = CI ? 200 : 20;

@RegisterExtension
static final ExecutorExtension<TestExecutor> timerExecutor = ExecutorExtension.withTestExecutor()
Expand Down Expand Up @@ -145,9 +157,15 @@ void setup(UnaryOperator<DefaultDnsServiceDiscovererBuilder> builderFunction) th

@AfterEach
public void tearDown() throws Exception {
client.closeAsync().toFuture().get();
dnsServer.stop();
dnsServer2.stop();
if (client != null) {
client.closeAsync().toFuture().get();
}
if (dnsServer != null) {
dnsServer.stop();
}
if (dnsServer2 != null) {
dnsServer2.stop();
}
}

private static void advanceTime() throws Exception {
Expand Down Expand Up @@ -1225,6 +1243,78 @@ void testResolutionTimeout(RecordType recordType) throws Exception {
testTimeout(Duration.ZERO, DEFAULT_TIMEOUT, recordType);
}

@Test
void backupRequest() throws Exception {
DnsNameResolver primaryResolver = mock(DnsNameResolver.class);
DnsNameResolver backupResolver = mock(DnsNameResolver.class);
EventLoop eventLoop = ioExecutor.executor().eventLoopGroup().next();
Promise<List<InetAddress>> primaryPromise = eventLoop.newPromise();
when(primaryResolver.resolveAll("foo")).thenReturn(primaryPromise);

Promise<List<InetAddress>> backupPromise = eventLoop.newPromise();
when(backupResolver.resolveAll("foo")).thenReturn(backupPromise);

DefaultDnsClient.DnsNameResolverDelegate resolver = new DefaultDnsClient.BackupRequestResolver(
primaryResolver, backupResolver, eventLoop, BACKUP_DELAY);
Future<List<InetAddress>> resolveFuture = resolver.resolveAll("foo");
assertFalse(resolveFuture.isDone());
verify(primaryResolver, times(1)).resolveAll("foo");
verify(backupResolver, times(0)).resolveAll("foo");

// Wait 20 milliseconds.
eventLoop.schedule(() -> { }, BACKUP_DELAY, MILLISECONDS).get();

verify(primaryResolver, times(1)).resolveAll("foo");
verify(backupResolver, times(1)).resolveAll("foo");
List<InetAddress> result = new ArrayList<>();
backupPromise.setSuccess(result);
assertEquals(result, resolveFuture.get());
}

@Test
void noBackupRequestIfOriginalSucceeds() throws Exception {
DnsNameResolver primaryResolver = mock(DnsNameResolver.class);
DnsNameResolver backupResolver = mock(DnsNameResolver.class);
EventLoop eventLoop = ioExecutor.executor().eventLoopGroup().next();
Promise<List<InetAddress>> primaryPromise = eventLoop.newPromise();
when(primaryResolver.resolveAll("foo")).thenReturn(primaryPromise);
when(backupResolver.resolveAll("foo")).thenReturn(eventLoop.newPromise());

DefaultDnsClient.DnsNameResolverDelegate resolver = new DefaultDnsClient.BackupRequestResolver(
primaryResolver, backupResolver, eventLoop, BACKUP_DELAY);
Future<List<InetAddress>> resolveFuture = resolver.resolveAll("foo");
assertFalse(resolveFuture.isDone());
List<InetAddress> result = new ArrayList<>();
primaryPromise.trySuccess(result);
assertEquals(result, resolveFuture.get(BACKUP_DELAY, SECONDS));
// Wait for the timeout duration to be sure we only get one call.
eventLoop.schedule(() -> { }, BACKUP_DELAY, MILLISECONDS).get();
verify(primaryResolver, times(1)).resolveAll("foo");
verify(backupResolver, times(0)).resolveAll("foo");
}

@Test
void initialFailureWillNotResultInBackup() throws Exception {
DnsNameResolver primaryResolver = mock(DnsNameResolver.class);
DnsNameResolver backupResolver = mock(DnsNameResolver.class);
EventLoop eventLoop = ioExecutor.executor().eventLoopGroup().next();
Promise<List<InetAddress>> primaryPromise = eventLoop.newPromise();
when(primaryResolver.resolveAll("foo")).thenReturn(primaryPromise);
when(backupResolver.resolveAll("foo")).thenReturn(eventLoop.newPromise());

DefaultDnsClient.DnsNameResolverDelegate resolver = new DefaultDnsClient.BackupRequestResolver(
primaryResolver, backupResolver, eventLoop, BACKUP_DELAY);
Future<List<InetAddress>> resolveFuture = resolver.resolveAll("foo");
assertFalse(resolveFuture.isDone());

primaryPromise.tryFailure(new Exception("so sad"));
assertThrows(ExecutionException.class, resolveFuture::get);
// Wait for the timeout duration to be sure we only get one call.
eventLoop.schedule(() -> { }, BACKUP_DELAY, MILLISECONDS).get();
verify(primaryResolver, times(1)).resolveAll("foo");
verify(backupResolver, times(0)).resolveAll("foo");
}

void testTimeout(Duration queryTimeout, Duration resolutionTimeout, RecordType recordType) throws Exception {
setup(builder -> builder
.queryTimeout(queryTimeout)
Expand Down