Skip to content

CASSJAVA-97: Let users inject an ID for each request and write to the custom payload #2037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: 4.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/revapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -7386,6 +7386,11 @@
"old": "method <T extends java.lang.Number> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(java.lang.Class<T>)",
"new": "method <T> com.datastax.oss.driver.api.core.type.reflect.GenericType<com.datastax.oss.driver.api.core.data.CqlVector<T>> com.datastax.oss.driver.api.core.type.reflect.GenericType<T>::vectorOf(java.lang.Class<T>)",
"justification": "JAVA-3143: Extend driver vector support to arbitrary subtypes and fix handling of variable length types (OSS C* 5.0)"
},
{
"code": "java.method.addedToInterface",
"new": "method com.datastax.oss.driver.api.core.tracker.DistributedTraceIdGenerator com.datastax.oss.driver.api.core.context.DriverContext::getDistributedTraceIdGenerator()",
"justification": "DistributedRequestID"
}
]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,21 @@ public enum DefaultDriverOption implements DriverOption {
*
* <p>Value-type: boolean
*/
SSL_ALLOW_DNS_REVERSE_LOOKUP_SAN("advanced.ssl-engine-factory.allow-dns-reverse-lookup-san");
SSL_ALLOW_DNS_REVERSE_LOOKUP_SAN("advanced.ssl-engine-factory.allow-dns-reverse-lookup-san"),

/**
* The class of session-wide component that generates distributed trace IDs.
*
* <p>Value-type: {@link String}
*/
DISTRIBUTED_TRACE_ID_GENERATOR_CLASS("advanced.distributed-tracing.id-generator.class"),

/**
* If not empty, the driver will write the distributed trace ID to this key in the custom payload
*
* <p>Value-type: {@link String}
*/
DISTRIBUTED_TRACE_ID_CUSTOM_PAYLOAD_KEY("advanced.distributed-tracing.custom-payload-with-key");

private final String path;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ protected static void fillWithDriverDefaults(OptionsMap map) {
map.put(TypedDriverOption.REQUEST_TRACE_INTERVAL, Duration.ofMillis(3));
map.put(TypedDriverOption.REQUEST_TRACE_CONSISTENCY, "ONE");
map.put(TypedDriverOption.REQUEST_LOG_WARNINGS, true);
map.put(
TypedDriverOption.DISTRIBUTED_TRACE_ID_GENERATOR_CLASS, "NoopDistributedTraceIdGenerator");
map.put(TypedDriverOption.DISTRIBUTED_TRACE_ID_CUSTOM_PAYLOAD_KEY, "");
map.put(TypedDriverOption.GRAPH_PAGING_ENABLED, "AUTO");
map.put(TypedDriverOption.GRAPH_CONTINUOUS_PAGING_PAGE_SIZE, requestPageSize);
map.put(TypedDriverOption.GRAPH_CONTINUOUS_PAGING_MAX_PAGES, continuousMaxPages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,18 @@ public String toString() {
new TypedDriverOption<>(
DefaultDriverOption.REQUEST_TRACKER_CLASSES, GenericType.listOf(String.class));

/** The class of a session-wide component that generates distributed trace IDs. */
public static final TypedDriverOption<String> DISTRIBUTED_TRACE_ID_GENERATOR_CLASS =
new TypedDriverOption<>(
DefaultDriverOption.DISTRIBUTED_TRACE_ID_GENERATOR_CLASS, GenericType.STRING);

/**
* If not empty, the driver will write the distributed trace ID to this key in the custom payload
*/
public static final TypedDriverOption<String> DISTRIBUTED_TRACE_ID_CUSTOM_PAYLOAD_KEY =
new TypedDriverOption<>(
DefaultDriverOption.DISTRIBUTED_TRACE_ID_CUSTOM_PAYLOAD_KEY, GenericType.STRING);

/** Whether to log successful requests. */
public static final TypedDriverOption<Boolean> REQUEST_LOGGER_SUCCESS_ENABLED =
new TypedDriverOption<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import com.datastax.oss.driver.api.core.specex.SpeculativeExecutionPolicy;
import com.datastax.oss.driver.api.core.ssl.SslEngineFactory;
import com.datastax.oss.driver.api.core.time.TimestampGenerator;
import com.datastax.oss.driver.api.core.tracker.DistributedTraceIdGenerator;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import edu.umd.cs.findbugs.annotations.NonNull;
import java.util.Map;
Expand Down Expand Up @@ -139,6 +140,10 @@ default SpeculativeExecutionPolicy getSpeculativeExecutionPolicy(@NonNull String
@NonNull
RequestTracker getRequestTracker();

/** @return The driver's distributed trace ID generator; never {@code null}. */
@NonNull
DistributedTraceIdGenerator getDistributedTraceIdGenerator();

/** @return The driver's request throttler; never {@code null}. */
@NonNull
RequestThrottler getRequestThrottler();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.datastax.oss.driver.api.core.metadata.NodeStateListener;
import com.datastax.oss.driver.api.core.metadata.schema.SchemaChangeListener;
import com.datastax.oss.driver.api.core.ssl.SslEngineFactory;
import com.datastax.oss.driver.api.core.tracker.DistributedTraceIdGenerator;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry;
Expand Down Expand Up @@ -59,6 +60,7 @@ public static Builder builder() {
private final NodeStateListener nodeStateListener;
private final SchemaChangeListener schemaChangeListener;
private final RequestTracker requestTracker;
private final DistributedTraceIdGenerator distributedTraceIdGenerator;
private final Map<String, String> localDatacenters;
private final Map<String, Predicate<Node>> nodeFilters;
private final Map<String, NodeDistanceEvaluator> nodeDistanceEvaluators;
Expand All @@ -77,6 +79,7 @@ private ProgrammaticArguments(
@Nullable NodeStateListener nodeStateListener,
@Nullable SchemaChangeListener schemaChangeListener,
@Nullable RequestTracker requestTracker,
@Nullable DistributedTraceIdGenerator distributedTraceIdGenerator,
@NonNull Map<String, String> localDatacenters,
@NonNull Map<String, Predicate<Node>> nodeFilters,
@NonNull Map<String, NodeDistanceEvaluator> nodeDistanceEvaluators,
Expand All @@ -94,6 +97,7 @@ private ProgrammaticArguments(
this.nodeStateListener = nodeStateListener;
this.schemaChangeListener = schemaChangeListener;
this.requestTracker = requestTracker;
this.distributedTraceIdGenerator = distributedTraceIdGenerator;
this.localDatacenters = localDatacenters;
this.nodeFilters = nodeFilters;
this.nodeDistanceEvaluators = nodeDistanceEvaluators;
Expand Down Expand Up @@ -128,6 +132,11 @@ public RequestTracker getRequestTracker() {
return requestTracker;
}

@Nullable
public DistributedTraceIdGenerator getDistributedTraceIdGenerator() {
return distributedTraceIdGenerator;
}

@NonNull
public Map<String, String> getLocalDatacenters() {
return localDatacenters;
Expand Down Expand Up @@ -196,6 +205,7 @@ public static class Builder {
private NodeStateListener nodeStateListener;
private SchemaChangeListener schemaChangeListener;
private RequestTracker requestTracker;
private DistributedTraceIdGenerator distributedTraceIdGenerator;
private ImmutableMap.Builder<String, String> localDatacentersBuilder = ImmutableMap.builder();
private final ImmutableMap.Builder<String, Predicate<Node>> nodeFiltersBuilder =
ImmutableMap.builder();
Expand Down Expand Up @@ -294,6 +304,13 @@ public Builder addRequestTracker(@NonNull RequestTracker requestTracker) {
return this;
}

@NonNull
public Builder withDistributedTraceIdGenerator(
@Nullable DistributedTraceIdGenerator distributedTraceIdGenerator) {
this.distributedTraceIdGenerator = distributedTraceIdGenerator;
return this;
}

@NonNull
public Builder withLocalDatacenter(
@NonNull String profileName, @NonNull String localDatacenter) {
Expand Down Expand Up @@ -417,6 +434,7 @@ public ProgrammaticArguments build() {
nodeStateListener,
schemaChangeListener,
requestTracker,
distributedTraceIdGenerator,
localDatacentersBuilder.build(),
nodeFiltersBuilder.build(),
nodeDistanceEvaluatorsBuilder.build(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.datastax.oss.driver.api.core.metadata.schema.SchemaChangeListener;
import com.datastax.oss.driver.api.core.ssl.ProgrammaticSslEngineFactory;
import com.datastax.oss.driver.api.core.ssl.SslEngineFactory;
import com.datastax.oss.driver.api.core.tracker.DistributedTraceIdGenerator;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.codec.registry.MutableCodecRegistry;
Expand Down Expand Up @@ -318,6 +319,19 @@ public SelfT addRequestTracker(@NonNull RequestTracker requestTracker) {
return self;
}

/**
* Registers a distributed trace ID generator The driver will use the distributed trace ID in the
* logs So that users can correlate logs about the same request from different loggers.
*
* @param distributedTraceIdGenerator
* @return
*/
@NonNull
public SelfT withDistributedTraceIdGenerator(
@NonNull DistributedTraceIdGenerator distributedTraceIdGenerator) {
this.programmaticArgumentsBuilder.withDistributedTraceIdGenerator(distributedTraceIdGenerator);
return self;
}
/**
* Registers an authentication provider to use with the session.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.api.core.tracker;

import com.datastax.oss.driver.api.core.session.Request;
import edu.umd.cs.findbugs.annotations.NonNull;

public interface DistributedTraceIdGenerator {
String getSessionRequestId(@NonNull Request statement);

String getNodeRequestId(@NonNull Request statement, @NonNull String sessionRequestId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import com.datastax.oss.driver.api.core.specex.SpeculativeExecutionPolicy;
import com.datastax.oss.driver.api.core.ssl.SslEngineFactory;
import com.datastax.oss.driver.api.core.time.TimestampGenerator;
import com.datastax.oss.driver.api.core.tracker.DistributedTraceIdGenerator;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import com.datastax.oss.driver.api.core.type.codec.TypeCodec;
import com.datastax.oss.driver.api.core.type.codec.registry.CodecRegistry;
Expand Down Expand Up @@ -221,6 +222,7 @@ public class DefaultDriverContext implements InternalDriverContext {
private final LazyReference<NodeStateListener> nodeStateListenerRef;
private final LazyReference<SchemaChangeListener> schemaChangeListenerRef;
private final LazyReference<RequestTracker> requestTrackerRef;
private final LazyReference<DistributedTraceIdGenerator> distributedTraceIdGeneratorRef;
private final LazyReference<Optional<AuthProvider>> authProviderRef;
private final LazyReference<List<LifecycleListener>> lifecycleListenersRef =
new LazyReference<>("lifecycleListeners", this::buildLifecycleListeners, cycleDetector);
Expand Down Expand Up @@ -282,6 +284,13 @@ public DefaultDriverContext(
this.requestTrackerRef =
new LazyReference<>(
"requestTracker", () -> buildRequestTracker(requestTrackerFromBuilder), cycleDetector);
this.distributedTraceIdGeneratorRef =
new LazyReference<>(
"distributedTraceIdGenerator",
() ->
buildDistributedTraceIdGenerator(
programmaticArguments.getDistributedTraceIdGenerator()),
cycleDetector);
this.sslEngineFactoryRef =
new LazyReference<>(
"sslEngineFactory",
Expand Down Expand Up @@ -709,6 +718,23 @@ protected RequestTracker buildRequestTracker(RequestTracker requestTrackerFromBu
}
}

protected DistributedTraceIdGenerator buildDistributedTraceIdGenerator(
DistributedTraceIdGenerator distributedTraceIdGenerator) {
return (distributedTraceIdGenerator != null)
? distributedTraceIdGenerator
: Reflection.buildFromConfig(
this,
DefaultDriverOption.DISTRIBUTED_TRACE_ID_GENERATOR_CLASS,
DistributedTraceIdGenerator.class,
"com.datastax.oss.driver.internal.core.tracker")
.orElseThrow(
() ->
new IllegalArgumentException(
String.format(
"Missing distributed trace ID generator, check your configuration (%s)",
DefaultDriverOption.DISTRIBUTED_TRACE_ID_GENERATOR_CLASS)));
}

protected Optional<AuthProvider> buildAuthProvider(AuthProvider authProviderFromBuilder) {
return (authProviderFromBuilder != null)
? Optional.of(authProviderFromBuilder)
Expand Down Expand Up @@ -973,6 +999,12 @@ public RequestTracker getRequestTracker() {
return requestTrackerRef.get();
}

@NonNull
@Override
public DistributedTraceIdGenerator getDistributedTraceIdGenerator() {
return distributedTraceIdGeneratorRef.get();
}

@Nullable
@Override
public String getLocalDatacenter(@NonNull String profileName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import com.datastax.oss.driver.api.core.servererrors.WriteTimeoutException;
import com.datastax.oss.driver.api.core.session.throttling.RequestThrottler;
import com.datastax.oss.driver.api.core.session.throttling.Throttled;
import com.datastax.oss.driver.api.core.tracker.DistributedTraceIdGenerator;
import com.datastax.oss.driver.api.core.tracker.RequestTracker;
import com.datastax.oss.driver.internal.core.adminrequest.ThrottledAdminRequestHandler;
import com.datastax.oss.driver.internal.core.adminrequest.UnexpectedResponseException;
Expand Down Expand Up @@ -78,8 +79,10 @@
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
Expand Down Expand Up @@ -125,12 +128,14 @@ public class CqlRequestHandler implements Throttled {
private final List<NodeResponseCallback> inFlightCallbacks;
private final RequestThrottler throttler;
private final RequestTracker requestTracker;
private final DistributedTraceIdGenerator distributedTraceIdGenerator;
private final SessionMetricUpdater sessionMetricUpdater;
private final DriverExecutionProfile executionProfile;

// The errors on the nodes that were already tried (lazily initialized on the first error).
// We don't use a map because nodes can appear multiple times.
private volatile List<Map.Entry<Node, Throwable>> errors;
private final String customPayloadKey;

protected CqlRequestHandler(
Statement<?> statement,
Expand All @@ -139,7 +144,8 @@ protected CqlRequestHandler(
String sessionLogPrefix) {

this.startTimeNanos = System.nanoTime();
this.logPrefix = sessionLogPrefix + "|" + this.hashCode();
this.distributedTraceIdGenerator = context.getDistributedTraceIdGenerator();
this.logPrefix = this.distributedTraceIdGenerator.getSessionRequestId(statement);
LOG.trace("[{}] Creating new handler for request {}", logPrefix, statement);

this.initialStatement = statement;
Expand Down Expand Up @@ -170,6 +176,11 @@ protected CqlRequestHandler(

this.timer = context.getNettyOptions().getTimer();
this.executionProfile = Conversions.resolveExecutionProfile(initialStatement, context);

this.customPayloadKey =
this.executionProfile.getString(
DefaultDriverOption.DISTRIBUTED_TRACE_ID_CUSTOM_PAYLOAD_KEY);

Duration timeout = Conversions.resolveRequestTimeout(statement, executionProfile);
this.scheduledTimeout = scheduleTimeout(timeout);

Expand Down Expand Up @@ -248,6 +259,16 @@ private void sendRequest(
if (result.isDone()) {
return;
}
String nodeRequestId = this.distributedTraceIdGenerator.getNodeRequestId(statement, logPrefix);
if (!this.customPayloadKey.isEmpty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not missing else block here?

// We cannot do statement.getCustomPayload().put() because the default empty map is abstract
// But this will create new Statement instance for every request. We might want to optimize
// this
Map<String, ByteBuffer> existingMap = new HashMap<>(statement.getCustomPayload());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Statement is by design immutable. Maybe a nicer way would be to create method StatementBuilder.from(Statement) where you could create builder again based on statement. The code would look like: StatementBuilder.from(statement).addCustomPayload(...).build().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can copy just the payload, not the whole statement:

    Map<String, ByteBuffer> customPayload = statement.getCustomPayload();
    if (!this.customPayloadKey.isEmpty()) {
      customPayload =
          NullAllowingImmutableMap.<String, ByteBuffer>builder()
              .putAll(customPayload)
              .put(
                  this.customPayloadKey,
                  ByteBuffer.wrap(nodeRequestId.getBytes(StandardCharsets.UTF_8)))
              .build();
    }

Then modify line 307 like so:

       channel
-          .write(message, statement.isTracing(), statement.getCustomPayload(), nodeResponseCallback)
+          .write(message, statement.isTracing(), customPayload, nodeResponseCallback)
           .addListener(nodeResponseCallback);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This solves the concurrency problem, but it also means the subsequent setFinalError(statement...), NodeResponseCallback(statement,...), and RequestTracker invocations do not have the statement with the actual custom payload.

existingMap.put(
this.customPayloadKey, ByteBuffer.wrap(nodeRequestId.getBytes(StandardCharsets.UTF_8)));
statement = statement.setCustomPayload(existingMap);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding custom payload here is not thread-safe. If client application executes the same statement instance multiple times concurrently (not a good use-case, but still possible), we do not guarantee how this map will be changed. Maybe indeed, there is no other way than make a shallow copy of the statement. Will think about it.

  /**
   * Sets the custom payload to use for execution.
   *
   * <p>All the driver's built-in statement implementations are immutable, and return a new instance
   * from this method. However custom implementations may choose to be mutable and return the same
   * instance.
   *
   * <p>Note that it's your responsibility to provide a thread-safe map. This can be achieved with a
   * concurrent or immutable implementation, or by making it effectively immutable (meaning that
   * it's never modified after being set on the statement).
   */
  @NonNull
  @CheckReturnValue
  SelfT setCustomPayload(@NonNull Map<String, ByteBuffer> newCustomPayload);

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the wrong place to do this. In most cases we haven't even selected the node yet; note that this happens immediately below where we poll the query plan if no node is explicitly set in the request. Assuming we update the request ID generation logic to correctly account for the target node the setting of custom payload fields should happen after we determine which node we're actually sending to.

Node node = retriedNode;
DriverChannel channel = null;
if (node == null || (channel = session.getChannel(node, logPrefix)) == null) {
Expand Down Expand Up @@ -276,7 +297,7 @@ private void sendRequest(
currentExecutionIndex,
retryCount,
scheduleNextExecution,
logPrefix);
nodeRequestId);
Message message = Conversions.toMessage(statement, executionProfile, context);
channel
.write(message, statement.isTracing(), statement.getCustomPayload(), nodeResponseCallback)
Expand Down Expand Up @@ -489,7 +510,7 @@ private NodeResponseCallback(
this.execution = execution;
this.retryCount = retryCount;
this.scheduleNextExecution = scheduleNextExecution;
this.logPrefix = logPrefix + "|" + execution;
this.logPrefix = logPrefix;
}

// this gets invoked once the write completes.
Expand Down
Loading