Skip to content

Commit eae16b2

Browse files
authored
unwrap ForwardingSubchannel during Picks (#12658)
This PR ensures that Load Balancing (LB) policies unwrap `ForwardingSubchannel` instances before returning them in a `PickResult`. **Rationale:** Currently, the identity of a subchannel is "awkward" because decorators break object identity. This forces the core channel to use internal workarounds like `getInternalSubchannel()` to find the underlying implementation. Removing these wrappers during the pick process is a critical prerequisite for deleting Subchannel Attributes. By enforcing unwrapping, `ManagedChannelImpl` can rely on the fact that a returned subchannel is the same instance it originally created. This allows the channel to use strongly-typed fields for state management (via "blind casting") rather than abusing attributes to re-discover information that should already be known. This also paves the way for the eventual removal of the `getInternalSubchannel()` internal API. **New APIs:** To ensure we don't "drop data on the floor" during the unwrapping process, this PR adds two new non-static APIs to PickResult: - copyWithSubchannel() - copyWithStreamTracerFactory() Unlike static factory methods, these instance methods follow a "copy-and-update" pattern that preserves all existing pick-level metadata (such as authority overrides or drop status) while only swapping the specific field required.
1 parent d9320ee commit eae16b2

File tree

11 files changed

+244
-58
lines changed

11 files changed

+244
-58
lines changed

api/src/main/java/io/grpc/LoadBalancer.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,8 @@ private PickResult(
632632
* stream is created at all in some cases.
633633
* @since 1.3.0
634634
*/
635+
// TODO(shivaspeaks): Need to deprecate old APIs and create new ones,
636+
// per https://github.com/grpc/grpc-java/issues/12662.
635637
public static PickResult withSubchannel(
636638
Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory) {
637639
return new PickResult(
@@ -661,6 +663,28 @@ public static PickResult withSubchannel(Subchannel subchannel) {
661663
return withSubchannel(subchannel, null);
662664
}
663665

666+
/**
667+
* Creates a new {@code PickResult} with the given {@code subchannel},
668+
* but retains all other properties from this {@code PickResult}.
669+
*
670+
* @since 1.80.0
671+
*/
672+
public PickResult copyWithSubchannel(Subchannel subchannel) {
673+
return new PickResult(checkNotNull(subchannel, "subchannel"), streamTracerFactory,
674+
status, drop, authorityOverride);
675+
}
676+
677+
/**
678+
* Creates a new {@code PickResult} with the given {@code streamTracerFactory},
679+
* but retains all other properties from this {@code PickResult}.
680+
*
681+
* @since 1.80.0
682+
*/
683+
public PickResult copyWithStreamTracerFactory(
684+
@Nullable ClientStreamTracer.Factory streamTracerFactory) {
685+
return new PickResult(subchannel, streamTracerFactory, status, drop, authorityOverride);
686+
}
687+
664688
/**
665689
* A decision to report a connectivity error to the RPC. If the RPC is {@link
666690
* CallOptions#withWaitForReady wait-for-ready}, it will stay buffered. Otherwise, it will fail

api/src/test/java/io/grpc/LoadBalancerTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ public void pickResult_withSubchannelAndTracer() {
6464
assertThat(result.isDrop()).isFalse();
6565
}
6666

67+
@Test
68+
public void pickResult_withSubchannelReplacement() {
69+
PickResult result = PickResult.withSubchannel(subchannel, tracerFactory)
70+
.copyWithSubchannel(subchannel2);
71+
assertThat(result.getSubchannel()).isSameInstanceAs(subchannel2);
72+
assertThat(result.getStatus()).isSameInstanceAs(Status.OK);
73+
assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory);
74+
assertThat(result.isDrop()).isFalse();
75+
}
76+
77+
@Test
78+
public void pickResult_withStreamTracerFactory() {
79+
PickResult result = PickResult.withSubchannel(subchannel)
80+
.copyWithStreamTracerFactory(tracerFactory);
81+
assertThat(result.getSubchannel()).isSameInstanceAs(subchannel);
82+
assertThat(result.getStatus()).isSameInstanceAs(Status.OK);
83+
assertThat(result.getStreamTracerFactory()).isSameInstanceAs(tracerFactory);
84+
assertThat(result.isDrop()).isFalse();
85+
}
86+
6787
@Test
6888
public void pickResult_withNoResult() {
6989
PickResult result = PickResult.withNoResult();

services/src/main/java/io/grpc/protobuf/services/HealthCheckingLoadBalancerFactory.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,30 @@ void setHealthCheckedService(@Nullable String service) {
144144
public String toString() {
145145
return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString();
146146
}
147+
148+
@Override
149+
public void updateBalancingState(
150+
io.grpc.ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) {
151+
delegate().updateBalancingState(newState, new HealthCheckPicker(newPicker));
152+
}
153+
154+
private final class HealthCheckPicker extends LoadBalancer.SubchannelPicker {
155+
private final LoadBalancer.SubchannelPicker delegate;
156+
157+
HealthCheckPicker(LoadBalancer.SubchannelPicker delegate) {
158+
this.delegate = delegate;
159+
}
160+
161+
@Override
162+
public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) {
163+
LoadBalancer.PickResult result = delegate.pickSubchannel(args);
164+
LoadBalancer.Subchannel subchannel = result.getSubchannel();
165+
if (subchannel instanceof SubchannelImpl) {
166+
return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate());
167+
}
168+
return result;
169+
}
170+
}
147171
}
148172

149173
@VisibleForTesting

util/src/main/java/io/grpc/util/HealthProducerHelper.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import com.google.common.annotations.VisibleForTesting;
2424
import io.grpc.Attributes;
25+
import io.grpc.ConnectivityState;
2526
import io.grpc.ConnectivityStateInfo;
2627
import io.grpc.Internal;
2728
import io.grpc.LoadBalancer;
@@ -84,6 +85,31 @@ protected LoadBalancer.Helper delegate() {
8485
return delegate;
8586
}
8687

88+
@Override
89+
public void updateBalancingState(
90+
ConnectivityState newState, LoadBalancer.SubchannelPicker newPicker) {
91+
delegate.updateBalancingState(newState, new HealthProducerPicker(newPicker));
92+
}
93+
94+
private static final class HealthProducerPicker extends LoadBalancer.SubchannelPicker {
95+
private final LoadBalancer.SubchannelPicker delegate;
96+
97+
HealthProducerPicker(LoadBalancer.SubchannelPicker delegate) {
98+
this.delegate = delegate;
99+
}
100+
101+
@Override
102+
public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) {
103+
LoadBalancer.PickResult result = delegate.pickSubchannel(args);
104+
LoadBalancer.Subchannel subchannel = result.getSubchannel();
105+
if (subchannel instanceof HealthProducerSubchannel) {
106+
return result.copyWithSubchannel(
107+
((HealthProducerSubchannel) subchannel).delegate());
108+
}
109+
return result;
110+
}
111+
}
112+
87113
// The parent subchannel in the health check producer LB chain. It duplicates subchannel state to
88114
// both the state listener and health listener.
89115
@VisibleForTesting

util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,14 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
442442

443443
Subchannel subchannel = pickResult.getSubchannel();
444444
if (subchannel != null) {
445-
return PickResult.withSubchannel(subchannel, new ResultCountingClientStreamTracerFactory(
446-
subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY),
447-
pickResult.getStreamTracerFactory()));
445+
EndpointTracker tracker = subchannel.getAttributes().get(ENDPOINT_TRACKER_KEY);
446+
if (subchannel instanceof OutlierDetectionSubchannel) {
447+
subchannel = ((OutlierDetectionSubchannel) subchannel).delegate();
448+
}
449+
return pickResult.copyWithSubchannel(subchannel)
450+
.copyWithStreamTracerFactory(new ResultCountingClientStreamTracerFactory(
451+
tracker,
452+
pickResult.getStreamTracerFactory()));
448453
}
449454

450455
return pickResult;

util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ public void delegatePick() throws Exception {
408408
// Make sure that we can pick the single READY subchannel.
409409
SubchannelPicker picker = pickerCaptor.getAllValues().get(2);
410410
PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class));
411-
Subchannel s = ((OutlierDetectionSubchannel) pickResult.getSubchannel()).delegate();
411+
Subchannel s = pickResult.getSubchannel();
412412
if (s instanceof HealthProducerHelper.HealthProducerSubchannel) {
413413
s = ((HealthProducerHelper.HealthProducerSubchannel) s).delegate();
414414
}

xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -252,42 +252,55 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
252252
args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build();
253253
final Subchannel subchannel = delegate().createSubchannel(args);
254254

255-
return new ForwardingSubchannel() {
256-
@Override
257-
public void start(SubchannelStateListener listener) {
258-
delegate().start(new SubchannelStateListener() {
259-
@Override
260-
public void onSubchannelState(ConnectivityStateInfo newState) {
261-
// Do nothing if LB has been shutdown
262-
if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) {
263-
// Get locality based on the connected address attributes
264-
ClusterLocality updatedClusterLocality = createClusterLocalityFromAttributes(
265-
subchannel.getConnectedAddressAttributes());
266-
ClusterLocality oldClusterLocality = localityAtomicReference
267-
.getAndSet(updatedClusterLocality);
268-
oldClusterLocality.release();
255+
return new ClusterImplSubchannel(subchannel, localityAtomicReference);
256+
}
257+
258+
private final class ClusterImplSubchannel extends ForwardingSubchannel {
259+
private final Subchannel delegate;
260+
private final AtomicReference<ClusterLocality> localityAtomicReference;
261+
262+
private ClusterImplSubchannel(
263+
Subchannel delegate, AtomicReference<ClusterLocality> localityAtomicReference) {
264+
this.delegate = delegate;
265+
this.localityAtomicReference = localityAtomicReference;
266+
}
267+
268+
@Override
269+
public void start(SubchannelStateListener listener) {
270+
delegate().start(
271+
new SubchannelStateListener() {
272+
@Override
273+
public void onSubchannelState(ConnectivityStateInfo newState) {
274+
// Do nothing if LB has been shutdown
275+
if (xdsClient != null && newState.getState().equals(ConnectivityState.READY)) {
276+
// Get locality based on the connected address attributes
277+
ClusterLocality updatedClusterLocality =
278+
createClusterLocalityFromAttributes(
279+
delegate.getConnectedAddressAttributes());
280+
ClusterLocality oldClusterLocality =
281+
localityAtomicReference.getAndSet(updatedClusterLocality);
282+
oldClusterLocality.release();
283+
}
284+
listener.onSubchannelState(newState);
269285
}
270-
listener.onSubchannelState(newState);
271-
}
272-
});
273-
}
286+
});
287+
}
274288

275-
@Override
276-
public void shutdown() {
277-
localityAtomicReference.get().release();
278-
delegate().shutdown();
279-
}
289+
@Override
290+
public void shutdown() {
291+
localityAtomicReference.get().release();
292+
delegate().shutdown();
293+
}
280294

281-
@Override
282-
public void updateAddresses(List<EquivalentAddressGroup> addresses) {
283-
delegate().updateAddresses(withAdditionalAttributes(addresses));
284-
}
295+
@Override
296+
public void updateAddresses(List<EquivalentAddressGroup> addresses) {
297+
delegate().updateAddresses(withAdditionalAttributes(addresses));
298+
}
285299

286-
@Override
287-
protected Subchannel delegate() {
288-
return subchannel;
289-
}
290-
};
300+
@Override
301+
protected Subchannel delegate() {
302+
return delegate;
303+
}
291304
}
292305

293306
private List<EquivalentAddressGroup> withAdditionalAttributes(
@@ -412,6 +425,11 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
412425
}
413426
PickResult result = delegate.pickSubchannel(args);
414427
if (result.getStatus().isOk() && result.getSubchannel() != null) {
428+
Subchannel subchannel = result.getSubchannel();
429+
if (subchannel instanceof ClusterImplLbHelper.ClusterImplSubchannel) {
430+
subchannel = ((ClusterImplLbHelper.ClusterImplSubchannel) subchannel).delegate();
431+
result = result.copyWithSubchannel(subchannel);
432+
}
415433
if (enableCircuitBreaking) {
416434
if (inFlights.get() >= maxConcurrentRequests) {
417435
if (dropStats != null) {
@@ -437,8 +455,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
437455
stats, inFlights, result.getStreamTracerFactory());
438456
ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
439457
.newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats));
440-
result = PickResult.withSubchannel(result.getSubchannel(),
441-
orcaTracerFactory);
458+
result = result.copyWithStreamTracerFactory(orcaTracerFactory);
442459
}
443460
}
444461
if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null

xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,12 +508,15 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
508508
if (subchannel == null) {
509509
return pickResult;
510510
}
511+
512+
subchannel = ((WrrSubchannel) subchannel).delegate();
511513
if (!enableOobLoadReport) {
512-
return PickResult.withSubchannel(subchannel,
513-
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
514-
reportListeners.get(pick)));
514+
return pickResult.copyWithSubchannel(subchannel)
515+
.copyWithStreamTracerFactory(
516+
OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory(
517+
reportListeners.get(pick)));
515518
} else {
516-
return PickResult.withSubchannel(subchannel);
519+
return pickResult.copyWithSubchannel(subchannel);
517520
}
518521
}
519522

xds/src/main/java/io/grpc/xds/orca/OrcaOobUtil.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@
3636
import io.grpc.ChannelLogger;
3737
import io.grpc.ChannelLogger.ChannelLogLevel;
3838
import io.grpc.ClientCall;
39+
import io.grpc.ConnectivityState;
3940
import io.grpc.ConnectivityStateInfo;
4041
import io.grpc.ExperimentalApi;
4142
import io.grpc.LoadBalancer;
4243
import io.grpc.LoadBalancer.CreateSubchannelArgs;
4344
import io.grpc.LoadBalancer.Helper;
45+
import io.grpc.LoadBalancer.PickResult;
46+
import io.grpc.LoadBalancer.PickSubchannelArgs;
4447
import io.grpc.LoadBalancer.Subchannel;
48+
import io.grpc.LoadBalancer.SubchannelPicker;
4549
import io.grpc.LoadBalancer.SubchannelStateListener;
4650
import io.grpc.Metadata;
4751
import io.grpc.Status;
@@ -236,6 +240,30 @@ protected Helper delegate() {
236240
return delegate;
237241
}
238242

243+
@Override
244+
public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
245+
delegate.updateBalancingState(newState, new OrcaOobPicker(newPicker));
246+
}
247+
248+
@VisibleForTesting
249+
static final class OrcaOobPicker extends SubchannelPicker {
250+
final SubchannelPicker delegate;
251+
252+
OrcaOobPicker(SubchannelPicker delegate) {
253+
this.delegate = delegate;
254+
}
255+
256+
@Override
257+
public PickResult pickSubchannel(PickSubchannelArgs args) {
258+
PickResult result = delegate.pickSubchannel(args);
259+
Subchannel subchannel = result.getSubchannel();
260+
if (subchannel instanceof SubchannelImpl) {
261+
return result.copyWithSubchannel(((SubchannelImpl) subchannel).delegate());
262+
}
263+
return result;
264+
}
265+
}
266+
239267
@Override
240268
public Subchannel createSubchannel(CreateSubchannelArgs args) {
241269
syncContext.throwIfNotInThisSynchronizationContext();

0 commit comments

Comments
 (0)