diff --git a/build.gradle b/build.gradle index a72da6d31b..1dd3290230 100644 --- a/build.gradle +++ b/build.gradle @@ -142,6 +142,7 @@ subprojects { compileJava { options.annotationProcessorPath = configurations.apt options.compilerArgs += ["-AIgnoreContextWarnings"] + options.compilerArgs += ["--add-modules=jdk.incubator.vector"] options.encoding = "UTF-8" } diff --git a/extended-it/src/test/java/apoc/neo4j/docker/SimilarityEntepriseTest.java b/extended-it/src/test/java/apoc/neo4j/docker/SimilarityEntepriseTest.java new file mode 100644 index 0000000000..dcaa006b77 --- /dev/null +++ b/extended-it/src/test/java/apoc/neo4j/docker/SimilarityEntepriseTest.java @@ -0,0 +1,186 @@ +package apoc.neo4j.docker; + + +import apoc.util.Neo4jContainerExtension; +import apoc.util.TestContainerUtil; +import apoc.util.Util; +import org.junit.*; +import org.neo4j.driver.Session; +import org.neo4j.graphdb.Result; + +import java.util.List; +import java.util.Map; + +import static apoc.util.TestContainerUtil.createEnterpriseDB; +import static apoc.util.ExtendedTestContainerUtil.singleResultFirstColumn; + +public class SimilarityEntepriseTest { + +// private static List nodes = null; + + private static Neo4jContainerExtension neo4jContainer; + private static Session session; + + @BeforeClass + public static void beforeAll() throws InterruptedException { + neo4jContainer = createEnterpriseDB(List.of(TestContainerUtil.ApocPackage.EXTENDED), true) + .withNeo4jConfig("apoc.import.file.enabled", "true") + .withNeo4jConfig("metrics.enabled", "true") + .withNeo4jConfig("metrics.csv.interval", "1s") + .withNeo4jConfig("dbms.memory.transaction.total.max", "1G") + .withNeo4jConfig("server.memory.heap.initial_size", "1G") + .withNeo4jConfig("server.memory.heap.max_size", "1G") + .withNeo4jConfig("server.memory.heap.max_size", "1G") + .withNeo4jConfig("internal.cypher.enable_vector_type", "true") + .withNeo4jConfig("metrics.namespaces.enabled", "true"); + neo4jContainer.start(); + session = neo4jContainer.getSession(); + + session.executeWrite(tx -> tx.run( + "CYPHER 25 UNWIND range(0, 50000) as id " + + "CREATE (:Similar {vect: VECTOR([1, 2, id], 3, INTEGER32), id: 1, test: 1}), (:Similar {vect: VECTOR([1, id, 3], 3, INTEGER32), id: 2}), (:Similar {id: 3}), (:Similar {vect: VECTOR([3, 2, id], 3, INTEGER32), ajeje: 1, id: 4}), (:Similar {vect: VECTOR([4, 2, id], 3, INTEGER32), brazorf: 1, id: 5})" + ).consume() + ); + + // todo - i can't use it: Struct tag: 0x56 representing type VECTOR is not supported for this protocol version +// nodes = singleResultFirstColumn(session, "MATCH (n:Similar) RETURN collect(n) AS nodes"); + +// try (Transaction tx = db.beginTx()) { +// tx.findNodes(Label.label("Similar")).forEachRemaining(i -> { +// i.setProperty("embedding", new float[]{1, 2, 4}); +// }); +// tx.commit(); +// } + } + + @AfterClass + public static void afterAll() { + neo4jContainer.close(); + } + + @Test + public void testSimilarityCompare() { + long before = System.currentTimeMillis(); + String s = session.executeRead(tx -> tx.run( + "CYPHER 25 MATCH (node:Similar) WITH COLLECT(node) AS nodes " + + "CALL custom.search.batchedSimilarity(nodes, 'vect', VECTOR([1, 2, 3], 3, INTEGER32), 5, 0.8) YIELD node, score " + + "RETURN node.id, score", +// "CALL custom.search.batchedSimilarity($nodes, 'null', null, 5, 0.8, {stopWhenFound: true}) YIELD node, score RETURN node, score", + Map.of(/*"nodes", nodes*/)).list().toString()); + long after = System.currentTimeMillis(); + System.out.println("after - before apoc proc= " + (after - before)); + + System.out.println("s = " + s); + } + + // TODO - maybe this part: https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/embeddings/compute-similarity/ + // runs faster.. + + // TODO - https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/embeddings/compute-similarity/ + // it seems the Public APIs still need to be written. + // Maybe once they are written, it will be possible to operate with the Java Vector API and SIMD?? + + // todo - also cypher with float array + // todo - I have this warning: WARNING: Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API. + // @Ignore + @Test + public void testSimilarityWithPureCypherInBatch() { + + long before = System.currentTimeMillis(); + String cypherRes = session.executeRead(tx -> tx.run(""" + CYPHER 25 + MATCH (node:Similar) + WITH COLLECT(node) AS nodes + + UNWIND nodes AS node + WITH node, vector.similarity.cosine(node.vect, VECTOR([1, 2, 3], 3, INTEGER32)) AS score + WHERE score >= $threshold + RETURN node.id, score + ORDER BY score DESC + LIMIT $topK + """, Map.of(/*"nodes", nodes, */"threshold", 0.8, "queryVector", new float[]{1, 2, 3}, "topK", 5)).list().toString()); + long after = System.currentTimeMillis(); + System.out.println("after - before pure cypher = " + (after - before)); + System.out.println("cypherRes = " + cypherRes); + } + + // todo - remove + @Ignore + @Test + public void testSimilarity() { + long before = System.currentTimeMillis(); + String s = session.executeRead(tx -> tx.run( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.8) YIELD node, score RETURN node, score", + Map.of()).list().toString()); + long after = System.currentTimeMillis(); + System.out.println("after - before = " + (after - before)); + + System.out.println("s = " + s); + + + String s1 = session.executeRead(tx -> tx.run( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.95) YIELD node, score RETURN node, score", + Map.of()).list().toString()); + + System.out.println("s = " + s1); + String s2 = session.executeRead(tx -> tx.run( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 5, 0.8) YIELD node, score RETURN node, score", + Map.of()).list().toString()); + + System.out.println("s = " + s2); + + + String s12 = session.executeRead(tx -> tx.run( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 5, 0.95) YIELD node, score RETURN node, score", + Map.of()).list().toString()); + + System.out.println("s = " + s12); + } + + // todo - remove + @Ignore + @Test + public void testSimilarityWithStopWhenFound() { + String s = session.executeRead(tx -> tx.run( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.8, {stopWhenFound: true}) YIELD node, score RETURN node, score", + Map.of()).list().toString()); + + System.out.println("stopWhenFound = " + s); + } + + // todo - remove + //@Ignore + @Test + @Ignore + public void testSimilarityWithPureCypher() { +// try (Transaction tx = db.beginTx()) { +// tx.findNodes(Label.label("Similar")).forEachRemaining(i -> { +// i.setProperty("embedding", new float[]{1, 2, 4}); +// }); +// tx.commit(); +// } + + long before = System.currentTimeMillis(); + String cypherRes = session.executeRead(tx -> tx.run(""" + MATCH (node:Similar) + // UNWIND $nodes AS node + // 2. Calcola la similarità per ogni nodo + WITH node, vector.similarity.cosine(node.embedding, $queryVector) AS score + // 3. Filtra i risultati che superano la soglia + WHERE score >= $threshold + // 4. Restituisce il nodo e il punteggio, ordinando per trovare i migliori K + RETURN node, score + ORDER BY score DESC + LIMIT $topK + """, Map.of("threshold", 0.8, "queryVector", new float[]{1, 2, 3}, "topK", 5)).list().toString()); + System.out.println("cypherRes = " + cypherRes); + long after = System.currentTimeMillis(); + System.out.println("after - before cypher match = " + (after - before)); + } + + + // todo - pure cypher with float vector + // todo - pure cypher with float vector + // todo - pure cypher with float vector + +} diff --git a/extended/src/main/java/apoc/algo/Neo4jVectorSimilaritySIMD.java b/extended/src/main/java/apoc/algo/Neo4jVectorSimilaritySIMD.java new file mode 100644 index 0000000000..661fa59f82 --- /dev/null +++ b/extended/src/main/java/apoc/algo/Neo4jVectorSimilaritySIMD.java @@ -0,0 +1,72 @@ +package apoc.algo; + +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorOperators; +import jdk.incubator.vector.VectorSpecies; + +// TODO - compare.. +// servirebbe una cosa di questa, ma non posso farlo.. +public class Neo4jVectorSimilaritySIMD { + + // Seleziona la "forma" del vettore SIMD più grande disponibile sulla CPU, fino a 256 bit. + // Questo caricherà 8 float (8 * 32bit = 256bit) alla volta. + private static final VectorSpecies SPECIES = FloatVector.SPECIES_256; + + /** + * Calcola la similarità per vettori FLOAT32 usando la Java Vector API per l'accelerazione SIMD. + */ + private double calculateFloat32_SIMD(float[] v1, float[] v2) { + // Inizializza i vettori accumulatori a zero. Questi conterranno somme parziali. + FloatVector dotProductVec = FloatVector.zero(SPECIES); + FloatVector normAVec = FloatVector.zero(SPECIES); + FloatVector normBVec = FloatVector.zero(SPECIES); + + // Calcola il limite superiore per il ciclo vettoriale. + // Assicura che processiamo solo blocchi completi. + int loopBound = SPECIES.loopBound(v1.length); + + // --- CICLO VETTORIALE (SIMD) --- + // Processa i dati in blocchi della dimensione di SPECIES (es. 8 elementi alla volta). + for (int i = 0; i < loopBound; i += SPECIES.length()) { + // Carica un blocco di dati dagli array Java nei vettori SIMD + FloatVector va = FloatVector.fromArray(SPECIES, v1, i); + FloatVector vb = FloatVector.fromArray(SPECIES, v2, i); + + // Calcola il prodotto scalare parziale usando FMA (Fused Multiply-Add: a * b + c) + // È più efficiente di una moltiplicazione seguita da un'addizione. + dotProductVec = va.fma(vb, dotProductVec); + + // Calcola le norme parziali + normAVec = va.fma(va, normAVec); // va * va + normAVec + normBVec = vb.fma(vb, normBVec); // vb * vb + normBVec + } + + // "Riduci" i vettori accumulatori a un singolo valore scalare (double) + // Sommando tutte le "lane" (corsie) del vettore SIMD. + double dotProduct = dotProductVec.reduceLanes(VectorOperators.ADD); + double normA = normAVec.reduceLanes(VectorOperators.ADD); + double normB = normBVec.reduceLanes(VectorOperators.ADD); + + // --- CICLO SCALARE (per la "coda") --- + // Processa gli elementi rimanenti che non rientravano in un blocco completo. + for (int i = loopBound; i < v1.length; i++) { + dotProduct += v1[i] * v2[i]; + normA += v1[i] * v1[i]; + normB += v2[i] * v2[i]; + } + + if (normA == 0.0 || normB == 0.0) { + return 0.0; + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + public double cosineSimilarity(float[] v1, float[] v2) { + if (v1.length != v2.length) { + throw new IllegalArgumentException("I vettori devono avere la stessa dimensione"); + } + double rawSimilarity = calculateFloat32_SIMD(v1, v2); + return (rawSimilarity + 1) / 2.0; + } +} \ No newline at end of file diff --git a/extended/src/main/java/apoc/algo/Similarity.java b/extended/src/main/java/apoc/algo/Similarity.java new file mode 100644 index 0000000000..5642d1df36 --- /dev/null +++ b/extended/src/main/java/apoc/algo/Similarity.java @@ -0,0 +1,201 @@ +package apoc.algo; +// WARNING: This code uses the Java Vector API, which is a preview/incubator +// feature in modern JDKs (e.g., JDK 21+). To compile and run this code, +// you must enable the vector module with specific JVM flags, +// e.g.: --add-modules jdk.incubator.vector +import apoc.Extended; +//import jdk.incubator.vector.*; +import apoc.util.Util; +import org.neo4j.graphdb.GraphDatabaseService; +import org.neo4j.graphdb.Node; +import org.neo4j.graphdb.Transaction; +import org.neo4j.procedure.*; +import org.neo4j.values.storable.*; +//import org.neo4j.values.vector.VectorValue; +//import org.neo4j.values.vector.VectorValueNotSupported; +//import org.neo4j.values.vector.VectorValue.VectorType; + +import java.util.*; +import java.util.stream.Stream; + +/** + * A Neo4j Procedure class for performing on-the-fly similarity searches on a dynamic batch of nodes. + * It does not use pre-existing indexes but leverages the Java Vector API (SIMD) for maximum compute performance. + */ + +/** + * A realistic, high-performance Neo4j Procedure for on-the-fly similarity searches. + * Given the public Neo4j Vector API constraints (element-by-element access), + * this implementation uses a SCALAR calculation, as it's more performant than + * creating temporary arrays for a SIMD approach. + */ +@Extended +public class Similarity { + @Context + public GraphDatabaseService db; + + @Context + public Transaction tx; + + public record SimilarityConfig(boolean stopWhenFound) { + public static SimilarityConfig fromMap(Map config) { + if (config == null) { + config = Collections.emptyMap(); + } + boolean stopWhenFoundConf = Util.toBoolean(config.get("stopWhenFound")); + return new SimilarityConfig(stopWhenFoundConf); + } + } + + // --- Procedure Output Class --- + public static class NodeScore { + @Description("The found node.") + public Node node; + @Description("The calculated similarity score, ranging from 0.0 to 1.0.") + public double score; + public NodeScore(Node node, double score) { this.node = node; this.score = score; } + } + + @Procedure(name = "custom.search.batchedSimilarity", mode = Mode.READ) + @Description("Performs a type-safe cosine similarity search on a batch of nodes. Returns the top-K nodes above a given threshold.") + public Stream batchedSimilarity( + @Name("nodes") List nodes, + // TODO -- + + @Name("propertyName") String propertyName, + // TODO -- CAMBIARE `Object` IN `VectorValue` + @Name("queryVector") Object queryVector, + @Name("topK") long topK, + @Name("threshold") double threshold, + @Name(value = "config", defaultValue = "{}") Map config + ) { + nodes = Util.rebind(nodes, tx); + SimilarityConfig conf = SimilarityConfig.fromMap(config); + +// // TODO +// // TODO - remove it, mock +// // TODO +// if (queryVector == null) { +// queryVector = Values.int64Vector(1, 2, 3); +// } +// // TODO - end mock + + PriorityQueue topKQueue = new PriorityQueue<>(Comparator.comparingDouble(a -> a.score)); + + // Type-safe dispatch based on the query vector's class + if (queryVector instanceof FloatingPointVector queryVecFloat) { + for (Node node : nodes) { + Object propertyValue = node.getProperty(propertyName, null); + if (propertyValue instanceof FloatingPointVector nodeVecFloat) { + if (processNode(node, nodeVecFloat, queryVecFloat, topK, threshold, topKQueue, conf)) { + break; + }; + } + } + } else if (queryVector instanceof IntegralVector queryVecByte) { + for (Node node : nodes) { + Object propertyValue = node.getProperty(propertyName, "null"); +// // TODO +// // TODO +// // TODO - REMOVE IT, mock +// // TODO +// if (propertyValue == "null") { +// if (node.hasProperty("test")) { +// propertyValue = Values.int64Vector(1, 2, 4); +// } else if (node.hasProperty("ajeje")) { +// propertyValue = Values.int64Vector(1, 2, 3); +// } else if (node.hasProperty("brazorf")) { +// propertyValue = Values.int64Vector(1, 3, 4); +// } else { +// propertyValue = Values.int64Vector(1, 3, 3); +// } +// } +// // TODO - end mock + if (propertyValue instanceof IntegralVector nodeVecByte) { + if (processNode(node, nodeVecByte, queryVecByte, topK, threshold, topKQueue, conf)){ + break; + }; + } + } + } else { + // TODO - error? + return Stream.empty(); // Unsupported query vector type + } + + List result = new ArrayList<>(topKQueue); + result.sort(Comparator.comparingDouble((NodeScore ns) -> ns.score).reversed()); + return result.stream(); + } + + // Helper method for the main processing logic + private boolean processNode(Node node, VectorValue nodeVector, VectorValue queryVector, long topK, double threshold, PriorityQueue topKQueue, SimilarityConfig conf) { +// System.out.println("Similarity.processNode" + node.getAllProperties()); + if (nodeVector.dimensions() != queryVector.dimensions()) { + return false; + } + + // The raw calculation now dispatches to the appropriate SCALAR calculator + double rawSimilarity = calculateRawSimilarity(nodeVector, queryVector); + double normalizedScore = (rawSimilarity + 1) / 2.0; + + if (normalizedScore >= threshold) { + if (topKQueue.size() < topK) { + topKQueue.add(new NodeScore(node, normalizedScore)); + } else if (normalizedScore > topKQueue.peek().score) { + topKQueue.poll(); + topKQueue.add(new NodeScore(node, normalizedScore)); + } + // NUOVO controllo per l'arresto rapido + if (conf.stopWhenFound() && topKQueue.size() == topK) { + return true; + } + } + // Continue with the next node + return false; + } + + // --- Section with SCALAR calculation functions --- + + private double calculateRawSimilarity(VectorValue v1, VectorValue v2) { + if (v1 instanceof FloatingPointVector) { + return calculateFloat(v1, v2); + } else if (v1 instanceof IntegralVector) { + return calculateInteger(v1, v2); + } + throw new UnsupportedOperationException("Unsupported vector type for calculation: " + v1.getClass().getName()); + } + + private double calculateFloat(VectorValue v1, VectorValue v2) { + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + int dimensions = v1.dimensions(); + for (int i = 0; i < dimensions; i++) { + // Accessing data element-by-element, as per the public methods + double valA = v1.doubleValue(i); + double valB = v2.doubleValue(i); + dotProduct += valA * valB; + normA += valA * valA; + normB += valA * valA; + } + if (normA == 0.0 || normB == 0.0) return 0.0; + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + private float calculateInteger(VectorValue v1, VectorValue v2) { + float dotProduct = 0; + float normA = 0; + float normB = 0; + int dimensions = v1.dimensions(); + for (int i = 0; i < dimensions; i++) { + // Here we would need a floatValue(i) method, assuming it exists on Integer8Vector + float valA = v1.floatValue(i); // Hypothetical method + float valB = v2.floatValue(i); // Hypothetical method + dotProduct += valA * valB; + normA += valA * valA; + normB += valA * valA; + } + if (normA == 0 || normB == 0) return 0.0F; + return (float) (dotProduct / (Math.sqrt(normA) * Math.sqrt(normB))); + } +} \ No newline at end of file diff --git a/extended/src/main/java/apoc/algo/SimilarityToDelete.java b/extended/src/main/java/apoc/algo/SimilarityToDelete.java new file mode 100644 index 0000000000..34efa6ac7a --- /dev/null +++ b/extended/src/main/java/apoc/algo/SimilarityToDelete.java @@ -0,0 +1,171 @@ +//package apoc.algo; +//// WARNING: This code uses the Java Vector API, which is a preview/incubator +//// feature in modern JDKs (e.g., JDK 21+). To compile and run this code, +//// you must enable the vector module with specific JVM flags, +//// e.g.: --add-modules jdk.incubator.vector +// +//import apoc.Extended; +//import jdk.incubator.vector.*; +//import org.neo4j.graphdb.Node; +//import org.neo4j.procedure.Description; +//import org.neo4j.procedure.Mode; +//import org.neo4j.procedure.Name; +//import org.neo4j.procedure.Procedure; +//import org.neo4j.values.storable.VectorValue; +// +//import java.util.ArrayList; +//import java.util.Comparator; +//import java.util.List; +//import java.util.PriorityQueue; +//import java.util.stream.Stream; +// +///** +// * A Neo4j Procedure class for performing on-the-fly similarity searches on a dynamic batch of nodes. +// * It does not use pre-existing indexes but leverages the Java Vector API (SIMD) for maximum compute performance. +// */ +// +//@Extended +//public class SimilarityToDelete { +// +// // --- Procedure Output Class --- +// public static class NodeScore { +// @Description("The found node.") +// public Node node; +// @Description("The calculated similarity score, ranging from 0.0 to 1.0.") +// public double score; +// +// public NodeScore(Node node, double score) { +// this.node = node; +// this.score = score; +// } +// } +// +// // --- SIMD Vector Species --- +// // Select the preferred SIMD vector "shape" (species) for floats from the CPU. +// private static final VectorSpecies FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED; +// // Select the preferred SIMD vector "shape" (species) for bytes from the CPU. +// private static final VectorSpecies BYTE_SPECIES = ByteVector.SPECIES_PREFERRED; +// // The corresponding integer species when widening bytes. +// private static final VectorSpecies INT_SPECIES_FOR_BYTE_WIDENING = ByteVector.SPECIES_PREFERRED.widenedspecies(int.class); +// +// @Procedure(name = "custom.search.batchedSimilarity", mode = Mode.READ) +// @Description("Performs a cosine similarity search on a batch of nodes using SIMD. Returns the top-K nodes above a given threshold.") +// public Stream batchedSimilarity( +// @Name("nodes") List nodes, +// @Name("propertyName") String propertyName, +// @Name("queryVector") VectorValue queryVector, +// @Name("topK") long topK, +// @Name("threshold") double threshold +// ) { +// // Use a PriorityQueue to efficiently maintain the top-K elements (as a min-heap). +// PriorityQueue topKQueue = new PriorityQueue<>(Comparator.comparingDouble(a -> a.score)); +// +// for (Node node : nodes) { +// // Requirement: Kernel Neo4j API for property access +// Object propertyValue = node.getProperty(propertyName, null); +// if (!(propertyValue instanceof VectorValue nodeVector)) { +// // Skip the node if it doesn't have the property or it's not a vector. +// continue; +// } +// +// // Compatibility checks +// if (nodeVector.size() != queryVector.size() || nodeVector.valueType().realType() != queryVector.valueType().realType()) { +// continue; +// } +// +// // Calculate the raw similarity score [-1, 1] using the appropriate SIMD function +// double rawSimilarity = calculateRawSimilarity(nodeVector, queryVector); +// +// // Normalize the score to the [0, 1] range +// double normalizedScore = (rawSimilarity + 1) / 2.0; +// +// // Requirement: Early filtering and stop... (threshold) +// if (normalizedScore < threshold) { +// // Immediately discard nodes below the threshold. +// continue; +// } +// +// // Requirement: Early filtering and stop... (top-k) +// if (topKQueue.size() < topK) { +// topKQueue.add(new NodeScore(node, normalizedScore)); +// } else if (normalizedScore > topKQueue.peek().score) { +// topKQueue.poll(); // Remove the worst element (lowest score). +// topKQueue.add(new NodeScore(node, normalizedScore)); // Add the new best one. +// } +// } +// +// // Convert the queue to a sorted list and return it as a stream. +// List result = new ArrayList<>(topKQueue); +// result.sort(Comparator.comparingDouble((NodeScore ns) -> ns.score).reversed()); +// return result.stream(); +// } +// +// // --- Section with SIMD calculation functions --- +// +// private double calculateRawSimilarity(VectorValue v1, VectorValue v2) { +// switch (v1.valueType().realType()) { +// case FLOAT32: +// return calculateFloat32_SIMD(v1.floatArray(), v2.floatArray()); +// case INTEGER8: +// return calculateInteger8_SIMD(v1.byteArray(), v2.byteArray()); +// default: +// throw new VectorValueNotSupported("Vector type not supported for cosine similarity: " + v1.valueType().realType()); +// } +// } +// +// private double calculateFloat32_SIMD(float[] v1, float[] v2) { +// FloatVector dotProductVec = FloatVector.zero(FLOAT_SPECIES); +// FloatVector normAVec = FloatVector.zero(FLOAT_SPECIES); +// FloatVector normBVec = FloatVector.zero(FLOAT_SPECIES); +// int loopBound = FLOAT_SPECIES.loopBound(v1.length); +// for (int i = 0; i < loopBound; i += FLOAT_SPECIES.length()) { +// FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, v1, i); +// FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, v2, i); +// dotProductVec = va.fma(vb, dotProductVec); +// normAVec = va.fma(va, normAVec); +// normBVec = vb.fma(vb, normBVec); +// } +// double dotProduct = dotProductVec.reduceLanes(VectorOperators.ADD); +// double normA = normAVec.reduceLanes(VectorOperators.ADD); +// double normB = normBVec.reduceLanes(VectorOperators.ADD); +// // Scalar loop for the "tail" +// for (int i = loopBound; i < v1.length; i++) { +// dotProduct += (double) v1[i] * v2[i]; +// normA += (double) v1[i] * v1[i]; +// normB += (double) v2[i] * v2[i]; +// } +// if (normA == 0.0 || normB == 0.0) return 0.0; +// return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); +// } +// +// private double calculateInteger8_SIMD(byte[] v1, byte[] v2) { +// // Accumulators MUST be of a wider type (Int/Long) to prevent overflow. +// IntVector dotProductVec = IntVector.zero(INT_SPECIES_FOR_BYTE_WIDENING); +// IntVector normAVec = IntVector.zero(INT_SPECIES_FOR_BYTE_WIDENING); +// IntVector normBVec = IntVector.zero(INT_SPECIES_FOR_BYTE_WIDENING); +// int loopBound = BYTE_SPECIES.loopBound(v1.length); +// for (int i = 0; i < loopBound; i += BYTE_SPECIES.length()) { +// ByteVector ba = ByteVector.fromArray(BYTE_SPECIES, v1, i); +// ByteVector bb = ByteVector.fromArray(BYTE_SPECIES, v2, i); +// // "Widen" byte vectors to int vectors to perform calculations safely, avoiding overflow. +// IntVector ia = ba.widen(VectorOperators.UNSIGNED_BYTE); +// IntVector ib = bb.widen(VectorOperators.UNSIGNED_BYTE); +// // Perform operations on the int vectors. +// dotProductVec = ia.mul(ib).add(dotProductVec); +// normAVec = ia.mul(ia).add(normAVec); +// normBVec = ib.mul(ib).add(normBVec); +// } +// // Reduce to 'long' to be 100% safe from overflow on the final sum. +// long dotProduct = dotProductVec.reduceLanesToLong(VectorOperators.ADD); +// long normA = normAVec.reduceLanesToLong(VectorOperators.ADD); +// long normB = normBVec.reduceLanesToLong(VectorOperators.ADD); +// // Scalar loop for the "tail" +// for (int i = loopBound; i < v1.length; i++) { +// dotProduct += (long) v1[i] * v2[i]; +// normA += (long) v1[i] * v1[i]; +// normB += (long) v2[i] * v2[i]; +// } +// if (normA == 0 || normB == 0) return 0.0; +// return (double) dotProduct / (Math.sqrt((double) normA) * Math.sqrt((double) normB)); +// } +//} \ No newline at end of file diff --git a/extended/src/test/java/apoc/algo/SimilarityTest.java b/extended/src/test/java/apoc/algo/SimilarityTest.java new file mode 100644 index 0000000000..bcf9437f33 --- /dev/null +++ b/extended/src/test/java/apoc/algo/SimilarityTest.java @@ -0,0 +1,175 @@ +package apoc.algo; + + +import apoc.util.TestUtil; +import apoc.util.Util; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Ignore; +import org.junit.Test; +import org.neo4j.graphdb.Label; +import org.neo4j.graphdb.Result; +import org.neo4j.graphdb.Transaction; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; +import org.neo4j.values.storable.Values; + +import java.util.List; +import java.util.Map; + +import static apoc.util.TestUtil.singleResultFirstColumn; + +// TODO +// TODO +// TODO - MOVE TO EXTENDED-IT DUE TO ENTERPRISE VECTOR TYPES +// TODO +public class SimilarityTest { + + private static List nodes = null; + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + @BeforeClass + public static void setUp() throws Exception { + TestUtil.registerProcedure(db, Similarity.class); + db.executeTransactionally( + "UNWIND range(0, 50000) as id CREATE (:Similar {id: 1, test: 1}), (:Similar {id: 2}), (:Similar {id: 3}), (:Similar {ajeje: 1, id: 4}), (:Similar {brazorf: 1, id: 5})"); + + nodes = singleResultFirstColumn(db, "MATCH (n:Similar) RETURN collect(n) AS nodes"); + + try (Transaction tx = db.beginTx()) { + tx.findNodes(Label.label("Similar")).forEachRemaining(i -> { + i.setProperty("embedding", new float[]{1, 2, 4}); + }); + tx.commit(); + } + } + + @Test + public void testSimilarityCompare() { + long before = System.currentTimeMillis(); + String s = db.executeTransactionally( + "CALL custom.search.batchedSimilarity($nodes, 'null', null, 5, 0.8) YIELD node, score RETURN node, score", +// "CALL custom.search.batchedSimilarity($nodes, 'null', null, 5, 0.8, {stopWhenFound: true}) YIELD node, score RETURN node, score", + Map.of("nodes", nodes), Result::resultAsString); + long after = System.currentTimeMillis(); + System.out.println("after - before = " + (after - before)); + + System.out.println("s = " + s); + } + + // TODO - forse questa parte: https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/embeddings/compute-similarity/ + // va più veloce.. + + // TODO - https://neo4j.com/docs/genai/tutorials/embeddings-vector-indexes/embeddings/compute-similarity/ + // pare si debbano ancora scrivere le PublicAPI + // forse una volta che vengono scritte si potrà operare con Java Vector API e SIMD?? + + // todo - pure cypher with float array + // todo - ho questo warning: WARNING: Java vector incubator module is not readable. For optimal vector performance, pass '--add-modules jdk.incubator.vector' to enable Vector API. + @Test + public void testSimilarityWithPureCypherInBatch() { + + long before = System.currentTimeMillis(); + String cypherRes = db.executeTransactionally(""" + //MATCH (node:Similar) + UNWIND $nodes AS node + // 2. Calcola la similarità per ogni nodo + WITH node, vector.similarity.cosine(node.embedding, $queryVector) AS score + // 3. Filtra i risultati che superano la soglia + WHERE score >= $threshold + // 4. Restituisce il nodo e il punteggio, ordinando per trovare i migliori K + RETURN node, score + ORDER BY score DESC + LIMIT $topK + """, Map.of( "nodes", Util.rebind(nodes, db.beginTx()), "threshold", 0.8, "queryVector", new float[]{1, 2, 3}, "topK", 5), Result::resultAsString); + long after = System.currentTimeMillis(); + System.out.println("after - before cypher = " + (after - before)); + System.out.println("cypherRes = " + cypherRes); + } + + // todo - remove + @Ignore + @Test + public void testSimilarity() { + long before = System.currentTimeMillis(); + String s = db.executeTransactionally( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.8) YIELD node, score RETURN node, score", + Map.of(), Result::resultAsString); + long after = System.currentTimeMillis(); + System.out.println("after - before = " + (after - before)); + + System.out.println("s = " + s); + + + String s1 = db.executeTransactionally( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.95) YIELD node, score RETURN node, score", + Map.of(), Result::resultAsString); + + System.out.println("s = " + s1); + String s2 = db.executeTransactionally( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 5, 0.8) YIELD node, score RETURN node, score", + Map.of(), Result::resultAsString); + + System.out.println("s = " + s2); + + + String s12 = db.executeTransactionally( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 5, 0.95) YIELD node, score RETURN node, score", + Map.of(), Result::resultAsString); + + System.out.println("s = " + s12); + } + + // todo - remove + @Ignore + @Test + public void testSimilarityWithStopWhenFound() { + String s = db.executeTransactionally( + "MATCH (n:Similar) WITH collect(n) AS nodes CALL custom.search.batchedSimilarity(nodes, 'null', null, 2, 0.8, {stopWhenFound: true}) YIELD node, score RETURN node, score", + Map.of(), Result::resultAsString); + + System.out.println("stopWhenFound = " + s); + } + + // todo - remove + @Ignore + @Test + public void testSimilarityWithPureCypher() { +// try (Transaction tx = db.beginTx()) { +// tx.findNodes(Label.label("Similar")).forEachRemaining(i -> { +// i.setProperty("embedding", new float[]{1, 2, 4}); +// }); +// tx.commit(); +// } + + long before = System.currentTimeMillis(); + String cypherRes = db.executeTransactionally(""" + MATCH (node:Similar) + // UNWIND $nodes AS node + // 2. Calcola la similarità per ogni nodo + WITH node, vector.similarity.cosine(node.embedding, $queryVector) AS score + // 3. Filtra i risultati che superano la soglia + WHERE score >= $threshold + // 4. Restituisce il nodo e il punteggio, ordinando per trovare i migliori K + RETURN node, score + ORDER BY score DESC + LIMIT $topK + """, Map.of("threshold", 0.8, "queryVector", new float[]{1, 2, 3}, "topK", 5), Result::resultAsString); + System.out.println("cypherRes = " + cypherRes); + } + + + // todo - pure cypher with float vector + // todo - pure cypher with float vector + // todo - pure cypher with float vector + + + + + // todo - mettere queryNodes + + + +} \ No newline at end of file