diff --git a/plugins/core/jetbrains-community/build.gradle.kts b/plugins/core/jetbrains-community/build.gradle.kts index 491a82f4166..c6a3128e8e2 100644 --- a/plugins/core/jetbrains-community/build.gradle.kts +++ b/plugins/core/jetbrains-community/build.gradle.kts @@ -18,15 +18,22 @@ buildscript { } } +private val generatedSrcDir = project.layout.buildDirectory.dir("generated-src") sourceSets { main { - java.srcDir(project.layout.buildDirectory.dir("generated-src")) + java.srcDir(generatedSrcDir) + } +} + +idea { + module { + generatedSourceDirs = generatedSourceDirs.toMutableSet() + generatedSrcDir.get().asFile } } val generateTelemetry = tasks.register("generateTelemetry") { inputFiles = listOf(file("${project.projectDir}/resources/telemetryOverride.json")) - outputDirectory = project.layout.buildDirectory.dir("generated-src").get().asFile + outputDirectory = generatedSrcDir.get().asFile doFirst { outputDirectory.deleteRecursively() diff --git a/plugins/core/jetbrains-community/src-233/com/intellij/platform/diagnostic/telemetry/helpers/UseWithoutActiveScope.kt b/plugins/core/jetbrains-community/src-233/com/intellij/platform/diagnostic/telemetry/helpers/UseWithoutActiveScope.kt new file mode 100644 index 00000000000..688e054bfa2 --- /dev/null +++ b/plugins/core/jetbrains-community/src-233/com/intellij/platform/diagnostic/telemetry/helpers/UseWithoutActiveScope.kt @@ -0,0 +1,27 @@ +// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.intellij.platform.diagnostic.telemetry.helpers + +import io.opentelemetry.api.common.AttributeKey +import io.opentelemetry.api.common.Attributes +import io.opentelemetry.api.trace.Span +import io.opentelemetry.api.trace.StatusCode +import kotlin.coroutines.cancellation.CancellationException + +val EXCEPTION_ESCAPED = AttributeKey.booleanKey("exception.escaped") + +inline fun Span.useWithoutActiveScope(operation: (Span) -> T): T { + try { + return operation(this) + } catch (e: CancellationException) { + recordException(e, Attributes.of(EXCEPTION_ESCAPED, true)) + throw e + } catch (e: Throwable) { + recordException(e, Attributes.of(EXCEPTION_ESCAPED, true)) + setStatus(StatusCode.ERROR) + throw e + } finally { + end() + } +} diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/services/telemetry/otel/OTelService.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/services/telemetry/otel/OTelService.kt new file mode 100644 index 00000000000..db3e66c976a --- /dev/null +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/services/telemetry/otel/OTelService.kt @@ -0,0 +1,171 @@ +// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +@file:Suppress("UnusedPrivateClass") + +package software.aws.toolkits.jetbrains.services.telemetry.otel + +import com.intellij.openapi.Disposable +import com.intellij.openapi.components.Service +import com.intellij.openapi.components.service +import com.intellij.openapi.diagnostic.thisLogger +import com.intellij.openapi.util.SystemInfoRt +import com.intellij.platform.util.http.ContentType +import com.intellij.platform.util.http.httpPost +import com.intellij.serviceContainer.NonInjectable +import io.opentelemetry.api.common.AttributeKey +import io.opentelemetry.api.common.Attributes +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator +import io.opentelemetry.context.Context +import io.opentelemetry.context.propagation.ContextPropagators +import io.opentelemetry.exporter.internal.otlp.traces.TraceRequestMarshaler +import io.opentelemetry.sdk.OpenTelemetrySdk +import io.opentelemetry.sdk.resources.Resource +import io.opentelemetry.sdk.trace.ReadWriteSpan +import io.opentelemetry.sdk.trace.ReadableSpan +import io.opentelemetry.sdk.trace.SdkTracerProvider +import io.opentelemetry.sdk.trace.SpanProcessor +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.launch +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider +import software.amazon.awssdk.http.ContentStreamProvider +import software.amazon.awssdk.http.HttpExecuteRequest +import software.amazon.awssdk.http.SdkHttpMethod +import software.amazon.awssdk.http.SdkHttpRequest +import software.amazon.awssdk.http.apache.ApacheHttpClient +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner +import java.io.ByteArrayOutputStream +import java.net.ConnectException + +private class BasicOtlpSpanProcessor( + private val coroutineScope: CoroutineScope, + private val traceUrl: String = "http://127.0.0.1:4318/v1/traces", +) : SpanProcessor { + override fun onStart(parentContext: Context, span: ReadWriteSpan) {} + override fun isStartRequired() = false + override fun isEndRequired() = true + + override fun onEnd(span: ReadableSpan) { + val data = span.toSpanData() + coroutineScope.launch { + try { + val item = TraceRequestMarshaler.create(listOf(data)) + + httpPost(traceUrl, contentLength = item.binarySerializedSize.toLong(), contentType = ContentType.XProtobuf) { + item.writeBinaryTo(this) + } + } catch (e: CancellationException) { + throw e + } catch (e: ConnectException) { + thisLogger().warn("Cannot export (url=$traceUrl): ${e.message}") + } catch (e: Throwable) { + thisLogger().error("Cannot export (url=$traceUrl)", e) + } + } + } +} + +private class SigV4OtlpSpanProcessor( + private val coroutineScope: CoroutineScope, + private val traceUrl: String, + private val creds: AwsCredentialsProvider, +) : SpanProcessor { + override fun onStart(parentContext: Context, span: ReadWriteSpan) {} + override fun isStartRequired() = false + override fun isEndRequired() = true + + private val client = ApacheHttpClient.create() + + override fun onEnd(span: ReadableSpan) { + coroutineScope.launch { + val data = span.toSpanData() + try { + val item = TraceRequestMarshaler.create(listOf(data)) + // calculate the sigv4 header + val signer = AwsV4HttpSigner.create() + val httpRequest = + SdkHttpRequest.builder() + .uri(traceUrl) + .method(SdkHttpMethod.POST) + .putHeader("Content-Type", "application/x-protobuf") + .build() + + val baos = ByteArrayOutputStream() + item.writeBinaryTo(baos) + val payload = ContentStreamProvider.fromByteArray(baos.toByteArray()) + val signedRequest = signer.sign { + it.identity(creds.resolveIdentity().get()) + it.request(httpRequest) + it.payload(payload) + it.putProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, "osis") + it.putProperty(AwsV4HttpSigner.REGION_NAME, "us-west-2") + } + + // Create and HTTP client and send the request. ApacheHttpClient requires the 'apache-client' module. + client.prepareRequest( + HttpExecuteRequest.builder() + .request(signedRequest.request()) + .contentStreamProvider(signedRequest.payload().orElse(null)) + .build() + ).call() + } catch (e: CancellationException) { + throw e + } catch (e: ConnectException) { + thisLogger().warn("Cannot export (url=$traceUrl): ${e.message}") + } catch (e: Throwable) { + thisLogger().error("Cannot export (url=$traceUrl)", e) + } + } + } +} + +private object StdoutSpanProcessor : SpanProcessor { + override fun onStart(parentContext: Context, span: ReadWriteSpan) {} + override fun isStartRequired() = false + override fun isEndRequired() = true + + override fun onEnd(span: ReadableSpan) { + println(span.toSpanData()) + } +} + +@Service +class OTelService @NonInjectable internal constructor(spanProcessors: List) : Disposable { + @Suppress("unused") + constructor() : this(listOf(StdoutSpanProcessor)) + + private val sdkDelegate = lazy { + OpenTelemetrySdk.builder() + .setTracerProvider( + SdkTracerProvider.builder() + .apply { + spanProcessors.forEach { + addSpanProcessor(it) + } + } + .setResource( + Resource.create( + Attributes.builder() + .put(AttributeKey.stringKey("os.type"), SystemInfoRt.OS_NAME) + .put(AttributeKey.stringKey("os.version"), SystemInfoRt.OS_VERSION) + .put(AttributeKey.stringKey("host.arch"), System.getProperty("os.arch")) + .build() + ) + ) + .build() + ) + .setPropagators(ContextPropagators.create(W3CTraceContextPropagator.getInstance())) + .build() + } + internal val sdk: OpenTelemetrySdk by sdkDelegate + + override fun dispose() { + if (sdkDelegate.isInitialized()) { + sdk.close() + } + } + + companion object { + fun getSdk() = service().sdk + } +} diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/services/telemetry/otel/OtelBase.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/services/telemetry/otel/OtelBase.kt new file mode 100644 index 00000000000..b5cf95d4fe3 --- /dev/null +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/services/telemetry/otel/OtelBase.kt @@ -0,0 +1,243 @@ +// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package software.aws.toolkits.jetbrains.services.telemetry.otel + +import com.intellij.openapi.application.ApplicationInfo +import com.intellij.openapi.application.ApplicationManager +import com.intellij.platform.diagnostic.telemetry.helpers.useWithoutActiveScope +import io.opentelemetry.api.common.AttributeKey +import io.opentelemetry.api.common.Attributes +import io.opentelemetry.api.trace.Span +import io.opentelemetry.api.trace.SpanBuilder +import io.opentelemetry.api.trace.SpanContext +import io.opentelemetry.api.trace.SpanKind +import io.opentelemetry.context.Context +import io.opentelemetry.context.ContextKey +import io.opentelemetry.context.Scope +import io.opentelemetry.sdk.trace.ReadWriteSpan +import kotlinx.coroutines.CoroutineScope +import software.amazon.awssdk.services.toolkittelemetry.model.AWSProduct +import software.aws.toolkits.core.utils.error +import software.aws.toolkits.core.utils.getLogger +import software.aws.toolkits.core.utils.warn +import software.aws.toolkits.jetbrains.isDeveloperMode +import software.aws.toolkits.jetbrains.services.telemetry.PluginResolver +import java.time.Instant +import java.util.concurrent.TimeUnit +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext +import com.intellij.platform.diagnostic.telemetry.helpers.use as ijUse +import com.intellij.platform.diagnostic.telemetry.helpers.useWithScope as ijUseWithScope + +val AWS_PRODUCT_CONTEXT_KEY = ContextKey.named("pluginDescriptor") +internal val PLUGIN_ATTRIBUTE_KEY = AttributeKey.stringKey("plugin") + +class DefaultSpan(context: Context?, delegate: Span) : BaseSpan(context, delegate) + +class DefaultSpanBuilder(delegate: SpanBuilder) : AbstractSpanBuilder(delegate) { + override fun doStartSpan() = DefaultSpan(parent, delegate.startSpan()) +} + +// temporary; will be generated +abstract class BaseSpan>( + context: Context?, + delegate: Span, +) : AbstractBaseSpan(context, delegate as ReadWriteSpan) + +abstract class AbstractSpanBuilder< + BuilderType : AbstractSpanBuilder, + SpanType : AbstractBaseSpan, + >( + protected val delegate: SpanBuilder, +) : SpanBuilder { + /** + * Same as [com.intellij.platform.diagnostic.telemetry.helpers.use] except downcasts to specific subclass of [BaseSpan] + * + * @inheritdoc + */ + inline fun use(operation: (SpanType) -> T): T = + // FIX_WHEN_MIN_IS_241: not worth fixing for 233 + if (ApplicationInfo.getInstance().build.baselineVersion == 233) { + startSpan().useWithoutActiveScope { span -> + span.makeCurrent().use { + operation(span as SpanType) + } + } + } else { + startSpan().ijUse { span -> + operation(span as SpanType) + } + } + + /** + * Same as [com.intellij.platform.diagnostic.telemetry.helpers.useWithScope] except downcasts to specific subclass of [BaseSpan] + * + * @inheritdoc + */ + suspend inline fun useWithScope( + context: CoroutineContext = EmptyCoroutineContext, + crossinline operation: suspend CoroutineScope.(SpanType) -> T, + ): T = + ijUseWithScope(context) { span -> + operation(span as SpanType) + } + + protected var parent: Context? = null + override fun setParent(context: Context): BuilderType { + parent = context + delegate.setParent(context) + return this as BuilderType + } + + override fun setNoParent(): BuilderType { + parent = null + delegate.setNoParent() + return this as BuilderType + } + + override fun addLink(spanContext: SpanContext): BuilderType { + delegate.addLink(spanContext) + return this as BuilderType + } + + override fun addLink( + spanContext: SpanContext, + attributes: Attributes, + ): BuilderType { + delegate.addLink(spanContext, attributes) + return this as BuilderType + } + + override fun setAttribute(key: String, value: String): BuilderType { + delegate.setAttribute(key, value) + return this as BuilderType + } + + override fun setAttribute(key: String, value: Long): BuilderType { + delegate.setAttribute(key, value) + return this as BuilderType + } + + override fun setAttribute(key: String, value: Double): BuilderType { + delegate.setAttribute(key, value) + return this as BuilderType + } + + override fun setAttribute(key: String, value: Boolean): BuilderType { + delegate.setAttribute(key, value) + return this as BuilderType + } + + override fun setAttribute( + key: AttributeKey, + value: V & Any, + ): BuilderType { + delegate.setAttribute(key, value) + return this as BuilderType + } + + override fun setAllAttributes(attributes: Attributes): BuilderType { + delegate.setAllAttributes(attributes) + return this as BuilderType + } + + override fun setSpanKind(spanKind: SpanKind): BuilderType { + delegate.setSpanKind(spanKind) + return this as BuilderType + } + + override fun setStartTimestamp(startTimestamp: Long, unit: TimeUnit): BuilderType { + delegate.setStartTimestamp(startTimestamp, unit) + return this as BuilderType + } + + override fun setStartTimestamp(startTimestamp: Instant): BuilderType { + delegate.setStartTimestamp(startTimestamp) + return this as BuilderType + } + + protected abstract fun doStartSpan(): SpanType + + override fun startSpan(): SpanType { + var parent = parent + if (parent == null) { + parent = Context.current() + } + requireNotNull(parent) + + val contextValue = parent.get(AWS_PRODUCT_CONTEXT_KEY) + if (contextValue == null) { + val s = Span.fromContextOrNull(parent) + parent = if (s is AbstractBaseSpan<*> && s.context != null) { + s.context.with(Span.fromContext(parent)) + } else { + parent.with(AWS_PRODUCT_CONTEXT_KEY, resolvePluginName()) + } + setParent(parent) + } + requireNotNull(parent) + + parent.get(AWS_PRODUCT_CONTEXT_KEY)?.toString()?.let { + setAttribute(PLUGIN_ATTRIBUTE_KEY, it) + } ?: run { + LOG.warn { "Reached setAttribute with null AWS_PRODUCT_CONTEXT_KEY, but should not be possible" } + } + + return doStartSpan() + } + + private companion object { + val LOG = getLogger>() + fun resolvePluginName() = PluginResolver.fromStackTrace(Thread.currentThread().stackTrace).product + } +} + +abstract class AbstractBaseSpan>(internal val context: Context?, private val delegate: ReadWriteSpan) : Span by delegate { + protected open val requiredFields: Collection = emptySet() + + /** + * Same as [com.intellij.platform.diagnostic.telemetry.helpers.use] except downcasts to specific subclass of [BaseSpan] + * + * @inheritdoc + */ + inline fun use(operation: (SpanType) -> T): T = + ijUse { span -> + operation(span as SpanType) + } + + fun metadata(key: String, value: String?): SpanType { + delegate.setAttribute(key, value) + return this as SpanType + } + + override fun end() { + validateRequiredAttributes() + delegate.end() + } + + override fun end(timestamp: Long, unit: TimeUnit) { + validateRequiredAttributes() + delegate.end() + } + + private fun validateRequiredAttributes() { + val missingFields = requiredFields.filter { delegate.getAttribute(AttributeKey.stringKey(it)) == null } + val message = { "${delegate.name} is missing required fields: ${missingFields.joinToString(", ")}" } + + if (missingFields.isNotEmpty()) { + when { + ApplicationManager.getApplication().isUnitTestMode -> error(message()) + isDeveloperMode() -> LOG.error(block = message) + else -> LOG.error(block = message) + } + } + } + + override fun makeCurrent(): Scope = + context?.with(this)?.makeCurrent() ?: super.makeCurrent() + + private companion object { + val LOG = getLogger>() + } +} diff --git a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/utils/ThreadingUtils.kt b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/utils/ThreadingUtils.kt index 06df30c72f5..2fdca6b8eec 100644 --- a/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/utils/ThreadingUtils.kt +++ b/plugins/core/jetbrains-community/src/software/aws/toolkits/jetbrains/utils/ThreadingUtils.kt @@ -14,6 +14,7 @@ import com.intellij.openapi.util.ThrowableComputable import com.intellij.util.ExceptionUtil import com.intellij.util.concurrency.AppExecutorUtil import com.intellij.util.concurrency.Semaphore +import io.opentelemetry.context.Context import software.aws.toolkits.jetbrains.services.telemetry.PluginResolver import java.time.Duration import java.util.concurrent.Future @@ -81,8 +82,9 @@ fun pluginAwareExecuteOnPooledThread(action: () -> T): Future { * worker thread will not contain original call stack. Necessary for telemetry. */ val pluginResolver = PluginResolver.fromCurrentThread() + val context = Context.current() return ApplicationManager.getApplication().executeOnPooledThread { PluginResolver.setThreadLocal(pluginResolver) - action() + context.wrap(action).call() } } diff --git a/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/telemetry/otel/OtelBaseTest.kt b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/telemetry/otel/OtelBaseTest.kt new file mode 100644 index 00000000000..445df85b241 --- /dev/null +++ b/plugins/core/jetbrains-community/tst/software/aws/toolkits/jetbrains/services/telemetry/otel/OtelBaseTest.kt @@ -0,0 +1,324 @@ +// Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package software.aws.toolkits.jetbrains.services.telemetry.otel + +import com.intellij.openapi.application.ApplicationManager +import com.intellij.testFramework.ApplicationExtension +import io.opentelemetry.api.trace.TraceId +import io.opentelemetry.context.Context +import io.opentelemetry.extension.kotlin.asContextElement +import io.opentelemetry.sdk.common.CompletableResultCode +import io.opentelemetry.sdk.trace.ReadWriteSpan +import io.opentelemetry.sdk.trace.ReadableSpan +import io.opentelemetry.sdk.trace.SpanProcessor +import kotlinx.coroutines.delay +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.AfterAllCallback +import org.junit.jupiter.api.extension.AfterEachCallback +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.RegisterExtension +import software.amazon.awssdk.services.toolkittelemetry.model.AWSProduct +import software.aws.toolkits.core.utils.getLogger +import software.aws.toolkits.core.utils.warn +import software.aws.toolkits.jetbrains.core.coroutines.getCoroutineBgContext +import software.aws.toolkits.jetbrains.utils.pluginAwareExecuteOnPooledThread +import software.aws.toolkits.jetbrains.utils.satisfiesKt +import software.aws.toolkits.jetbrains.utils.spinUntil +import java.util.concurrent.TimeUnit + +@ExtendWith(ApplicationExtension::class) +class OtelBaseTest { + private companion object { + @RegisterExtension + val otelExtension = OtelExtension() + } + + @Test + fun `context propagates from parent to child - happy case`() { + spanBuilder("tracer", "parentSpan").use { + spanBuilder("anotherTracer", "childSpan").use {} + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context propagates from parent to child - happy case coroutines`() = runTest { + spanBuilder("tracer", "parentSpan").useWithScope { + spanBuilder("anotherTracer", "childSpan").useWithScope {} + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context propagates from parent to child - with context override`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + spanBuilder("anotherTracer", "childSpan").use {} + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context propagates from parent to child when child overrides context`() { + spanBuilder("tracer", "parentSpan").use { + // parent->child relationship is still maintained because Context.current() will return parent context + spanBuilder("anotherTracer", "childSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use {} + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isNotEqualTo(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + } + } + + @Test + fun `context override does not propagate from parent to child when switching threads`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + ApplicationManager.getApplication().executeOnPooledThread { + spanBuilder("anotherTracer", "childSpan").use {} + }.get(10, TimeUnit.SECONDS) + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isNotEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context override propagates from parent to child when switching threads while preserving thread-local`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + pluginAwareExecuteOnPooledThread { + spanBuilder("anotherTracer", "childSpan").use {} + }.get(10, TimeUnit.SECONDS) + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context override propagates from parent to child when only child is coroutine`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + runTest { + spanBuilder("anotherTracer", "childSpan").useWithScope { + delay(10000) + } + } + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context override propagates from parent to child when only parent is coroutine`() = runTest { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).useWithScope { + spanBuilder("anotherTracer", "childSpan").use {} + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context override propagates from parent to child when both are coroutines`() = runTest { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).useWithScope { + spanBuilder("anotherTracer", "childSpan").useWithScope {} + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context override does not propagate from parent to child coroutines if context is not preserved`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + runBlocking(getCoroutineBgContext()) { + spanBuilder("anotherTracer", "childSpan").use {} + } + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isNotEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context override propagates from parent to child coroutines with manual coroutine context propagation`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + runBlocking(getCoroutineBgContext() + Context.current().asContextElement()) { + spanBuilder("anotherTracer", "childSpan").use {} + } + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.first() + val child = spans.last() + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + @Test + fun `context propagates from parent to child when child#end is after parent#end`() { + spanBuilder("tracer", "parentSpan").setParent(Context.current().with(AWS_PRODUCT_CONTEXT_KEY, AWSProduct.AMAZON_Q_FOR_VS_CODE)).use { + pluginAwareExecuteOnPooledThread { + spanBuilder("anotherTracer", "childSpan").use { + Thread.sleep(100) + } + } + } + spinUntil(java.time.Duration.ofSeconds(10)) { + otelExtension.completedSpans.size == 2 + } + + assertThat(otelExtension.completedSpans).hasSize(2).satisfiesKt { spans -> + val parent = spans.last() + val child = spans.first() + + assertThat(parent.hasEnded()) + assertThat(child.hasEnded()) + + // child started after parent + assertThat(child.toSpanData().startEpochNanos).isGreaterThanOrEqualTo(parent.toSpanData().startEpochNanos) + // and called end after parent + assertThat(child.toSpanData().endEpochNanos).isGreaterThanOrEqualTo(parent.toSpanData().endEpochNanos) + + assertThat(parent.parentSpanContext.traceId).isEqualTo(TraceId.getInvalid()) + assertThat(child.parentSpanContext.traceId).isEqualTo(parent.spanContext.traceId) + assertThat(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo("Amazon Q For VS Code") + assertThat(child.getAttribute(PLUGIN_ATTRIBUTE_KEY)).isEqualTo(parent.getAttribute(PLUGIN_ATTRIBUTE_KEY)) + } + } + + private fun spanBuilder(tracer: String, spanName: String) = DefaultSpanBuilder(otelExtension.sdk.sdk.getTracer(tracer).spanBuilder(spanName)) +} + +class OtelExtension : AfterEachCallback, AfterAllCallback { + private val openSpans = mutableSetOf() + private val _completedSpans = mutableListOf() + val completedSpans + get() = _completedSpans.reversed().toList() + + val sdk = OTelService( + listOf( + // should probably be a service loader + object : SpanProcessor { + override fun isStartRequired() = true + override fun isEndRequired() = true + + override fun onStart(parentContext: Context, span: ReadWriteSpan) { + openSpans.add(span) + } + + override fun onEnd(span: ReadableSpan) { + _completedSpans.add(span) + + if (!openSpans.contains(span)) { + LOG.warn(RuntimeException("Span ended without corresponding start")) { span.toString() } + } + openSpans.remove(span) + } + + override fun forceFlush(): CompletableResultCode { + assert(openSpans.isEmpty()) { "Not all open spans were closed: ${openSpans.joinToString(", ")}" } + return CompletableResultCode.ofSuccess() + } + } + ) + ) + + override fun afterEach(context: ExtensionContext?) { + reset() + } + + override fun afterAll(context: ExtensionContext?) { + sdk.sdk.shutdown() + } + + fun reset() { + openSpans.clear() + _completedSpans.clear() + } + + companion object { + private val LOG = getLogger() + } +}