Skip to content

Commit 15abcfa

Browse files
authored
Fix MGet bug, randomize fan out distribution (#1885)
* Fix MGet bug, randomize fan out distribution Signed-off-by: Chase Engelbrecht <[email protected]> * Fix ktlint Signed-off-by: Chase Engelbrecht <[email protected]> * Fix import Signed-off-by: Chase Engelbrecht <[email protected]> --------- Signed-off-by: Chase Engelbrecht <[email protected]>
1 parent 7c2ad41 commit 15abcfa

File tree

4 files changed

+201
-45
lines changed

4 files changed

+201
-45
lines changed

alerting/src/main/kotlin/org/opensearch/alerting/DocumentLevelMonitorRunner.kt

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import org.opensearch.commons.alerting.util.AlertingException
3131
import org.opensearch.core.action.ActionListener
3232
import org.opensearch.core.common.breaker.CircuitBreakingException
3333
import org.opensearch.core.common.io.stream.Writeable
34-
import org.opensearch.core.index.shard.ShardId
3534
import org.opensearch.core.rest.RestStatus
3635
import org.opensearch.index.IndexNotFoundException
3736
import org.opensearch.index.seqno.SequenceNumbers
@@ -245,7 +244,8 @@ class DocumentLevelMonitorRunner : MonitorRunner() {
245244
* thus effectively making the fan-out a single node operation.
246245
* This is done to avoid de-dupe Alerts generated by Aggregation Sigma Rules
247246
**/
248-
val localNode = monitorCtx.clusterService!!.localNode()
247+
val clusterService = monitorCtx.clusterService!!
248+
val localNode = clusterService.localNode()
249249
val nodeMap: Map<String, DiscoveryNode> = if (docLevelMonitorInput?.fanoutEnabled == true) {
250250
getNodes(monitorCtx)
251251
} else {
@@ -254,10 +254,10 @@ class DocumentLevelMonitorRunner : MonitorRunner() {
254254
}
255255

256256
val nodeShardAssignments = distributeShards(
257-
monitorCtx,
257+
monitorCtx.totalNodesFanOut,
258258
nodeMap.keys.toList(),
259259
shards.toList(),
260-
concreteIndexName
260+
monitorCtx.clusterService!!.state().metadata.index(concreteIndexName).index
261261
)
262262

263263
val responses: Collection<DocLevelMonitorFanOutResponse> = suspendCoroutine { cont ->
@@ -609,41 +609,4 @@ class DocumentLevelMonitorRunner : MonitorRunner() {
609609
private fun getNodes(monitorCtx: MonitorRunnerExecutionContext): Map<String, DiscoveryNode> {
610610
return monitorCtx.clusterService!!.state().nodes.dataNodes.filter { it.value.version >= Version.CURRENT }
611611
}
612-
613-
private fun distributeShards(
614-
monitorCtx: MonitorRunnerExecutionContext,
615-
allNodes: List<String>,
616-
shards: List<String>,
617-
index: String,
618-
): Map<String, MutableSet<ShardId>> {
619-
val totalShards = shards.size
620-
val numFanOutNodes = allNodes.size.coerceAtMost((totalShards + 1) / 2)
621-
val totalNodes = monitorCtx.totalNodesFanOut.coerceAtMost(numFanOutNodes)
622-
val shardsPerNode = totalShards / totalNodes
623-
var shardsRemaining = totalShards % totalNodes
624-
625-
val shardIdList = shards.map {
626-
ShardId(monitorCtx.clusterService!!.state().metadata.index(index).index, it.toInt())
627-
}
628-
val nodes = allNodes.subList(0, totalNodes)
629-
630-
val nodeShardAssignments = mutableMapOf<String, MutableSet<ShardId>>()
631-
var idx = 0
632-
for (node in nodes) {
633-
val nodeShardAssignment = mutableSetOf<ShardId>()
634-
for (i in 1..shardsPerNode) {
635-
nodeShardAssignment.add(shardIdList[idx++])
636-
}
637-
nodeShardAssignments[node] = nodeShardAssignment
638-
}
639-
640-
for (node in nodes) {
641-
if (shardsRemaining == 0) {
642-
break
643-
}
644-
nodeShardAssignments[node]!!.add(shardIdList[idx++])
645-
--shardsRemaining
646-
}
647-
return nodeShardAssignments
648-
}
649612
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.alerting
6+
7+
import org.apache.logging.log4j.LogManager
8+
import org.apache.logging.log4j.Logger
9+
import org.opensearch.core.index.Index
10+
import org.opensearch.core.index.shard.ShardId
11+
12+
private val logger: Logger = LogManager.getLogger("FanOutEligibility")
13+
14+
fun distributeShards(
15+
maxFanoutNodes: Int,
16+
allNodes: List<String>,
17+
shards: List<String>,
18+
index: Index,
19+
): Map<String, MutableSet<ShardId>> {
20+
val totalShards = shards.size
21+
val numFanOutNodes = allNodes.size.coerceAtMost(totalShards)
22+
val totalNodes = maxFanoutNodes.coerceAtMost(numFanOutNodes)
23+
24+
val shardIdList = shards.map {
25+
ShardId(index, it.toInt())
26+
}
27+
val shuffledNodes = allNodes.shuffled()
28+
val nodes = shuffledNodes.subList(0, totalNodes)
29+
30+
val nodeShardAssignments = nodes.associateWith { mutableSetOf<ShardId>() }
31+
32+
if (nodeShardAssignments.isEmpty()) {
33+
logger.error("No nodes eligible for fanout")
34+
return nodeShardAssignments
35+
}
36+
37+
shardIdList.forEachIndexed { idx, shardId ->
38+
val nodeIdx = idx % nodes.size
39+
val node = nodes[nodeIdx]
40+
nodeShardAssignments[node]!!.add(shardId)
41+
}
42+
43+
return nodeShardAssignments
44+
}

alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportDocLevelMonitorFanOutAction.kt

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ class TransportDocLevelMonitorFanOutAction
595595
val findingStr =
596596
finding.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS)
597597
.string()
598-
log.debug("Findings: $findingStr")
598+
log.trace("Findings: $findingStr")
599599

600600
if (shouldCreateFinding and (
601601
monitor.shouldCreateSingleAlertForFindings == null ||
@@ -1267,23 +1267,38 @@ class TransportDocLevelMonitorFanOutAction
12671267
findingIdToDocSource: MutableMap<String, MultiGetItemResponse>
12681268
) {
12691269
val docFieldTags = parseSampleDocTags(monitor.triggers)
1270-
val request = MultiGetRequest()
12711270

12721271
// Perform mGet request in batches.
12731272
findingToDocPairs.chunked(findingsIndexBatchSize).forEach { batch ->
1273+
val request = MultiGetRequest()
1274+
val docIdToFindingId = mutableMapOf<String, String>()
1275+
12741276
batch.forEach { (findingId, docIdAndIndex) ->
12751277
val docIdAndIndexSplit = docIdAndIndex.split("|")
12761278
val docId = docIdAndIndexSplit[0]
1279+
docIdToFindingId[docId] = findingId
1280+
12771281
val concreteIndex = docIdAndIndexSplit[1]
12781282
if (findingId.isNotEmpty() && docId.isNotEmpty() && concreteIndex.isNotEmpty()) {
12791283
val docItem = MultiGetRequest.Item(concreteIndex, docId)
12801284
if (docFieldTags.isNotEmpty())
12811285
docItem.fetchSourceContext(FetchSourceContext(true, docFieldTags.toTypedArray(), emptyArray()))
12821286
request.add(docItem)
12831287
}
1284-
val response = client.suspendUntil { client.multiGet(request, it) }
1285-
response.responses.forEach { item ->
1288+
}
1289+
1290+
val startMget = System.currentTimeMillis()
1291+
val response = client.suspendUntil { client.multiGet(request, it) }
1292+
val mgetDuration = System.currentTimeMillis() - startMget
1293+
log.debug(
1294+
"DocLevelMonitor ${monitor.id} mget retrieved [${response.responses.size}] documents. Took: ${mgetDuration}ms"
1295+
)
1296+
response.responses.forEach { item ->
1297+
val findingId = docIdToFindingId[item.id]
1298+
if (findingId != null) {
12861299
findingIdToDocSource[findingId] = item
1300+
} else {
1301+
log.error("Unable to find finding ID for document with ID [${item.id}] for monitor [${monitor.id}]")
12871302
}
12881303
}
12891304
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.alerting
6+
7+
import org.opensearch.core.index.Index
8+
import org.opensearch.core.index.shard.ShardId
9+
import org.opensearch.test.OpenSearchTestCase
10+
11+
class MonitorFanOutUtilsTests : OpenSearchTestCase() {
12+
fun `test distribute few shards many nodes`() {
13+
val result = distributeShards(
14+
1000,
15+
listOf("nodeA", "nodeB", "nodeC", "nodeD", "nodeE"),
16+
listOf("0", "1"),
17+
Index("index1", "id1")
18+
)
19+
20+
validateDistribution(result, 2, listOf(1), 2)
21+
}
22+
23+
fun `test distribute randomizes the assigned node`() {
24+
val nodes = mutableSetOf<String>()
25+
26+
// Picking a node to distribute to is random. To reduce test flakiness, we run this 100 times to give a (1/5)^99 chance
27+
// that the same node is picked every time
28+
repeat(100) {
29+
val result = distributeShards(
30+
1000,
31+
listOf("nodeA", "nodeB", "nodeC", "nodeD", "nodeE"),
32+
listOf("0"),
33+
Index("index1", "id1")
34+
)
35+
36+
validateDistribution(result, 1, listOf(1), 1)
37+
nodes.addAll(result.keys)
38+
}
39+
40+
assertTrue(nodes.size > 1)
41+
}
42+
43+
fun `test distribute many shards few nodes`() {
44+
val result = distributeShards(
45+
1000,
46+
listOf("nodeA", "nodeB", "nodeC"),
47+
listOf("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"),
48+
Index("index1", "id1")
49+
)
50+
51+
validateDistribution(result, 3, listOf(3, 4), 10)
52+
}
53+
54+
fun `test distribute max nodes limits`() {
55+
val result = distributeShards(
56+
2,
57+
listOf("nodeA", "nodeB", "nodeC"),
58+
listOf("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"),
59+
Index("index1", "id1")
60+
)
61+
62+
validateDistribution(result, 2, listOf(5), 10)
63+
}
64+
65+
fun `test distribute edge case 1 shard`() {
66+
val result = distributeShards(
67+
1000,
68+
listOf("nodeA", "nodeB", "nodeC"),
69+
listOf("0"),
70+
Index("index1", "id1")
71+
)
72+
73+
validateDistribution(result, 1, listOf(1), 1)
74+
}
75+
76+
fun `test distribute edge case 1 node`() {
77+
val result = distributeShards(
78+
1000,
79+
listOf("nodeA"),
80+
listOf("0", "1", "2"),
81+
Index("index1", "id1")
82+
)
83+
84+
validateDistribution(result, 1, listOf(3), 3)
85+
}
86+
87+
fun `test distribute edge case 1 shard 1 node`() {
88+
val result = distributeShards(
89+
1000,
90+
listOf("nodeA"),
91+
listOf("0"),
92+
Index("index1", "id1")
93+
)
94+
95+
validateDistribution(result, 1, listOf(1), 1)
96+
}
97+
98+
fun `test distribute edge case no nodes does not throw`() {
99+
val result = distributeShards(
100+
1000,
101+
listOf(),
102+
listOf("0"),
103+
Index("index1", "id1")
104+
)
105+
106+
validateDistribution(result, 0, listOf(), 0)
107+
}
108+
109+
fun `test distribute edge case no shards does not throw`() {
110+
val result = distributeShards(
111+
1000,
112+
listOf("nodeA"),
113+
listOf(),
114+
Index("index1", "id1")
115+
)
116+
117+
validateDistribution(result, 0, listOf(), 0)
118+
}
119+
120+
private fun validateDistribution(
121+
result: Map<String, MutableSet<ShardId>>,
122+
expectedNodeCount: Int,
123+
expectedShardsPerNode: List<Int>,
124+
expectedTotalShardCount: Int
125+
) {
126+
assertEquals(expectedNodeCount, result.keys.size)
127+
var shardCount = 0
128+
result.forEach { (_, shards) ->
129+
assertTrue(expectedShardsPerNode.contains(shards.size))
130+
shardCount += shards.size
131+
}
132+
assertEquals(expectedTotalShardCount, shardCount)
133+
}
134+
}

0 commit comments

Comments
 (0)