diff --git a/instrumentation/rxjava/rxjava-2.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/TracingAssembly.java b/instrumentation/rxjava/rxjava-2.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/TracingAssembly.java index 30a4292366e3..71596414b08d 100644 --- a/instrumentation/rxjava/rxjava-2.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/TracingAssembly.java +++ b/instrumentation/rxjava/rxjava-2.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/TracingAssembly.java @@ -93,6 +93,10 @@ public final class TracingAssembly { private static Function oldOnParallelAssembly; + @GuardedBy("TracingAssembly.class") + @Nullable + private static Function oldScheduleHandler; + @GuardedBy("TracingAssembly.class") private static boolean enabled; @@ -118,6 +122,8 @@ public void enable() { enableObservable(); + enableWrappedScheduleHandler(); + enableCompletable(); enableSingle(); @@ -142,6 +148,8 @@ public void disable() { disableObservable(); + disableWrappedScheduleHandler(); + disableCompletable(); disableSingle(); @@ -219,6 +227,25 @@ private static void enableObservable() { } } + @GuardedBy("TracingAssembly.class") + private static void enableWrappedScheduleHandler() { + oldScheduleHandler = RxJavaPlugins.getScheduleHandler(); + RxJavaPlugins.setScheduleHandler( + runnable -> { + Context context = Context.current(); + Runnable wrappedRunnable = + () -> { + try (Scope ignored = context.makeCurrent()) { + runnable.run(); + } + }; + // If there was a previous handler, apply it to our wrapped runnable + return oldScheduleHandler != null + ? oldScheduleHandler.apply(wrappedRunnable) + : wrappedRunnable; + }); + } + @GuardedBy("TracingAssembly.class") @SuppressWarnings({"rawtypes", "unchecked"}) private static void enableSingle() { @@ -274,6 +301,12 @@ private static void disableObservable() { oldOnObservableSubscribe = null; } + @GuardedBy("TracingAssembly.class") + private static void disableWrappedScheduleHandler() { + RxJavaPlugins.setScheduleHandler(oldScheduleHandler); + oldScheduleHandler = null; + } + @GuardedBy("TracingAssembly.class") private static void disableCompletable() { RxJavaPlugins.setOnCompletableSubscribe(oldOnCompletableSubscribe); diff --git a/instrumentation/rxjava/rxjava-2.0/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/AbstractRxJava2Test.java b/instrumentation/rxjava/rxjava-2.0/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/AbstractRxJava2Test.java index 1be2f4112448..4aaf66115dd9 100644 --- a/instrumentation/rxjava/rxjava-2.0/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/AbstractRxJava2Test.java +++ b/instrumentation/rxjava/rxjava-2.0/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v2_0/AbstractRxJava2Test.java @@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanKind; import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension; import io.opentelemetry.instrumentation.testing.util.ThrowingRunnable; @@ -23,12 +24,19 @@ import io.reactivex.Observable; import io.reactivex.Scheduler; import io.reactivex.Single; +import io.reactivex.disposables.Disposable; +import io.reactivex.functions.Function; import io.reactivex.internal.operators.flowable.FlowablePublish; import io.reactivex.internal.operators.observable.ObservablePublish; +import io.reactivex.plugins.RxJavaPlugins; import io.reactivex.schedulers.Schedulers; import java.util.Comparator; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Stream; import org.junit.jupiter.api.Test; @@ -91,7 +99,7 @@ public void onComplete() {} } @Test - public void basicMaybe() { + void basicMaybe() { int result = createParentSpan(() -> Maybe.just(1).map(this::addOne).blockingGet()); assertThat(result).isEqualTo(2); testing() @@ -106,7 +114,7 @@ public void basicMaybe() { } @Test - public void twoOperationsMaybe() { + void twoOperationsMaybe() { int result = createParentSpan(() -> Maybe.just(2).map(this::addOne).map(this::addOne).blockingGet()); assertThat(result).isEqualTo(4); @@ -126,7 +134,7 @@ public void twoOperationsMaybe() { } @Test - public void delayedMaybe() { + void delayedMaybe() { int result = createParentSpan( () -> Maybe.just(3).delay(100, TimeUnit.MILLISECONDS).map(this::addOne).blockingGet()); @@ -143,7 +151,7 @@ public void delayedMaybe() { } @Test - public void delayedTwiceMaybe() { + void delayedTwiceMaybe() { int result = createParentSpan( () -> @@ -170,7 +178,7 @@ public void delayedTwiceMaybe() { } @Test - public void basicFlowable() { + void basicFlowable() { Iterable result = createParentSpan( () -> Flowable.fromIterable(asList(5, 6)).map(this::addOne).toList().blockingGet()); @@ -191,7 +199,7 @@ public void basicFlowable() { } @Test - public void twoOperationsFlowable() { + void twoOperationsFlowable() { List result = createParentSpan( () -> @@ -225,7 +233,7 @@ public void twoOperationsFlowable() { } @Test - public void delayedFlowable() { + void delayedFlowable() { List result = createParentSpan( () -> @@ -251,7 +259,7 @@ public void delayedFlowable() { } @Test - public void delayedTwiceFlowable() { + void delayedTwiceFlowable() { List result = createParentSpan( () -> @@ -287,7 +295,7 @@ public void delayedTwiceFlowable() { } @Test - public void maybeFromCallable() { + void maybeFromCallable() { Integer result = createParentSpan( () -> Maybe.fromCallable(() -> addOne(10)).map(this::addOne).blockingGet()); @@ -308,7 +316,7 @@ public void maybeFromCallable() { } @Test - public void basicSingle() { + void basicSingle() { Integer result = createParentSpan(() -> Single.just(0).map(this::addOne).blockingGet()); assertThat(result).isEqualTo(1); testing() @@ -323,7 +331,7 @@ public void basicSingle() { } @Test - public void basicObservable() { + void basicObservable() { List result = createParentSpan(() -> Observable.just(0).map(this::addOne).toList().blockingGet()); assertThat(result).contains(1); @@ -339,7 +347,38 @@ public void basicObservable() { } @Test - public void connectableFlowable() { + void observableFromCallableContextPropagation() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference traceId = new AtomicReference<>(); + AtomicReference innerObservableTraceId = new AtomicReference<>(); + AtomicReference endObservableTraceId = new AtomicReference<>(); + + createParentSpan( + () -> { + traceId.set(Span.current().getSpanContext().getTraceId()); + Disposable unused = + Observable.fromCallable( + () -> { + innerObservableTraceId.set(Span.current().getSpanContext().getTraceId()); + return "success"; + }) + .subscribeOn(Schedulers.io()) + .observeOn(Schedulers.single()) + .subscribe( + data -> { + endObservableTraceId.set(Span.current().getSpanContext().getTraceId()); + latch.countDown(); + }); + assertThat(unused).isNotNull(); + }); + + latch.await(); + assertThat(innerObservableTraceId.get()).isEqualTo(traceId.get()); + assertThat(endObservableTraceId.get()).isEqualTo(traceId.get()); + } + + @Test + void connectableFlowable() { List result = createParentSpan( () -> @@ -361,7 +400,7 @@ public void connectableFlowable() { } @Test - public void connectableObservable() { + void connectableObservable() { List result = createParentSpan( () -> @@ -383,7 +422,7 @@ public void connectableObservable() { } @Test - public void maybeError() { + void maybeError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Maybe.error(error).blockingGet())) .isEqualTo(error); @@ -395,7 +434,7 @@ public void maybeError() { } @Test - public void flowableError() { + void flowableError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Flowable.error(error)).toList().blockingGet()) .isEqualTo(error); @@ -407,7 +446,7 @@ public void flowableError() { } @Test - public void singleError() { + void singleError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Single.error(error)).blockingGet()) .isEqualTo(error); @@ -419,7 +458,7 @@ public void singleError() { } @Test - public void observableError() { + void observableError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Observable.error(error).toList().blockingGet())) .isEqualTo(error); @@ -431,7 +470,7 @@ public void observableError() { } @Test - public void completableError() { + void completableError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy( () -> createParentSpan(() -> Completable.error(error).toMaybe().blockingGet())) @@ -444,7 +483,7 @@ public void completableError() { } @Test - public void basicMaybeFailure() { + void basicMaybeFailure() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy( () -> @@ -470,7 +509,7 @@ public void basicMaybeFailure() { } @Test - public void basicFlowableFailure() { + void basicFlowableFailure() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy( () -> @@ -497,7 +536,7 @@ public void basicFlowableFailure() { } @Test - public void basicMaybeCancel() { + void basicMaybeCancel() { createParentSpan( () -> Maybe.just(1).toFlowable().map(this::addOne).subscribe(CancellingSubscriber.INSTANCE)); @@ -509,7 +548,7 @@ public void basicMaybeCancel() { } @Test - public void basicFlowableCancel() { + void basicFlowableCancel() { createParentSpan( () -> Flowable.fromIterable(asList(5, 6)) @@ -523,7 +562,7 @@ public void basicFlowableCancel() { } @Test - public void basicSingleCancel() { + void basicSingleCancel() { createParentSpan( () -> Single.just(1).toFlowable().map(this::addOne).subscribe(CancellingSubscriber.INSTANCE)); @@ -535,7 +574,7 @@ public void basicSingleCancel() { } @Test - public void basicCompletableCancel() { + void basicCompletableCancel() { createParentSpan( () -> Completable.fromCallable(() -> 1) @@ -549,7 +588,7 @@ public void basicCompletableCancel() { } @Test - public void basicObservableCancel() { + void basicObservableCancel() { createParentSpan( () -> Observable.just(1) @@ -564,7 +603,7 @@ public void basicObservableCancel() { } @Test - public void basicMaybeChain() { + void basicMaybeChain() { createParentSpan( () -> Maybe.just(1) @@ -593,7 +632,7 @@ public void basicMaybeChain() { } @Test - public void basicFlowableChain() { + void basicFlowableChain() { createParentSpan( () -> Flowable.fromIterable(asList(5, 6)) @@ -631,7 +670,7 @@ public void basicFlowableChain() { // Publisher chain spans have the correct parents from subscription time @Test - public void maybeChainParentSpan() { + void maybeChainParentSpan() { Maybe maybe = Maybe.just(42).map(this::addOne).map(this::addTwo); testing().runWithSpan("trace-parent", () -> maybe.blockingGet()); testing() @@ -650,7 +689,7 @@ public void maybeChainParentSpan() { } @Test - public void maybeChainHasSubscriptionContext() { + void maybeChainHasSubscriptionContext() { Integer result = createParentSpan( () -> { @@ -680,7 +719,7 @@ public void maybeChainHasSubscriptionContext() { } @Test - public void flowableChainHasSubscriptionContext() { + void flowableChainHasSubscriptionContext() { List result = createParentSpan( () -> { @@ -719,7 +758,7 @@ public void flowableChainHasSubscriptionContext() { } @Test - public void singleChainHasSubscriptionContext() { + void singleChainHasSubscriptionContext() { Integer result = createParentSpan( () -> { @@ -749,7 +788,7 @@ public void singleChainHasSubscriptionContext() { } @Test - public void observableChainHasSubscriptionContext() { + void observableChainHasSubscriptionContext() { List result = createParentSpan( () -> { @@ -779,9 +818,9 @@ public void observableChainHasSubscriptionContext() { .hasParent(trace.getSpan(0)))); } - @ParameterizedTest @MethodSource("schedulers") - public void flowableMultiResults(Scheduler scheduler) { + @ParameterizedTest + void flowableMultiResults(Scheduler scheduler) { List result = testing() .runWithSpan( @@ -821,7 +860,7 @@ public void flowableMultiResults(Scheduler scheduler) { @ParameterizedTest @MethodSource("schedulers") - public void maybeMultipleTraceChains(Scheduler scheduler) { + void maybeMultipleTraceChains(Scheduler scheduler) { int iterations = 100; RxJava2ConcurrencyTestHelper.launchAndWait(scheduler, iterations, 60000, testing()); @SuppressWarnings("unchecked") @@ -851,4 +890,53 @@ public void maybeMultipleTraceChains(Scheduler scheduler) { assertions); testing().clearData(); } + + @Test + void scheduleHandlerChainingPreservesExistingHandler() throws InterruptedException { + AtomicInteger customHandlerCallCount = new AtomicInteger(0); + AtomicBoolean customHandlerExecuted = new AtomicBoolean(false); + + Function originalHandler = + RxJavaPlugins.getScheduleHandler(); + Function customHandler = + runnable -> + () -> { + customHandlerCallCount.incrementAndGet(); + customHandlerExecuted.set(true); + runnable.run(); + }; + + try { + RxJavaPlugins.setScheduleHandler(customHandler); + + CountDownLatch latch = new CountDownLatch(1); + AtomicBoolean observableExecuted = new AtomicBoolean(false); + AtomicReference traceId = new AtomicReference<>(); + + createParentSpan( + () -> { + traceId.set(Span.current().getSpanContext().getTraceId()); + Disposable unused = + Observable.fromCallable( + () -> { + observableExecuted.set(true); + return "test"; + }) + .subscribeOn(Schedulers.io()) + .subscribe(result -> latch.countDown()); + }); + + latch.await(); + assertThat(observableExecuted.get()).isTrue(); + assertThat(customHandlerExecuted.get()).isTrue(); + assertThat(customHandlerCallCount.get()).isGreaterThan(0); + + assertThat(traceId.get()).isNotNull(); + assertThat(traceId.get()).isNotEqualTo("00000000000000000000000000000000"); + + } finally { + // Restore original handler to avoid affecting other tests + RxJavaPlugins.setScheduleHandler(originalHandler); + } + } } diff --git a/instrumentation/rxjava/rxjava-3-common/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v3/common/AbstractRxJava3Test.java b/instrumentation/rxjava/rxjava-3-common/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v3/common/AbstractRxJava3Test.java index 36872f47b5f1..d801ca473301 100644 --- a/instrumentation/rxjava/rxjava-3-common/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v3/common/AbstractRxJava3Test.java +++ b/instrumentation/rxjava/rxjava-3-common/testing/src/main/java/io/opentelemetry/instrumentation/rxjava/v3/common/AbstractRxJava3Test.java @@ -11,6 +11,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.SpanKind; import io.opentelemetry.instrumentation.testing.junit.InstrumentationExtension; import io.opentelemetry.instrumentation.testing.util.ThrowingRunnable; @@ -23,12 +24,19 @@ import io.reactivex.rxjava3.core.Observable; import io.reactivex.rxjava3.core.Scheduler; import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.disposables.Disposable; +import io.reactivex.rxjava3.functions.Function; import io.reactivex.rxjava3.internal.operators.flowable.FlowablePublish; import io.reactivex.rxjava3.internal.operators.observable.ObservablePublish; +import io.reactivex.rxjava3.plugins.RxJavaPlugins; import io.reactivex.rxjava3.schedulers.Schedulers; import java.util.Comparator; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.stream.Stream; import org.junit.jupiter.api.Test; @@ -91,7 +99,7 @@ public void onComplete() {} } @Test - public void basicMaybe() { + void basicMaybe() { int result = createParentSpan(() -> Maybe.just(1).map(this::addOne).blockingGet()); assertThat(result).isEqualTo(2); testing() @@ -106,7 +114,7 @@ public void basicMaybe() { } @Test - public void twoOperationsMaybe() { + void twoOperationsMaybe() { int result = createParentSpan(() -> Maybe.just(2).map(this::addOne).map(this::addOne).blockingGet()); assertThat(result).isEqualTo(4); @@ -126,7 +134,7 @@ public void twoOperationsMaybe() { } @Test - public void delayedMaybe() { + void delayedMaybe() { int result = createParentSpan( () -> Maybe.just(3).delay(100, TimeUnit.MILLISECONDS).map(this::addOne).blockingGet()); @@ -143,7 +151,7 @@ public void delayedMaybe() { } @Test - public void delayedTwiceMaybe() { + void delayedTwiceMaybe() { int result = createParentSpan( () -> @@ -170,7 +178,7 @@ public void delayedTwiceMaybe() { } @Test - public void basicFlowable() { + void basicFlowable() { Iterable result = createParentSpan( () -> Flowable.fromIterable(asList(5, 6)).map(this::addOne).toList().blockingGet()); @@ -191,7 +199,7 @@ public void basicFlowable() { } @Test - public void twoOperationsFlowable() { + void twoOperationsFlowable() { List result = createParentSpan( () -> @@ -225,7 +233,39 @@ public void twoOperationsFlowable() { } @Test - public void delayedFlowable() { + void observableFromCallableContextPropagation() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + AtomicReference traceId = new AtomicReference<>(); + AtomicReference innerObservableTraceId = new AtomicReference<>(); + AtomicReference endObservableTraceId = new AtomicReference<>(); + + createParentSpan( + () -> { + traceId.set(Span.current().getSpanContext().getTraceId()); + Disposable unused = + Observable.fromCallable( + () -> { + innerObservableTraceId.set(Span.current().getSpanContext().getTraceId()); + return "success"; + }) + .subscribeOn(Schedulers.io()) + .observeOn(Schedulers.single()) + .subscribe( + data -> { + endObservableTraceId.set(Span.current().getSpanContext().getTraceId()); + latch.countDown(); + latch.countDown(); + }); + assertThat(unused).isNotNull(); + }); + + latch.await(); + assertThat(innerObservableTraceId.get()).isEqualTo(traceId.get()); + assertThat(endObservableTraceId.get()).isEqualTo(traceId.get()); + } + + @Test + void delayedFlowable() { List result = createParentSpan( () -> @@ -251,7 +291,7 @@ public void delayedFlowable() { } @Test - public void delayedTwiceFlowable() { + void delayedTwiceFlowable() { List result = createParentSpan( () -> @@ -287,7 +327,7 @@ public void delayedTwiceFlowable() { } @Test - public void maybeFromCallable() { + void maybeFromCallable() { Integer result = createParentSpan( () -> Maybe.fromCallable(() -> addOne(10)).map(this::addOne).blockingGet()); @@ -308,7 +348,7 @@ public void maybeFromCallable() { } @Test - public void basicSingle() { + void basicSingle() { Integer result = createParentSpan(() -> Single.just(0).map(this::addOne).blockingGet()); assertThat(result).isEqualTo(1); testing() @@ -323,7 +363,7 @@ public void basicSingle() { } @Test - public void basicObservable() { + void basicObservable() { List result = createParentSpan(() -> Observable.just(0).map(this::addOne).toList().blockingGet()); assertThat(result).contains(1); @@ -339,7 +379,7 @@ public void basicObservable() { } @Test - public void connectableFlowable() { + void connectableFlowable() { List result = createParentSpan( () -> @@ -361,7 +401,7 @@ public void connectableFlowable() { } @Test - public void connectableObservable() { + void connectableObservable() { List result = createParentSpan( () -> @@ -383,7 +423,7 @@ public void connectableObservable() { } @Test - public void maybeError() { + void maybeError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Maybe.error(error).blockingGet())) .isEqualTo(error); @@ -395,7 +435,7 @@ public void maybeError() { } @Test - public void flowableError() { + void flowableError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Flowable.error(error)).toList().blockingGet()) .isEqualTo(error); @@ -407,7 +447,7 @@ public void flowableError() { } @Test - public void singleError() { + void singleError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Single.error(error)).blockingGet()) .isEqualTo(error); @@ -419,7 +459,7 @@ public void singleError() { } @Test - public void observableError() { + void observableError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy(() -> createParentSpan(() -> Observable.error(error).toList().blockingGet())) .isEqualTo(error); @@ -431,7 +471,7 @@ public void observableError() { } @Test - public void completableError() { + void completableError() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy( () -> createParentSpan(() -> Completable.error(error).toMaybe().blockingGet())) @@ -444,7 +484,7 @@ public void completableError() { } @Test - public void basicMaybeFailure() { + void basicMaybeFailure() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy( () -> @@ -470,7 +510,7 @@ public void basicMaybeFailure() { } @Test - public void basicFlowableFailure() { + void basicFlowableFailure() { IllegalStateException error = new IllegalStateException(EXCEPTION_MESSAGE); assertThatThrownBy( () -> @@ -497,7 +537,7 @@ public void basicFlowableFailure() { } @Test - public void basicMaybeCancel() { + void basicMaybeCancel() { createParentSpan( () -> Maybe.just(1).toFlowable().map(this::addOne).subscribe(CancellingSubscriber.INSTANCE)); @@ -509,7 +549,7 @@ public void basicMaybeCancel() { } @Test - public void basicFlowableCancel() { + void basicFlowableCancel() { createParentSpan( () -> Flowable.fromIterable(asList(5, 6)) @@ -523,7 +563,7 @@ public void basicFlowableCancel() { } @Test - public void basicSingleCancel() { + void basicSingleCancel() { createParentSpan( () -> Single.just(1).toFlowable().map(this::addOne).subscribe(CancellingSubscriber.INSTANCE)); @@ -535,7 +575,7 @@ public void basicSingleCancel() { } @Test - public void basicCompletableCancel() { + void basicCompletableCancel() { createParentSpan( () -> Completable.fromCallable(() -> 1) @@ -549,7 +589,7 @@ public void basicCompletableCancel() { } @Test - public void basicObservableCancel() { + void basicObservableCancel() { createParentSpan( () -> Observable.just(1) @@ -564,7 +604,7 @@ public void basicObservableCancel() { } @Test - public void basicMaybeChain() { + void basicMaybeChain() { createParentSpan( () -> Maybe.just(1) @@ -593,7 +633,7 @@ public void basicMaybeChain() { } @Test - public void basicFlowableChain() { + void basicFlowableChain() { createParentSpan( () -> Flowable.fromIterable(asList(5, 6)) @@ -631,7 +671,7 @@ public void basicFlowableChain() { // Publisher chain spans have the correct parents from subscription time @Test - public void maybeChainParentSpan() { + void maybeChainParentSpan() { Maybe maybe = Maybe.just(42).map(this::addOne).map(this::addTwo); testing().runWithSpan("trace-parent", () -> maybe.blockingGet()); testing() @@ -650,7 +690,7 @@ public void maybeChainParentSpan() { } @Test - public void maybeChainHasSubscriptionContext() { + void maybeChainHasSubscriptionContext() { Integer result = createParentSpan( () -> { @@ -680,7 +720,7 @@ public void maybeChainHasSubscriptionContext() { } @Test - public void flowableChainHasSubscriptionContext() { + void flowableChainHasSubscriptionContext() { List result = createParentSpan( () -> { @@ -719,7 +759,7 @@ public void flowableChainHasSubscriptionContext() { } @Test - public void singleChainHasSubscriptionContext() { + void singleChainHasSubscriptionContext() { Integer result = createParentSpan( () -> { @@ -749,7 +789,7 @@ public void singleChainHasSubscriptionContext() { } @Test - public void observableChainHasSubscriptionContext() { + void observableChainHasSubscriptionContext() { List result = createParentSpan( () -> { @@ -781,7 +821,7 @@ public void observableChainHasSubscriptionContext() { @ParameterizedTest @MethodSource("schedulers") - public void flowableMultiResults(Scheduler scheduler) { + void flowableMultiResults(Scheduler scheduler) { List result = testing() .runWithSpan( @@ -821,7 +861,7 @@ public void flowableMultiResults(Scheduler scheduler) { @ParameterizedTest @MethodSource("schedulers") - public void maybeMultipleTraceChains(Scheduler scheduler) { + void maybeMultipleTraceChains(Scheduler scheduler) { int iterations = 100; RxJava3ConcurrencyTestHelper.launchAndWait(scheduler, iterations, 60000, testing()); @SuppressWarnings("unchecked") @@ -851,4 +891,53 @@ public void maybeMultipleTraceChains(Scheduler scheduler) { assertions); testing().clearData(); } + + @Test + void scheduleHandlerChainingPreservesExistingHandler() throws InterruptedException { + AtomicInteger customHandlerCallCount = new AtomicInteger(0); + AtomicBoolean customHandlerExecuted = new AtomicBoolean(false); + + Function originalHandler = + RxJavaPlugins.getScheduleHandler(); + Function customHandler = + runnable -> + () -> { + customHandlerCallCount.incrementAndGet(); + customHandlerExecuted.set(true); + runnable.run(); + }; + + try { + RxJavaPlugins.setScheduleHandler(customHandler); + + CountDownLatch latch = new CountDownLatch(1); + AtomicBoolean observableExecuted = new AtomicBoolean(false); + AtomicReference traceId = new AtomicReference<>(); + + createParentSpan( + () -> { + traceId.set(Span.current().getSpanContext().getTraceId()); + Disposable unused = + Observable.fromCallable( + () -> { + observableExecuted.set(true); + return "test"; + }) + .subscribeOn(Schedulers.io()) + .subscribe(result -> latch.countDown()); + }); + + latch.await(); + assertThat(observableExecuted.get()).isTrue(); + assertThat(customHandlerExecuted.get()).isTrue(); + assertThat(customHandlerCallCount.get()).isGreaterThan(0); + + assertThat(traceId.get()).isNotNull(); + assertThat(traceId.get()).isNotEqualTo("00000000000000000000000000000000"); + + } finally { + // Restore original handler to avoid affecting other tests + RxJavaPlugins.setScheduleHandler(originalHandler); + } + } } diff --git a/instrumentation/rxjava/rxjava-3.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_0/TracingAssembly.java b/instrumentation/rxjava/rxjava-3.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_0/TracingAssembly.java index 191a2ebf9b1f..b9c056a2c3b9 100644 --- a/instrumentation/rxjava/rxjava-3.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_0/TracingAssembly.java +++ b/instrumentation/rxjava/rxjava-3.0/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_0/TracingAssembly.java @@ -97,6 +97,10 @@ public final class TracingAssembly { private static Function oldOnParallelAssembly; + @GuardedBy("TracingAssembly.class") + @Nullable + private static Function oldScheduleHandler; + @GuardedBy("TracingAssembly.class") private static boolean enabled; @@ -122,6 +126,8 @@ public void enable() { enableObservable(); + enableWrappedScheduleHandler(); + enableCompletable(); enableSingle(); @@ -146,6 +152,8 @@ public void disable() { disableObservable(); + disableWrappedScheduleHandler(); + disableCompletable(); disableSingle(); @@ -221,6 +229,25 @@ private static void enableObservable() { })); } + @GuardedBy("TracingAssembly.class") + private static void enableWrappedScheduleHandler() { + oldScheduleHandler = RxJavaPlugins.getScheduleHandler(); + RxJavaPlugins.setScheduleHandler( + runnable -> { + Context context = Context.current(); + Runnable wrappedRunnable = + () -> { + try (Scope ignored = context.makeCurrent()) { + runnable.run(); + } + }; + // If there was a previous handler, apply it to our wrapped runnable + return oldScheduleHandler != null + ? oldScheduleHandler.apply(wrappedRunnable) + : wrappedRunnable; + }); + } + @GuardedBy("TracingAssembly.class") @SuppressWarnings({"rawtypes", "unchecked"}) private static void enableSingle() { @@ -276,6 +303,12 @@ private static void disableObservable() { oldOnObservableSubscribe = null; } + @GuardedBy("TracingAssembly.class") + private static void disableWrappedScheduleHandler() { + RxJavaPlugins.setScheduleHandler(oldScheduleHandler); + oldScheduleHandler = null; + } + @GuardedBy("TracingAssembly.class") private static void disableCompletable() { RxJavaPlugins.setOnCompletableSubscribe(oldOnCompletableSubscribe); diff --git a/instrumentation/rxjava/rxjava-3.1.1/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_1_1/TracingAssembly.java b/instrumentation/rxjava/rxjava-3.1.1/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_1_1/TracingAssembly.java index 10726a499728..0b4c60938213 100644 --- a/instrumentation/rxjava/rxjava-3.1.1/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_1_1/TracingAssembly.java +++ b/instrumentation/rxjava/rxjava-3.1.1/library/src/main/java/io/opentelemetry/instrumentation/rxjava/v3_1_1/TracingAssembly.java @@ -97,6 +97,10 @@ public final class TracingAssembly { private static Function oldOnParallelAssembly; + @GuardedBy("TracingAssembly.class") + @Nullable + private static Function oldScheduleHandler; + @GuardedBy("TracingAssembly.class") private static boolean enabled; @@ -122,6 +126,8 @@ public void enable() { enableObservable(); + enableWrappedScheduleHandler(); + enableCompletable(); enableSingle(); @@ -146,6 +152,8 @@ public void disable() { disableObservable(); + disableWrappedScheduleHandler(); + disableCompletable(); disableSingle(); @@ -186,6 +194,25 @@ private static void enableCompletable() { })); } + @GuardedBy("TracingAssembly.class") + private static void enableWrappedScheduleHandler() { + oldScheduleHandler = RxJavaPlugins.getScheduleHandler(); + RxJavaPlugins.setScheduleHandler( + runnable -> { + Context context = Context.current(); + Runnable wrappedRunnable = + () -> { + try (Scope ignored = context.makeCurrent()) { + runnable.run(); + } + }; + // If there was a previous handler, apply it to our wrapped runnable + return oldScheduleHandler != null + ? oldScheduleHandler.apply(wrappedRunnable) + : wrappedRunnable; + }); + } + @GuardedBy("TracingAssembly.class") @SuppressWarnings({"rawtypes", "unchecked"}) private static void enableFlowable() { @@ -276,6 +303,12 @@ private static void disableObservable() { oldOnObservableSubscribe = null; } + @GuardedBy("TracingAssembly.class") + private static void disableWrappedScheduleHandler() { + RxJavaPlugins.setScheduleHandler(oldScheduleHandler); + oldScheduleHandler = null; + } + @GuardedBy("TracingAssembly.class") private static void disableCompletable() { RxJavaPlugins.setOnCompletableSubscribe(oldOnCompletableSubscribe);