Skip to content

Commit 62ec335

Browse files
committed
Add new endpoints to query the sources of each asset and relation.
1 parent 62e50e0 commit 62ec335

File tree

18 files changed

+339
-177
lines changed

18 files changed

+339
-177
lines changed

cmd/go-graphkb/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ func queryFunc(cmd *cobra.Command, args []string) {
168168

169169
q := knowledge.NewQuerier(Database, Database)
170170

171-
r, err := q.Query(ctx, args[0], false)
171+
r, err := q.Query(ctx, args[0])
172172
if err != nil {
173173
logrus.Fatal(err)
174174
}

internal/database/mariadb.go

Lines changed: 72 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ func (m *MariaDB) Close() error {
595595
}
596596

597597
// Query the database with provided intermediate query representation
598-
func (m *MariaDB) Query(ctx context.Context, sql knowledge.SQLTranslation, includeDataSourceInResults bool) (*knowledge.GraphQueryResult, error) {
598+
func (m *MariaDB) Query(ctx context.Context, sql knowledge.SQLTranslation) (*knowledge.GraphQueryResult, error) {
599599
deadline, ok := ctx.Deadline()
600600
// If there is a deadline, we make sure the query stops right after it has been reached.
601601
if ok {
@@ -610,11 +610,69 @@ func (m *MariaDB) Query(ctx context.Context, sql knowledge.SQLTranslation, inclu
610610
}
611611

612612
res := new(knowledge.GraphQueryResult)
613-
res.Cursor = NewMariaDBCursor(rows, sql.ProjectionTypes, includeDataSourceInResults)
613+
res.Cursor = NewMariaDBCursor(rows, sql.ProjectionTypes)
614614
res.Projections = sql.ProjectionTypes
615615
return res, nil
616616
}
617617

618+
func (m *MariaDB) GetAssetSources(ctx context.Context, ids []string) (map[string][]string, error) {
619+
stmt, err := m.db.PrepareContext(ctx, `
620+
SELECT sources.name FROM sources
621+
INNER JOIN assets_by_source ON sources.id = assets_by_source.source_id
622+
WHERE asset_id = ?`)
623+
if err != nil {
624+
return nil, fmt.Errorf("Unable to prepare statement for retrieving asset sources: %w", err)
625+
}
626+
idsSet := make(map[string][]string)
627+
for _, id := range ids {
628+
row, err := stmt.QueryContext(ctx, id)
629+
if err != nil {
630+
return nil, fmt.Errorf("Unable to retrieve sources for asset id %s: %w", id, err)
631+
}
632+
633+
idsSet[id] = []string{}
634+
635+
var source string
636+
for row.Next() {
637+
err = row.Scan(&source)
638+
if err != nil {
639+
return nil, fmt.Errorf("Unable to scan row of asset source: %w", err)
640+
}
641+
idsSet[id] = append(idsSet[id], source)
642+
}
643+
}
644+
return idsSet, nil
645+
}
646+
647+
func (m *MariaDB) GetRelationSources(ctx context.Context, ids []string) (map[string][]string, error) {
648+
stmt, err := m.db.PrepareContext(ctx, `
649+
SELECT sources.name FROM sources
650+
INNER JOIN relations_by_source ON sources.id = relations_by_source.source_id
651+
WHERE relation_id = ?`)
652+
if err != nil {
653+
return nil, fmt.Errorf("Unable to prepare statement for retrieving relation sources: %w", err)
654+
}
655+
idsSet := make(map[string][]string)
656+
for _, id := range ids {
657+
row, err := stmt.QueryContext(ctx, id)
658+
if err != nil {
659+
return nil, fmt.Errorf("Unable to retrieve sources for relation id %s: %w", id, err)
660+
}
661+
662+
idsSet[id] = []string{}
663+
664+
var source string
665+
for row.Next() {
666+
err = row.Scan(&source)
667+
if err != nil {
668+
return nil, fmt.Errorf("Unable to scan row of relation source: %w", err)
669+
}
670+
idsSet[id] = append(idsSet[id], source)
671+
}
672+
}
673+
return idsSet, nil
674+
}
675+
618676
// SaveSuccessfulQuery log an entry to mark a successful query
619677
func (m *MariaDB) SaveSuccessfulQuery(ctx context.Context, cypher, sql string, duration time.Duration) error {
620678
_, err := m.db.ExecContext(ctx, "INSERT INTO query_history (id, timestamp, query_cypher, query_sql, status, execution_time_ms) VALUES (NULL, CURRENT_TIMESTAMP(), ?, ?, 'SUCCESS', ?)",
@@ -706,16 +764,14 @@ func (m *MariaDB) ListSources(ctx context.Context) (map[string]string, error) {
706764
type MariaDBCursor struct {
707765
*sql.Rows
708766

709-
Projections []knowledge.Projection
710-
IncludeDataSourceInResults bool
767+
Projections []knowledge.Projection
711768
}
712769

713770
// NewMariaDBCursor create a new instance of MariaDBCursor
714-
func NewMariaDBCursor(rows *sql.Rows, projections []knowledge.Projection, includeDataSourceInResults bool) *MariaDBCursor {
771+
func NewMariaDBCursor(rows *sql.Rows, projections []knowledge.Projection) *MariaDBCursor {
715772
return &MariaDBCursor{
716-
Rows: rows,
717-
Projections: projections,
718-
IncludeDataSourceInResults: includeDataSourceInResults,
773+
Rows: rows,
774+
Projections: projections,
719775
}
720776
}
721777

@@ -726,16 +782,6 @@ func (mc *MariaDBCursor) HasMore() bool {
726782

727783
// Read read one more item from the cursor
728784
func (mc *MariaDBCursor) Read(ctx context.Context, doc interface{}) error {
729-
type RelationWithIDAndSource struct {
730-
Source string `json:"source,omitempty"`
731-
knowledge.RelationWithID
732-
}
733-
734-
type AssetWithIDAndSource struct {
735-
Source string `json:"source,omitempty"`
736-
knowledge.AssetWithID
737-
}
738-
739785
var err error
740786
var fArr []string
741787

@@ -789,11 +835,9 @@ func (mc *MariaDBCursor) Read(ctx context.Context, doc interface{}) error {
789835
Key: fmt.Sprintf("%v", reflect.ValueOf(items[1])),
790836
}
791837

792-
awi := AssetWithIDAndSource{
793-
AssetWithID: knowledge.AssetWithID{
794-
ID: reflect.ValueOf(items[0]).String(),
795-
Asset: asset,
796-
},
838+
awi := knowledge.AssetWithID{
839+
ID: reflect.ValueOf(items[0]).String(),
840+
Asset: asset,
797841
}
798842
output[i] = awi
799843
case knowledge.EdgeExprType:
@@ -803,12 +847,11 @@ func (mc *MariaDBCursor) Read(ctx context.Context, doc interface{}) error {
803847
return fmt.Errorf("Unable to get %d items to build an edge: %v", itemCount, err)
804848
}
805849

806-
r := RelationWithIDAndSource{
807-
RelationWithID: knowledge.RelationWithID{
808-
From: reflect.ValueOf(items[1]).String(),
809-
To: reflect.ValueOf(items[2]).String(),
810-
Type: schema.RelationKeyType(fmt.Sprintf("%v", reflect.ValueOf(items[3]))),
811-
},
850+
r := knowledge.RelationWithID{
851+
ID: reflect.ValueOf(items[0]).String(),
852+
From: reflect.ValueOf(items[1]).String(),
853+
To: reflect.ValueOf(items[2]).String(),
854+
Type: schema.RelationKeyType(fmt.Sprintf("%v", reflect.ValueOf(items[3]))),
812855
}
813856
output[i] = r
814857
case knowledge.PropertyExprType:
@@ -820,29 +863,6 @@ func (mc *MariaDBCursor) Read(ctx context.Context, doc interface{}) error {
820863
}
821864
}
822865

823-
if mc.IncludeDataSourceInResults {
824-
for i, pt := range mc.Projections {
825-
switch pt.ExpressionType {
826-
case knowledge.NodeExprType:
827-
a := output[i].(AssetWithIDAndSource)
828-
items, err := q.Get(1)
829-
if err != nil {
830-
return fmt.Errorf("Unable to get source property from queue: %w", err)
831-
}
832-
a.Source = reflect.ValueOf(items[0]).String()
833-
output[i] = a
834-
case knowledge.EdgeExprType:
835-
r := output[i].(RelationWithIDAndSource)
836-
items, err := q.Get(1)
837-
if err != nil {
838-
return fmt.Errorf("Unable to get source property from queue: %w", err)
839-
}
840-
r.Source = reflect.ValueOf(items[0]).String()
841-
output[i] = r
842-
}
843-
}
844-
}
845-
846866
val.Elem().Set(reflect.ValueOf(output))
847867
return nil
848868
}

internal/handlers/handler_query.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ func PostQuery(database knowledge.GraphDB, queryHistorizer history.Historizer) h
1515
return func(w http.ResponseWriter, r *http.Request) {
1616
type QueryRequestBody struct {
1717
Query string `json:"q"`
18-
// If set, the name of the data source is returned with the item
19-
IncludeDataSource bool `json:"include_data_source"`
2018
}
2119

2220
type ColumnType struct {
@@ -30,6 +28,16 @@ func PostQuery(database knowledge.GraphDB, queryHistorizer history.Historizer) h
3028
ExecutionTimeMs time.Duration `json:"execution_time_ms"`
3129
}
3230

31+
type AssetWithIDAndSources struct {
32+
Sources []string `json:"sources,omitempty"`
33+
knowledge.AssetWithID
34+
}
35+
36+
type RelationWithIDAndSources struct {
37+
Sources []string `json:"sources,omitempty"`
38+
knowledge.RelationWithID
39+
}
40+
3341
requestBody := QueryRequestBody{}
3442
err := json.NewDecoder(r.Body).Decode(&requestBody)
3543
if err != nil {
@@ -50,7 +58,7 @@ func PostQuery(database knowledge.GraphDB, queryHistorizer history.Historizer) h
5058
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
5159
defer cancel()
5260

53-
res, err := querier.Query(ctx, requestBody.Query, requestBody.IncludeDataSource)
61+
res, err := querier.Query(ctx, requestBody.Query)
5462
if err != nil {
5563
ReplyWithInternalError(w, err)
5664
return
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package handlers
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"net/http"
8+
9+
"github.com/clems4ever/go-graphkb/internal/knowledge"
10+
)
11+
12+
const MaxIds = 20000
13+
14+
func postAssetSources(database knowledge.GraphDB, fetcherFn func(context.Context, []string) (map[string][]string, error)) http.HandlerFunc {
15+
return func(w http.ResponseWriter, r *http.Request) {
16+
type RequestBody struct {
17+
IDs []string `json:"ids"`
18+
}
19+
20+
type ResponseBody struct {
21+
Results map[string][]string `json:"results"`
22+
}
23+
24+
requestBody := RequestBody{}
25+
err := json.NewDecoder(r.Body).Decode(&requestBody)
26+
if err != nil {
27+
ReplyWithInternalError(w, err)
28+
return
29+
}
30+
31+
if len(requestBody.IDs) > MaxIds {
32+
ReplyWithBadRequest(w, fmt.Errorf("A maximum of %d IDs can be requested in one query", MaxIds))
33+
return
34+
}
35+
36+
idsSet := make(map[string]struct{})
37+
for _, id := range requestBody.IDs {
38+
idsSet[id] = struct{}{}
39+
}
40+
41+
ids := []string{}
42+
for k := range idsSet {
43+
ids = append(ids, k)
44+
}
45+
46+
sources, err := fetcherFn(r.Context(), ids)
47+
if err != nil {
48+
ReplyWithInternalError(w, err)
49+
return
50+
}
51+
52+
response := ResponseBody{
53+
Results: sources,
54+
}
55+
56+
err = json.NewEncoder(w).Encode(response)
57+
if err != nil {
58+
ReplyWithInternalError(w, err)
59+
}
60+
}
61+
}
62+
63+
// PostQueryAssetsSources post endpoint to retrieve the sources of a given set of assets
64+
func PostQueryAssetsSources(database knowledge.GraphDB) http.HandlerFunc {
65+
return postAssetSources(database, database.GetAssetSources)
66+
}
67+
68+
// PostQueryAssetsSources post endpoint to retrieve the sources of a given set of assets
69+
func PostQueryRelationsSources(database knowledge.GraphDB) http.HandlerFunc {
70+
return postAssetSources(database, database.GetRelationSources)
71+
}

internal/knowledge/graphdb.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ type GraphDB interface {
2727
RemoveAssets(ctx context.Context, sourceName string, assets []Asset) error
2828
RemoveRelations(ctx context.Context, sourceName string, relations []Relation) error
2929

30+
GetAssetSources(ctx context.Context, ids []string) (map[string][]string, error)
31+
GetRelationSources(ctx context.Context, ids []string) (map[string][]string, error)
32+
3033
FlushAll(ctx context.Context) error
3134

3235
CountAssets(ctx context.Context) (int64, error)
3336
CountAssetsBySource(ctx context.Context, sourceName string) (int64, error)
3437
CountRelations(ctx context.Context) (int64, error)
3538
CountRelationsBySource(ctx context.Context, sourceName string) (int64, error)
3639

37-
Query(ctx context.Context, query SQLTranslation, includeDataSourceInResults bool) (*GraphQueryResult, error)
40+
Query(ctx context.Context, query SQLTranslation) (*GraphQueryResult, error)
3841
}
3942

4043
// Cursor is a cursor over the results

internal/knowledge/querier.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ func NewQuerier(db GraphDB, historizer history.Historizer) *Querier {
2626
return &Querier{GraphDB: db, historizer: historizer}
2727
}
2828

29-
// Query run a query against the graph DB. If includeDataSourceInResults is set, the data source name is also part of the results.
30-
func (q *Querier) Query(ctx context.Context, queryString string, includeDataSourceInResults bool) (*QuerierResult, error) {
31-
qr, sql, err := q.queryInternal(ctx, queryString, includeDataSourceInResults)
29+
// Query run a query against the graph DB.
30+
func (q *Querier) Query(ctx context.Context, queryString string) (*QuerierResult, error) {
31+
qr, sql, err := q.queryInternal(ctx, queryString)
3232
if err != nil {
3333
saveErr := q.historizer.SaveFailedQuery(ctx, queryString, sql, err)
3434
if saveErr != nil {
@@ -44,7 +44,7 @@ func (q *Querier) Query(ctx context.Context, queryString string, includeDataSour
4444
return qr, nil
4545
}
4646

47-
func (q *Querier) queryInternal(ctx context.Context, cypherQuery string, includeDataSourceInResults bool) (*QuerierResult, string, error) {
47+
func (q *Querier) queryInternal(ctx context.Context, cypherQuery string) (*QuerierResult, string, error) {
4848
s := Statistics{}
4949

5050
var err error
@@ -58,14 +58,14 @@ func (q *Querier) queryInternal(ctx context.Context, cypherQuery string, include
5858
return nil, "", err
5959
}
6060

61-
translation, err := NewSQLQueryTranslator().Translate(queryCypher, includeDataSourceInResults)
61+
translation, err := NewSQLQueryTranslator().Translate(queryCypher)
6262
if err != nil {
6363
return nil, "", err
6464
}
6565

6666
var res *GraphQueryResult
6767
s.Execution = MeasureDuration(func() {
68-
res, err = q.GraphDB.Query(ctx, *translation, includeDataSourceInResults)
68+
res, err = q.GraphDB.Query(ctx, *translation)
6969
})
7070

7171
if err != nil {

0 commit comments

Comments
 (0)