Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@ import (
"strings"

"github.com/antlr/antlr4/runtime/Go/antlr"
"github.com/cloudprivacylabs/lpg/v2"
"github.com/cloudprivacylabs/opencypher/parser"
)

type Evaluatable interface {
Evaluate(*EvalContext) (Value, error)
}

type ResultPath struct {
Result *lpg.Path
Symbols map[string]Value
}

type regularQuery struct {
singleQuery Evaluatable
unions []union
Expand Down Expand Up @@ -81,11 +87,11 @@ type multiPartQuery struct {
}

type ReadingClause interface {
GetResults(*EvalContext) (ResultSet, error)
GetResults(*EvalContext) ([]ResultPath, error)
}

type UpdatingClause interface {
Update(*EvalContext, ResultSet) (Value, error)
Update(*EvalContext, []ResultPath) (Value, error)
TopLevelUpdate(*EvalContext) (Value, error)
}

Expand Down Expand Up @@ -445,8 +451,9 @@ func oC_SinglePartQuery(ctx *parser.OC_SinglePartQueryContext) singlePartQuery {
return ret
}

//oC_MultiPartQuery
// : ( ( oC_ReadingClause SP? )* ( oC_UpdatingClause SP? )* oC_With SP? )+ oC_SinglePartQuery ;
// oC_MultiPartQuery
//
// : ( ( oC_ReadingClause SP? )* ( oC_UpdatingClause SP? )* oC_With SP? )+ oC_SinglePartQuery ;
func oC_MultiPartQuery(ctx *parser.OC_MultiPartQueryContext) multiPartQuery {
ret := multiPartQuery{parts: []multiPartQueryPart{}}
count := ctx.GetChildCount()
Expand Down Expand Up @@ -663,11 +670,11 @@ func oC_ComparisonExpression(ctx *parser.OC_ComparisonExpressionContext) Express
}

// oC_AddOrSubtractExpression :
// oC_MultiplyDivideModuloExpression (
// ( SP? '+' SP? oC_MultiplyDivideModuloExpression ) |
// ( SP? '-' SP? oC_MultiplyDivideModuloExpression )
// )*
//
// oC_MultiplyDivideModuloExpression (
// ( SP? '+' SP? oC_MultiplyDivideModuloExpression ) |
// ( SP? '-' SP? oC_MultiplyDivideModuloExpression )
// )*
func oC_AddOrSubtractExpression(ctx *parser.OC_AddOrSubtractExpressionContext) Expression {
ret := &addOrSubtractExpression{}
target := &ret.add
Expand Down Expand Up @@ -695,10 +702,11 @@ func oC_AddOrSubtractExpression(ctx *parser.OC_AddOrSubtractExpressionContext) E
}

// oC_MultiplyDivideModuloExpression :
// oC_PowerOfExpression (
// ( SP? '*' SP? oC_PowerOfExpression ) |
// ( SP? '/' SP? oC_PowerOfExpression ) |
// ( SP? '%' SP? oC_PowerOfExpression ) )* ;
//
// oC_PowerOfExpression (
// ( SP? '*' SP? oC_PowerOfExpression ) |
// ( SP? '/' SP? oC_PowerOfExpression ) |
// ( SP? '%' SP? oC_PowerOfExpression ) )* ;
func oC_MultiplyDivideModuloExpression(ctx *parser.OC_MultiplyDivideModuloExpressionContext) Expression {
ret := &multiplyDivideModuloExpression{}
count := ctx.GetChildCount()
Expand Down Expand Up @@ -728,7 +736,8 @@ func oC_MultiplyDivideModuloExpression(ctx *parser.OC_MultiplyDivideModuloExpres
}

// oC_PowerOfExpression :
// oC_UnaryAddOrSubtractExpression ( SP? '^' SP? oC_UnaryAddOrSubtractExpression )* ;
//
// oC_UnaryAddOrSubtractExpression ( SP? '^' SP? oC_UnaryAddOrSubtractExpression )* ;
func oC_PowerOfExpression(ctx *parser.OC_PowerOfExpressionContext) Evaluatable {
ret := powerOfExpression{}
for _, x := range ctx.AllOC_UnaryAddOrSubtractExpression() {
Expand Down Expand Up @@ -1170,7 +1179,7 @@ func oC_FilterExpression(ctx *parser.OC_FilterExpressionContext) filterExpressio
return ret
}

//oC_RelationshipsPattern : oC_NodePattern ( SP? oC_PatternElementChain )+ ;
// oC_RelationshipsPattern : oC_NodePattern ( SP? oC_PatternElementChain )+ ;
// oC_PatternElementChain : oC_RelationshipPattern SP? oC_NodePattern ;
func oC_RelationshipsPattern(ctx *parser.OC_RelationshipsPatternContext) relationshipsPattern {
ret := relationshipsPattern{
Expand Down
53 changes: 53 additions & 0 deletions cartesian.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package opencypher

// CartesianProuductPaths builds the product of all the resultpaths
func CartesianProductPaths(ctx *EvalContext, numItems int, getItem func(int, *EvalContext) ([]ResultPath, error), filter func([]ResultPath) bool) [][]ResultPath {
result := make([][]ResultPath, 0)
product := make([][]ResultPath, numItems)
indexes := make([]int, numItems)

columnProcessor := func(next func(int)) func(int) {
return func(column int) {
product[column], _ = getItem(column, ctx)
for i := range product[column] {
indexes[column] = i
next(column + 1)
}
}
}

capture := func(int) {
row := make([]ResultPath, 0, numItems)
for i, x := range indexes {
row = append(row, product[i][x])
}
if filter(row) {
result = append(result, row)
}
}

next := columnProcessor(capture)
for column := numItems - 2; column >= 0; column-- {
next = columnProcessor(next)
}
next(0)
return result
}

func AllPathsToResultSets(paths [][]ResultPath) []ResultSet {
res := make([]ResultSet, len(paths))
for i, path := range paths {
for j := range path {
res[i].Rows[i] = paths[i][j].Symbols
}
}
return res
}

func ResultPathToResultSet(resultPath []ResultPath) ResultSet {
rs := ResultSet{Rows: make([]map[string]Value, 0, len(resultPath))}
for _, rp := range resultPath {
rs.Rows = append(rs.Rows, rp.Symbols)
}
return rs
}
129 changes: 129 additions & 0 deletions cartesian_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package opencypher

import (
"encoding/json"
"fmt"
"os"
"testing"

"github.com/cloudprivacylabs/lpg/v2"
)

func TestCartesianProductPaths(t *testing.T) {
f, err := os.Open("testdata/g1.json")
if err != nil {
t.Error(err)
return
}
target := lpg.NewGraph()
err = lpg.JSON{}.Decode(target, json.NewDecoder(f))
if err != nil {
t.Error(err)
return
}
path := &lpg.Path{}
pe := make([]lpg.PathElement, 0)
for itr := target.GetEdges(); itr.Next(); {
// path.Append(lpg.PathElement{
// Edge: itr.Edge(),
// })
pe = append(pe, lpg.PathElement{
Edge: itr.Edge(),
})
}
path.Append(pe...)
{
tests := []struct {
rp [][]ResultPath
expLen int
}{
{
rp: [][]ResultPath{
{
ResultPath{
Result: path,
},
ResultPath{
Result: &lpg.Path{},
},
ResultPath{
Result: &lpg.Path{},
},
},
{
ResultPath{
Result: &lpg.Path{},
},
ResultPath{
Result: &lpg.Path{},
},
},
{
ResultPath{
Result: &lpg.Path{},
},
ResultPath{
Result: &lpg.Path{},
},
},
},
expLen: 12,
},
{
rp: [][]ResultPath{
{
ResultPath{
Result: path,
},
ResultPath{
Result: &lpg.Path{},
},
ResultPath{
Result: &lpg.Path{},
},
},
{
ResultPath{
Result: &lpg.Path{},
},
},
{
ResultPath{
Result: &lpg.Path{},
},
},
},
expLen: 3,
},
{
rp: [][]ResultPath{
{
ResultPath{
Result: path,
},
},
{
ResultPath{
Result: &lpg.Path{},
},
},
},
expLen: 1,
},
}

for _, test := range tests {
prod := CartesianProductPaths(NewEvalContext(target), len(test.rp), func(i int, ec *EvalContext) ([]ResultPath, error) {
return test.rp[i], nil
}, func([]ResultPath) bool {
return true
})
if len(prod) != test.expLen {
t.Errorf("Got %d", len(prod))
for _, x := range prod {
fmt.Println("test2:", x)
}
}
}
}
}
25 changes: 15 additions & 10 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,27 +456,32 @@ func (query singlePartQuery) Evaluate(ctx *EvalContext) (Value, error) {
}
return nil
}
results := *NewResultSet()
results := make([]ResultPath, 0)
var rs ResultSet
if len(query.read) > 0 {
for _, r := range query.read {
rs, err := r.GetResults(ctx)
for i, r := range query.read {
rp, err := r.GetResults(ctx)
if err != nil {
return nil, err
}
results.Add(rs)
if len(rp) != 0 {
rs.Rows = append(rs.Rows, rp[i].Symbols)
}
}

for _, upd := range query.update {
v, err := upd.Update(ctx, results)
if err != nil {
return nil, err
}
results = v.Get().(ResultSet)
results = v.Get().([]ResultPath)
}
if query.ret == nil {
return RValue{Value: *NewResultSet()}, nil
}
err := project(results.Rows)

// rs = ResultPathToResultSet(results)
err := project(rs.Rows)
if err != nil {
return nil, err
}
Expand All @@ -489,15 +494,15 @@ func (query singlePartQuery) Evaluate(ctx *EvalContext) (Value, error) {
return nil, err
}
if v != nil && v.Get() != nil {
results = v.Get().(ResultSet)
results = v.Get().([]ResultPath)
}
}
if query.ret == nil {
return RValue{Value: *NewResultSet()}, nil
}

if len(results.Rows) > 0 {
for _, row := range results.Rows {
if len(rs.Rows) > 0 {
for _, row := range rs.Rows {
val, err := query.ret.projection.items.Project(ctx, row)
if err != nil {
return nil, err
Expand Down Expand Up @@ -593,7 +598,7 @@ func (pe propertyExpression) Evaluate(ctx *EvalContext) (Value, error) {
return val, nil
}

func (unwind unwind) GetResults(ctx *EvalContext) (ResultSet, error) { panic("Unimplemented") }
func (unwind unwind) GetResults(ctx *EvalContext) ([]ResultPath, error) { panic("Unimplemented") }
func (ls listComprehension) Evaluate(ctx *EvalContext) (Value, error) { panic("Unimplemented") }
func (p patternComprehension) Evaluate(ctx *EvalContext) (Value, error) { panic("Unimplemented") }
func (flt filterAtom) Evaluate(ctx *EvalContext) (Value, error) { panic("Unimplemented") }
Expand Down
8 changes: 4 additions & 4 deletions lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ func (p PatternPart) FindRelative(this *lpg.Node) ([]*lpg.Node, error) {

resultAccumulator := matchResultAccumulator{
evalCtx: ctx,
result: NewResultSet(),
result: []ResultPath{},
}
err = pattern.Run(ctx.graph, symbols, &resultAccumulator)
if err != nil {
return nil, err
}

ret := make([]*lpg.Node, 0, len(resultAccumulator.result.Rows))
for _, row := range resultAccumulator.result.Rows {
rs := ResultPathToResultSet(resultAccumulator.result)
ret := make([]*lpg.Node, 0, len(rs.Rows))
for _, row := range rs.Rows {
t, ok := row["target"]
if !ok {
continue
Expand Down
Loading