diff --git a/codepropertygraph/build.sbt b/codepropertygraph/build.sbt index 75465b764..e70d575e4 100644 --- a/codepropertygraph/build.sbt +++ b/codepropertygraph/build.sbt @@ -3,7 +3,7 @@ name := "codepropertygraph" dependsOn(Projects.protoBindings) libraryDependencies ++= Seq( - "io.shiftleft" % "tinkergraph-gremlin" % "3.3.4.16", + "io.shiftleft" % "tinkergraph-gremlin" % "3.3.4.17-MP-SNAPSHOT", "com.michaelpollmeier" %% "gremlin-scala" % "3.3.4.13", "com.google.guava" % "guava" % "21.0", "org.apache.commons" % "commons-lang3" % "3.5", diff --git a/codepropertygraph/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/OnDiskOverflowConfig.scala b/codepropertygraph/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/OnDiskOverflowConfig.scala index 70cb25ae4..f36ea5528 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/OnDiskOverflowConfig.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/OnDiskOverflowConfig.scala @@ -3,6 +3,7 @@ package io.shiftleft.codepropertygraph.cpgloading import scala.compat.java8.OptionConverters._ /** configure graphdb to use ondisk overflow. + * if the file specified by `graphLocation` already exists, we'll initialize the graph from there * if `graphLocation` is specified, graph will be saved there on close, and can be reloaded by just instantiating one with the same setting * otherwise, system tmp directory is used (e.g. `/tmp`) and graph won't be saved on close */ case class OnDiskOverflowConfig(graphLocation: Option[String] = None, diff --git a/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoEdgeSerializer.java b/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoEdgeSerializer.java index a97901108..5faa051d1 100644 --- a/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoEdgeSerializer.java +++ b/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoEdgeSerializer.java @@ -9,10 +9,20 @@ import org.apache.tinkerpop.gremlin.tinkergraph.storage.Serializer; import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; +import java.util.SortedMap; +import java.util.TreeMap; public class ProtoEdgeSerializer extends Serializer { + /* TODO move definition of property indices to json schema + * (or - better - ensure it's always in the same order when generating the cpg.proto and + * use the index there) */ + final Map> propertyIndexByEdgeAndPropertyName; + + public ProtoEdgeSerializer(Map> propertyIndexByEdgeAndPropertyName) { + this.propertyIndexByEdgeAndPropertyName = propertyIndexByEdgeAndPropertyName; + } + @Override protected long getId(ProtoEdgeWithId edgeWithId) { return edgeWithId.id; @@ -24,10 +34,13 @@ protected String getLabel(ProtoEdgeWithId edgeWithId) { } @Override - protected Map getProperties(ProtoEdgeWithId edgeWithId) { - final Map propertyMap = new THashMap<>(edgeWithId.edge.getPropertyCount()); + protected SortedMap getProperties(ProtoEdgeWithId edgeWithId) { + final SortedMap propertyMap = new TreeMap<>(); + final String edgeType = edgeWithId.edge.getType().name(); + final Map propertyIndexByName = propertyIndexByEdgeAndPropertyName.get(edgeType); + for (Cpg.CpgStruct.Edge.Property property : edgeWithId.edge.getPropertyList()) { - final String key = property.getName().name(); + final Integer key = propertyIndexByName.get(property.getName().name()); final Cpg.PropertyValue propertyValue = property.getValue(); switch(propertyValue.getValueCase()) { case INT_VALUE: diff --git a/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoNodeSerializer.java b/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoNodeSerializer.java index b9718ce03..7a79ac14a 100644 --- a/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoNodeSerializer.java +++ b/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoNodeSerializer.java @@ -10,14 +10,24 @@ import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; public class ProtoNodeSerializer extends Serializer { + /* TODO move definition of property indices to json schema + * (or - better - ensure it's always in the same order when generating the cpg.proto and + * use the index there) */ + final Map> propertyIndexByEdgeAndPropertyName; + //NodeId -> EdgeLabel -> EdgeId private final Map> inEdgesByNodeId; private final Map> outEdgesByNodeId; - public ProtoNodeSerializer(Map> inEdgesByNodeId, Map> outEdgesByNodeId) { + public ProtoNodeSerializer(Map> propertyIndexByEdgeAndPropertyName, + Map> inEdgesByNodeId, + Map> outEdgesByNodeId) { + this.propertyIndexByEdgeAndPropertyName = propertyIndexByEdgeAndPropertyName; this.inEdgesByNodeId = inEdgesByNodeId; this.outEdgesByNodeId = outEdgesByNodeId; } @@ -33,10 +43,13 @@ protected String getLabel(Cpg.CpgStruct.Node node) { } @Override - protected Map getProperties(Cpg.CpgStruct.Node node) { - final Map propertyMap = new THashMap<>(node.getPropertyCount()); + protected SortedMap getProperties(Cpg.CpgStruct.Node node) { + final SortedMap propertyMap = new TreeMap<>(); + final String nodeType = node.getType().name(); + final Map propertyIndexByName = propertyIndexByEdgeAndPropertyName.get(nodeType); + for (Cpg.CpgStruct.Node.Property property : node.getPropertyList()) { - final String key = property.getName().name(); + final Integer key = propertyIndexByName.get(property.getName().name()); final Cpg.PropertyValue propertyValue = property.getValue(); switch (propertyValue.getValueCase()) { case INT_VALUE: diff --git a/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoToOverflowDb.scala b/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoToOverflowDb.scala index a58e0b2a3..061e8d537 100644 --- a/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoToOverflowDb.scala +++ b/cpg2overflowdb/src/main/scala/io/shiftleft/codepropertygraph/cpgloading/ProtoToOverflowDb.scala @@ -1,12 +1,13 @@ package io.shiftleft.codepropertygraph.cpgloading -import java.io.{File, FileInputStream} +import java.io.File import java.nio.file.{Files, Path} import java.util.{HashMap => JHashMap, Map => JMap} import gnu.trove.set.TLongSet import gnu.trove.set.hash.TLongHashSet import io.shiftleft.proto.cpg.Cpg.CpgStruct +import io.shiftleft.codepropertygraph.generated.{edges, nodes} import org.apache.logging.log4j.LogManager import org.apache.tinkerpop.gremlin.tinkergraph.storage.OndiskOverflow @@ -27,7 +28,20 @@ object ProtoToOverflowDb extends App { type EdgeLabel = String private lazy val logger = LogManager.getLogger(getClass) - private lazy val edgeSerializer = new ProtoEdgeSerializer + + private val edgePropertyIndexByNameAndElementName: JMap[String, JMap[String, Integer]] = + edges.Factories.All.map { factory => + (factory.forLabel, propertyIndexByName(factory.propertyNamesByIndex)) + }.toMap.asJava + + private lazy val nodePropertyIndexByNameAndElementName: JMap[String, JMap[String, Integer]] = + nodes.Factories.All.map { factory => + (factory.forLabel, propertyIndexByName(factory.propertyNamesByIndex)) + }.toMap.asJava + + private lazy val edgeSerializer = + new ProtoEdgeSerializer(edgePropertyIndexByNameAndElementName) + private lazy val nodeFilter = new NodeFilter parseConfig.map(run) @@ -71,7 +85,7 @@ object ProtoToOverflowDb extends App { overflowDb.getEdgeMVMap.put(edgeWithId.id, edgeSerializer.serialize(edgeWithId)) } - val nodeSerializer = new ProtoNodeSerializer(inEdgesByNodeId, outEdgesByNodeId) + val nodeSerializer = new ProtoNodeSerializer(nodePropertyIndexByNameAndElementName, inEdgesByNodeId, outEdgesByNodeId) cpgProto.getNodeList.asScala.par.filter(nodeFilter.filterNode).foreach { node => overflowDb.getVertexMVMap.put(node.getKey, nodeSerializer.serialize(node)) } @@ -89,6 +103,13 @@ object ProtoToOverflowDb extends App { }.parse(args, Config(cpg = null)) } + + private def propertyIndexByName( + propertyNamesByIndex: JMap[Integer, String]): JMap[String, Integer] = + propertyNamesByIndex.asScala.map { + case (idx, name) => (name, idx) + }.asJava + } case class Config(cpg: File, writeTo: Option[File] = None) \ No newline at end of file diff --git a/project/DomainClassCreator.scala b/project/DomainClassCreator.scala index 9a273a11b..d2da48bfc 100644 --- a/project/DomainClassCreator.scala +++ b/project/DomainClassCreator.scala @@ -51,7 +51,7 @@ object DomainClassCreator { package $edgesPackage import java.lang.{Boolean => JBoolean, Long => JLong} - import java.util.{Set => JSet} + import java.util.{HashMap => JHashMap, Set => JSet, TreeMap} import org.apache.tinkerpop.gremlin.structure.Property import org.apache.tinkerpop.gremlin.structure.{Vertex, VertexProperty} import org.apache.tinkerpop.gremlin.tinkergraph.structure.{EdgeRef, SpecializedElementFactory, SpecializedTinkerEdge, TinkerGraph, TinkerProperty, TinkerVertex, VertexRef} @@ -91,16 +91,36 @@ object DomainClassCreator { val edgeClassName = edgeType.className val edgeClassNameDb = s"${edgeClassName}Db" val keysQuoted = keys.map('"' + _.name + '"') - val keyToValueMap = keys - .map { key => - s""" "${key.name}" -> { instance: $edgeClassNameDb => instance.${camelCase(key.name)}()}""" - } - .mkString(",\n") + val keyConstants = keys.map(key => s"""val ${camelCase(key.name).capitalize} = "${key.name}" """).mkString("\n") + + val keyToValueMap = keys.map { key => + s""" "${key.name}" -> { instance: $edgeClassNameDb => instance.${camelCase(key.name)}()}""" + }.mkString(",\n") + + //note: schema changes will change the binary format - we could define constants in the schema to avoid that + val keysWithIndex = keys.zipWithIndex + val keyIndexConstants = keysWithIndex.map { case (key, idx) => + s"""val ${camelCase(key.name).capitalize} = $idx""" + }.mkString("\n") + + val propertyNamesByIndex = keysWithIndex.map { case (key, idx) => + val keyName = camelCase(key.name).capitalize + s"""ret.put($edgeClassName.Keys.Indices.$keyName, $edgeClassName.Keys.$keyName)""" + }.mkString("\n") + + val propertyTypeByIndex = keysWithIndex.map { case (key, idx) => + val keyName = camelCase(key.name).capitalize + s"""ret.put($edgeClassName.Keys.Indices.$keyName, classOf[${getBaseType(key)}])""" + }.mkString("\n") val companionObject = s""" |object $edgeClassName { | val Label = "${edgeType.name}" | object Keys { + | $keyConstants + | object Indices { + | $keyIndexConstants + | } | val All: JSet[String] = Set(${keysQuoted.mkString(", ")}).asJava | val KeyToValue: Map[String, $edgeClassNameDb => Any] = Map( | $keyToValueMap @@ -117,6 +137,18 @@ object DomainClassCreator { | | override def createEdgeRef(id: JLong, graph: TinkerGraph, outVertex: VertexRef[_ <: TinkerVertex], inVertex: VertexRef[_ <: TinkerVertex]) = | ${edgeClassName}(id, graph) + | + | override def propertyNamesByIndex() = { + | val ret = new JHashMap[Integer, String] + | $propertyNamesByIndex + | ret + | } + | + | override def propertyTypeByIndex() = { + | val ret = new JHashMap[Integer, Class[_]] + | $propertyTypeByIndex + | ret + | } | } | | def apply(wrapped: $edgeClassNameDb) = new EdgeRef(wrapped) with $edgeClassName @@ -131,7 +163,10 @@ object DomainClassCreator { class ${edgeClassNameDb}(_graph: TinkerGraph, _id: Long, private val _outVertex: Vertex, _inVertex: Vertex) extends SpecializedTinkerEdge(_graph, _id, _outVertex, $edgeClassName.Label, _inVertex, $edgeClassName.Keys.All) { - ${propertyBasedFields(keys)} + ${propertyBasedFields(edgeClassName, keys)} + + ${propertiesByStorageIdx(edgeClassName, keys)} + override protected def specificProperty[A](key: String): Property[A] = $edgeClassName.Keys.KeyToValue.get(key) match { case None => Property.empty[A] @@ -187,7 +222,7 @@ object DomainClassCreator { import gremlin.scala._ import io.shiftleft.codepropertygraph.generated.EdgeKeys import java.lang.{Boolean => JBoolean, Long => JLong} - import java.util.{Collections => JCollections, HashMap => JHashMap, Iterator => JIterator, Map => JMap, Set => JSet} + import java.util.{Collections => JCollections, HashMap => JHashMap, Iterator => JIterator, Map => JMap, Set => JSet, TreeMap} import org.apache.tinkerpop.gremlin.structure.{Vertex, VertexProperty} import org.apache.tinkerpop.gremlin.tinkergraph.structure.{SpecializedElementFactory, SpecializedTinkerVertex, TinkerGraph, TinkerVertexProperty, VertexRef} import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils @@ -314,6 +349,23 @@ object DomainClassCreator { keys: List[Property], nodeToInEdges: mutable.MultiMap[String, String]) = { val keyConstants = keys.map(key => s"""val ${camelCase(key.name).capitalize} = "${key.name}" """).mkString("\n") + + //note: schema changes will change the binary format - we could define constants in the schema to avoid that + val keysWithIndex = keys.zipWithIndex + val keyIndexConstants = keysWithIndex.map { case (key, idx) => + s"""val ${camelCase(key.name).capitalize} = $idx""" + }.mkString("\n") + + val propertyNamesByIndex = keysWithIndex.map { case (key, idx) => + val keyName = camelCase(key.name).capitalize + s"""ret.put(${nodeType.className}.Keys.Indices.$keyName, ${nodeType.className}.Keys.$keyName)""" + }.mkString("\n") + + val propertyTypeByIndex = keysWithIndex.map { case (key, idx) => + val keyName = camelCase(key.name).capitalize + s"""ret.put(${nodeType.className}.Keys.Indices.$keyName, classOf[${getBaseType(key)}])""" + }.mkString("\n") + val keyToValueMap = keys .map { property: Property => getHigherType(property) match { @@ -322,8 +374,7 @@ object DomainClassCreator { case HigherValueType.Option => s""" "${property.name}" -> { instance: ${nodeType.classNameDb} => instance.${camelCase(property.name)}.orNull}""" } - } - .mkString(",\n") + } .mkString(",\n") def outEdges(nodeType: NodeType): List[String] = { nodeType.outEdges.map(_.edgeName) @@ -343,6 +394,9 @@ object DomainClassCreator { val Label = "${nodeType.name}" object Keys { $keyConstants + object Indices { + $keyIndexConstants + } val All: JSet[String] = Set(${keys .map { key => camelCase(key.name).capitalize @@ -363,6 +417,17 @@ object DomainClassCreator { override def createVertex(id: JLong, graph: TinkerGraph) = new ${nodeType.classNameDb}(id, graph) override def createVertexRef(vertex: ${nodeType.classNameDb}) = ${nodeType.className}(vertex) override def createVertexRef(id: JLong, graph: TinkerGraph) = ${nodeType.className}(id, graph) + override def propertyNamesByIndex() = { + val ret = new JHashMap[Integer, String] + $propertyNamesByIndex + ret + } + + override def propertyTypeByIndex() = { + val ret = new JHashMap[Integer, Class[_]] + $propertyTypeByIndex + ret + } } def apply(wrapped: ${nodeType.classNameDb}) = new VertexRef(wrapped) with ${nodeType.className} @@ -394,7 +459,7 @@ object DomainClassCreator { val memberName = camelCase(key.name) Cardinality.fromName(key.cardinality) match { case Cardinality.One => - s"""if (${memberName} != null) { properties.put("${key.name}", ${memberName}) }""" + s"""properties.put("${key.name}", ${memberName})""" case Cardinality.ZeroOrOne => s"""${memberName}.map { value => properties.put("${key.name}", value) }""" case Cardinality.List => // need java list, e.g. for VertexSerializer @@ -521,9 +586,12 @@ object DomainClassCreator { override def specificKeys() = ${nodeType.className}.Keys.All /* all properties */ + // TODO optimise: load all properties in one go override def valueMap: JMap[String, AnyRef] = $valueMapImpl - ${propertyBasedFields(keys)} + ${propertyBasedFields(nodeType.className, keys)} + + ${propertiesByStorageIdx(nodeType.className, keys)} override def canEqual(that: Any): Boolean = that != null && that.isInstanceOf[${nodeType.classNameDb}] override def productElement(n: Int): Any = @@ -845,31 +913,69 @@ object Utils { case HigherValueType.List => s"List[${getBaseType(property)}]" } - def propertyBasedFields(properties: List[Property]): String = + def propertyBasedFields(elementName: String, properties: List[Property]): String = properties.map { property => val name = camelCase(property.name) + val nameCapitalized = name.capitalize + val tpeBase = getBaseType(property) val tpe = getCompleteType(property) + val readProperty = s"graph().asInstanceOf[TinkerGraph].readProperty(this, $elementName.Keys.Indices.$nameCapitalized, classOf[$tpeBase])" getHigherType(property) match { case HigherValueType.None => - /** TODO: rather than returning `null`, throw an exception, since this is a schema violation: - s"""|var _$name: $tpe = null - |def $name: $tpe = - | if (_$name == null) { - | throw new AssertionError("property $name is mandatory but hasn't been initialised yet") - |} else { _$name } """.stripMargin - */ s"""|private var _$name: $tpe = null - |def $name(): $tpe = _$name""".stripMargin + |def $name(): $tpe = { + | if (_$name == null) { + | _$name = $readProperty.asInstanceOf[$tpeBase] + | } + | validateMandatoryProperty($elementName.Keys.$nameCapitalized, _$name) + | _$name + |}""".stripMargin case HigherValueType.Option => s"""|private var _$name: $tpe = None - |def $name(): $tpe = _$name""".stripMargin + |def $name(): $tpe = { + | if (_$name == null) { + | _$name = Option($readProperty).asInstanceOf[$tpe] + | } + | _$name + |}""".stripMargin case HigherValueType.List => s"""|private var _$name: $tpe = Nil - |def $name(): $tpe = _$name""".stripMargin + |def $name(): $tpe = { + | if (_$name == null) { + | val property = $readProperty + | if (property == null) _$name = Nil + | else _$name = property.asInstanceOf[$tpe] + | } + | _$name + |}""".stripMargin } }.mkString("\n\n") + def propertiesByStorageIdx(elementName: String, keys: List[Property]): String = { + val putKeysImpl = keys.map { key: Property => + val memberName = camelCase(key.name) + val memberNameCapitalized = memberName.capitalize + val propertyAccessor = s"$elementName.Keys.Indices.$memberNameCapitalized" + Cardinality.fromName(key.cardinality) match { + case Cardinality.One => + s"""properties.put($propertyAccessor, $memberName)""" + case Cardinality.ZeroOrOne => + s"""${memberName}.map { value => properties.put($propertyAccessor, value) }""" + case Cardinality.List => // need java list, e.g. for VertexSerializer + s"""if (${memberName}.nonEmpty) { properties.put($propertyAccessor, $memberName.asJava) }""" + } + }.mkString("\n") + + s""" + |override def propertiesByStorageIdx = { + | val properties = new TreeMap[Integer, Object] + | // TODO optimise: load all properties in one go + | $putKeysImpl + | properties + |}""".stripMargin + } + def updateSpecificPropertyBody(properties: List[Property]): String = { val caseNotFound = s"""PropertyErrorRegister.logPropertyErrorIfFirst(getClass, key)"""