Skip to content

Commit b76f810

Browse files
authored
chore: make device token thread safe (#571)
1 parent c5456cd commit b76f810

File tree

3 files changed

+209
-13
lines changed

3 files changed

+209
-13
lines changed

datapipelines/api/datapipelines.api

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,6 @@ public final class io/customer/datapipelines/plugins/AutomaticActivityScreenTrac
7676
public fun update (Lcom/segment/analytics/kotlin/core/Settings;Lcom/segment/analytics/kotlin/core/platform/Plugin$UpdateType;)V
7777
}
7878

79-
public final class io/customer/datapipelines/plugins/ContextPlugin : com/segment/analytics/kotlin/core/platform/Plugin {
80-
public field analytics Lcom/segment/analytics/kotlin/core/Analytics;
81-
public fun <init> (Lio/customer/sdk/data/store/DeviceStore;)V
82-
public fun execute (Lcom/segment/analytics/kotlin/core/BaseEvent;)Lcom/segment/analytics/kotlin/core/BaseEvent;
83-
public fun getAnalytics ()Lcom/segment/analytics/kotlin/core/Analytics;
84-
public fun getType ()Lcom/segment/analytics/kotlin/core/platform/Plugin$Type;
85-
public fun setAnalytics (Lcom/segment/analytics/kotlin/core/Analytics;)V
86-
public fun setup (Lcom/segment/analytics/kotlin/core/Analytics;)V
87-
public fun update (Lcom/segment/analytics/kotlin/core/Settings;Lcom/segment/analytics/kotlin/core/platform/Plugin$UpdateType;)V
88-
}
89-
9079
public final class io/customer/datapipelines/plugins/CustomerIODestination : com/segment/analytics/kotlin/core/platform/DestinationPlugin, com/segment/analytics/kotlin/core/platform/VersionedPlugin, sovran/kotlin/Subscriber {
9180
public fun <init> ()V
9281
public fun alias (Lcom/segment/analytics/kotlin/core/AliasEvent;)Lcom/segment/analytics/kotlin/core/BaseEvent;

datapipelines/src/main/kotlin/io/customer/datapipelines/plugins/ContextPlugin.kt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,35 @@ import io.customer.sdk.data.store.DeviceStore
1212
* Plugin class responsible for updating the context properties in events
1313
* tracked by Customer.io SDK.
1414
*/
15-
class ContextPlugin(private val deviceStore: DeviceStore) : Plugin {
15+
internal class ContextPlugin(
16+
private val deviceStore: DeviceStore,
17+
private val eventProcessor: ContextPluginEventProcessor = DefaultContextPluginEventProcessor()
18+
) : Plugin {
1619
override val type: Plugin.Type = Plugin.Type.Before
1720
override lateinit var analytics: Analytics
1821

22+
@Volatile
1923
internal var deviceToken: String? = null
2024

2125
override fun execute(event: BaseEvent): BaseEvent {
26+
return eventProcessor.execute(event, deviceStore) { deviceToken }
27+
}
28+
}
29+
30+
/**
31+
* Interface to handle the processing of events inside [ContextPlugin].
32+
* Allows custom logic to be injected for testing or extension.
33+
*/
34+
internal interface ContextPluginEventProcessor {
35+
fun execute(event: BaseEvent, deviceStore: DeviceStore, deviceTokenProvider: () -> String?): BaseEvent
36+
}
37+
38+
/**
39+
* Default implementation of [ContextPluginEventProcessor] that sets the user agent
40+
* in the context and ensures the device token is added if not already present.
41+
*/
42+
internal class DefaultContextPluginEventProcessor : ContextPluginEventProcessor {
43+
override fun execute(event: BaseEvent, deviceStore: DeviceStore, deviceTokenProvider: () -> String?): BaseEvent {
2244
// Set user agent in context as it is required by Customer.io Data Pipelines
2345
event.putInContext("userAgent", deviceStore.buildUserAgent())
2446
// Remove analytics library information from context as Customer.io
@@ -28,7 +50,7 @@ class ContextPlugin(private val deviceStore: DeviceStore) : Plugin {
2850
// In case of migration from older versions, the token might already be present in context
2951
// We need to ensure that the token is not overridden to avoid corruption of data
3052
// So we add current token to context only if context does not have any token already
31-
event.findInContextAtPath("device.token").firstOrNull()?.content ?: deviceToken?.let { token ->
53+
event.findInContextAtPath("device.token").firstOrNull()?.content ?: deviceTokenProvider()?.let { token ->
3254
// Device token is expected to be attached to device in context
3355
event.putInContextUnderKey("device", "token", token)
3456
}
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
package io.customer.datapipelines.plugins
2+
3+
import com.segment.analytics.kotlin.core.BaseEvent
4+
import com.segment.analytics.kotlin.core.utilities.putInContextUnderKey
5+
import io.customer.commontest.config.TestConfig
6+
import io.customer.commontest.extensions.flushCoroutines
7+
import io.customer.datapipelines.testutils.core.JUnitTest
8+
import io.customer.datapipelines.testutils.core.testConfiguration
9+
import io.customer.datapipelines.testutils.extensions.deviceToken
10+
import io.customer.datapipelines.testutils.extensions.getStringAtPath
11+
import io.customer.datapipelines.testutils.utils.OutputReaderPlugin
12+
import io.customer.datapipelines.testutils.utils.trackEvents
13+
import io.customer.sdk.DataPipelinesLogger
14+
import io.customer.sdk.core.di.SDKComponent
15+
import io.customer.sdk.data.store.DeviceStore
16+
import io.customer.sdk.data.store.GlobalPreferenceStore
17+
import io.mockk.every
18+
import io.mockk.mockk
19+
import java.util.concurrent.Executors
20+
import java.util.concurrent.TimeUnit
21+
import kotlin.random.Random
22+
import kotlinx.coroutines.test.StandardTestDispatcher
23+
import kotlinx.coroutines.test.runCurrent
24+
import kotlinx.coroutines.test.runTest
25+
import org.amshove.kluent.internal.assertEquals
26+
import org.amshove.kluent.shouldNotBeNull
27+
import org.junit.jupiter.api.Test
28+
29+
/**
30+
* Tests [ContextPlugin] behavior using [StandardTestDispatcher] to simulate realistic coroutine
31+
* scheduling and timing.
32+
*/
33+
class ContextPluginBehaviorTest : JUnitTest(dispatcher = StandardTestDispatcher()) {
34+
private val testScope get() = delegate.testScope
35+
36+
private lateinit var deviceStore: DeviceStore
37+
private lateinit var outputReaderPlugin: OutputReaderPlugin
38+
39+
override fun setup(testConfig: TestConfig) {
40+
super.setup(
41+
testConfiguration {
42+
diGraph {
43+
sdk {
44+
overrideDependency<DataPipelinesLogger>(mockk(relaxed = true))
45+
overrideDependency<DeviceStore>(mockk(relaxed = true))
46+
overrideDependency<GlobalPreferenceStore>(mockk(relaxed = true))
47+
}
48+
}
49+
}
50+
)
51+
52+
val androidSDKComponent = SDKComponent.android()
53+
deviceStore = androidSDKComponent.deviceStore
54+
every { deviceStore.buildUserAgent() } returns "test-user-agent"
55+
56+
outputReaderPlugin = OutputReaderPlugin()
57+
analytics.add(outputReaderPlugin)
58+
59+
// Run all pending coroutines to ensure analytics is initialized and ready to process events
60+
@Suppress("OPT_IN_USAGE")
61+
testScope.runCurrent()
62+
}
63+
64+
/**
65+
* Verifies that the plugin correctly adds the expected device token to the event context
66+
* when the token is accessed from a different thread, including within coroutine dispatchers.
67+
* This test should fail intermittently if token value is not correctly synchronized across threads.
68+
*/
69+
@Test
70+
fun execute_whenDeviceTokenIsSetFromAnotherThread_thenAddsCorrectTokenToEvent() = runTest {
71+
// Define test parameters for easier configuration
72+
val readerExecutionTimeMillis = 5000
73+
val writerCutoffTimeMillis = readerExecutionTimeMillis - 200 // ensure writer ends before reader execution
74+
val minThreadWaitTime = 50
75+
val maxThreadWaitTime = 100
76+
val tokenPrefix = "test-token-"
77+
val currentNanoTime = { System.nanoTime() }
78+
// Setup context plugin with a custom processor to track execution time
79+
val contextPluginProcessor = object : ContextPluginEventProcessor {
80+
val defaultProcessor = DefaultContextPluginEventProcessor()
81+
override fun execute(event: BaseEvent, deviceStore: DeviceStore, deviceTokenProvider: () -> String?): BaseEvent {
82+
// Add execution time to context for verification later
83+
event.putInContextUnderKey("test", "executionStartTime", currentNanoTime())
84+
val result = defaultProcessor.execute(event, deviceStore, deviceTokenProvider)
85+
event.putInContextUnderKey("test", "executionEndTime", currentNanoTime())
86+
return result
87+
}
88+
}
89+
val contextPlugin = ContextPlugin(deviceStore, contextPluginProcessor)
90+
analytics.add(contextPlugin)
91+
// Set initial value for test
92+
val writerLog = mutableMapOf<Long, String>() // (timestamp, read)
93+
// Set initial device token to skip unnecessary null checks and ensure value is fetched for initial events
94+
writerLog[currentNanoTime()] = ""
95+
// Prepare for concurrent execution
96+
val executor = Executors.newFixedThreadPool(2)
97+
val testStartTimeMs = currentNanoTime().nanosToMillis()
98+
99+
// Writer thread: writes tokens at random intervals
100+
val writerThread = executor.submit {
101+
var counter = 1
102+
while (true) {
103+
val nowMs = currentNanoTime().nanosToMillis()
104+
if (nowMs - testStartTimeMs >= writerCutoffTimeMillis) break
105+
106+
val newToken = "${tokenPrefix}${counter++}"
107+
waitUntil(nowMs + Random.nextInt(minThreadWaitTime, maxThreadWaitTime))
108+
109+
sdkInstance.registerDeviceToken(newToken).flushCoroutines(testScope)
110+
writerLog[currentNanoTime()] = newToken
111+
}
112+
}
113+
114+
// Reader thread: executes events with the current device token at random intervals
115+
val readerThread = executor.submit {
116+
var counter = 1
117+
// Ensure writer has started
118+
Thread.sleep(maxThreadWaitTime.toLong())
119+
while (true) {
120+
val nowMs = currentNanoTime().nanosToMillis()
121+
if (nowMs - testStartTimeMs >= readerExecutionTimeMillis) break
122+
123+
waitUntil(nowMs + Random.nextInt(minThreadWaitTime, maxThreadWaitTime))
124+
// Track an event with so that the context is updated with the current device token
125+
sdkInstance.track(name = "test-event-${counter++}").flushCoroutines(testScope)
126+
// Yield to allow other thread to run
127+
Thread.yield()
128+
}
129+
}
130+
131+
// Wait for both threads to finish
132+
writerThread.get(readerExecutionTimeMillis + 500L, TimeUnit.MILLISECONDS)
133+
readerThread.get(readerExecutionTimeMillis + 500L, TimeUnit.MILLISECONDS)
134+
executor.shutdown()
135+
136+
// For each event executed by SDK, verify writer token that was active during the event's execution
137+
val mismatches = outputReaderPlugin.trackEvents.mapNotNull { event ->
138+
val executionStartTime = event.context.getStringAtPath("test.executionStartTime")?.toLong().shouldNotBeNull()
139+
val executionEndTime = event.context.getStringAtPath("test.executionEndTime")?.toLong().shouldNotBeNull()
140+
val actualToken = event.context.deviceToken
141+
142+
// Find the latest write before the event execution end time
143+
val latestWriteBeforeEvent = writerLog
144+
.filterKeys { it <= executionEndTime }
145+
.maxByOrNull { it.key }
146+
// Find the newest write after the latest write
147+
// This is because the writer might have written a new token after the event was executed
148+
// So having a newer token is valid
149+
val nextWriteAfterLatest = writerLog
150+
.filterKeys { it > (latestWriteBeforeEvent?.key ?: Long.MAX_VALUE) }
151+
.minByOrNull { it.key }
152+
// Valid tokens are the latest write before the event and the next write after the latest
153+
val validTokens = setOfNotNull(latestWriteBeforeEvent?.value, nextWriteAfterLatest?.value)
154+
155+
// If the actual token is not in valid tokens, it's a mismatch
156+
if (actualToken !in validTokens) {
157+
return@mapNotNull Triple("$executionStartTime..$executionEndTime", actualToken, validTokens.joinToString(" or "))
158+
}
159+
return@mapNotNull null
160+
}
161+
162+
assertEquals(
163+
expected = 0,
164+
actual = mismatches.size,
165+
message = buildString {
166+
append("Event processed with incorrect device token:\n")
167+
append(
168+
mismatches.joinToString("\n") { (time, actual, expected) ->
169+
"- At $time NS: saw `$actual`, expected `$expected`"
170+
}
171+
)
172+
}
173+
)
174+
}
175+
176+
private fun waitUntil(timeMs: Long) {
177+
val sleepTime = timeMs - System.nanoTime().nanosToMillis()
178+
assert(sleepTime > 0) { "Cannot wait for past time: $timeMs" }
179+
Thread.sleep(sleepTime)
180+
}
181+
182+
private fun Long.nanosToMillis(): Long {
183+
return this / 1_000_000
184+
}
185+
}

0 commit comments

Comments
 (0)