diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala index f0b0cf5fc9e1..f6410ddc1064 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/HeldTaskCompletion.scala @@ -1,7 +1,11 @@ package io.joern.dataflowengineoss.queryengine +import io.shiftleft.codepropertygraph.generated.nodes.{Call, CfgNode} + +import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable import scala.collection.parallel.CollectionConverters._ +import scala.language.postfixOps /** Complete held tasks using the result table. The result table is modified in the process. * @@ -34,7 +38,6 @@ class HeldTaskCompletion( * created, `changed` is set to true for the result's table entry and `resultsProductByTask` is updated. */ def completeHeldTasks(): Unit = { - deduplicateResultTable() val toProcess = heldTasks.distinct.sortBy(x => @@ -46,6 +49,11 @@ class HeldTaskCompletion( def noneChanged = toProcess.map { t => t.fingerprint -> false }.toMap var changed: Map[TaskFingerprint, Boolean] = allChanged + val groupMap + : mutable.Map[TaskFingerprint, Map[((CfgNode, List[Call], Boolean), (CfgNode, List[Call], Boolean)), List[ + TableEntry + ]]] = mutable.Map() + val rwlock = new ReentrantReadWriteLock() while (changed.values.toList.contains(true)) { val taskResultsPairs = toProcess @@ -61,7 +69,7 @@ class HeldTaskCompletion( changed = noneChanged taskResultsPairs.foreach { case (t, resultsForTask, newResults) => - addCompletedTasksToMainTable(newResults.toList) + addCompletedTasksToMainTable(newResults.toList, groupMap, rwlock) newResults.foreach { case (fingerprint, _) => changed += fingerprint -> true } @@ -117,11 +125,73 @@ class HeldTaskCompletion( pathSeq.distinct.size == pathSeq.size } - private def addCompletedTasksToMainTable(results: List[(TaskFingerprint, TableEntry)]): Unit = { - results.groupBy(_._1).foreach { case (fingerprint, resultList) => - val entries = resultList.map(_._2) - val old = resultTable.getOrElse(fingerprint, Vector()).toList - resultTable.put(fingerprint, deduplicateTableEntries(old ++ entries)) + private def addCompletedTasksToMainTable( + results: List[(TaskFingerprint, TableEntry)], + groupMap: mutable.Map[TaskFingerprint, Map[((CfgNode, List[Call], Boolean), (CfgNode, List[Call], Boolean)), List[ + TableEntry + ]]], + rwlock: ReentrantReadWriteLock + ): Unit = { + results.groupBy(_._1).par.foreach { case (fingerprint, resultList) => + val entries = resultList.par.map(_._2) + val newGroups = entries + .groupBy { result => + val head = result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + val last = result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + (head, last) + } + + rwlock.readLock().lock() + val old = resultTable.getOrElse(fingerprint, Vector()).toList + val oldGroups = groupMap.getOrElse( + fingerprint, + old + .groupBy { result => + val head = result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + val last = result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + (head, last) + } + ) + rwlock.readLock().unlock() + + val mergedGroups = oldGroups ++ newGroups.map { case (k, v) => + k -> { + val old = oldGroups.getOrElse(k, List()) + val maxLen = if (old.length > 0) { + old.head.path.length + } else { 0 } + + val gtOrEqualMax = v.filter(x => x.path.length >= maxLen) + val gtMax = gtOrEqualMax.filter(x => x.path.length > maxLen) + + if (gtMax.length > 0) { + // new list contains elements with paths exceeding the max. retain new list elements only + // that have max length + var newMaxLen = maxLen + gtMax.foreach(x => { + if (x.path.length > newMaxLen) { + newMaxLen = x.path.length + } + }) + val element = gtMax.filter(x => x.path.length == newMaxLen).par.minBy(computePriority) + List(element) + } else if (gtOrEqualMax == 0) { + // new list contains all elements with paths less than the max. retain old list elements only + old + } else { + // new list contains all elements with paths less than or equal to the max but not exceeding it. + // append new list elements that are equal to max + val element = (old ++ gtOrEqualMax.par.filter(x => x.path.length == maxLen)).par.minBy(computePriority) + List(element) + } + } + } + val mergedList = mergedGroups.map { case (_, list) => list.head}.toList + + rwlock.writeLock().lock() + resultTable.put(fingerprint, mergedList) + groupMap.update(fingerprint, mergedGroups) + rwlock.writeLock().unlock() } } @@ -168,4 +238,17 @@ class HeldTaskCompletion( .toList } + private def computePriority(entry: TableEntry): BigInt = { + var priority: BigInt = entry.path.length + val multiplier: BigInt = 131072 // 2^17 + + entry.path.foreach(element => { + priority = priority + element.callSiteStack.length * multiplier + priority = priority + element.node.id() * multiplier * 64 + priority = priority + element.isOutputArg.hashCode() * multiplier * multiplier *64 + priority = priority + element.visible.hashCode() * multiplier * 64 + }) + priority + } + } diff --git a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala index 2fbe0aa3ff4b..ba96cc3ea606 100644 --- a/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala +++ b/dataflowengineoss/src/main/scala/io/joern/dataflowengineoss/queryengine/TaskSolver.scala @@ -5,6 +5,7 @@ import io.joern.dataflowengineoss.semanticsloader.Semantics import io.shiftleft.codepropertygraph.generated.nodes._ import io.shiftleft.semanticcpg.language.{toCfgNodeMethods, toExpressionMethods} +import java.security.MessageDigest import java.util.concurrent.Callable import scala.collection.mutable @@ -49,6 +50,7 @@ class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[Cfg val parentTask = r.taskStack(i) val pathToSink = r.path.slice(0, r.path.map(_.node).indexOf(parentTask.sink)) val newPath = pathToSink :+ PathElement(parentTask.sink, parentTask.callSiteStack) + (parentTask, TableEntry(path = newPath)) }.toList }