Skip to content

Commit d024565

Browse files
committed
Add generic
1 parent 3c2ceb3 commit d024565

File tree

13 files changed

+499
-125
lines changed

13 files changed

+499
-125
lines changed

adapter/adapter.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package adapter
2+
3+
import (
4+
"context"
5+
"github.com/gocql/gocql"
6+
"reflect"
7+
"strings"
8+
9+
q "github.com/core-go/cassandra"
10+
)
11+
12+
type Adapter[T any] struct {
13+
DB *gocql.ClusterConfig
14+
Table string
15+
Schema *q.Schema
16+
JsonColumnMap map[string]string
17+
versionField string
18+
versionIndex int
19+
versionDBField string
20+
}
21+
22+
func NewAdapter[T any](db *gocql.ClusterConfig, tableName string) (*Adapter[T], error) {
23+
return NewAdapterWithVersion[T](db, tableName, "")
24+
}
25+
func NewAdapterWithVersion[T any](db *gocql.ClusterConfig, tableName string, versionField string) (*Adapter[T], error) {
26+
var t T
27+
modelType := reflect.TypeOf(t)
28+
if modelType.Kind() == reflect.Ptr {
29+
modelType = modelType.Elem()
30+
}
31+
schema := q.CreateSchema(modelType)
32+
jsonColumnMapT := q.MakeJsonColumnMap(modelType)
33+
jsonColumnMap := q.GetWritableColumns(schema.Fields, jsonColumnMapT)
34+
adapter := &Adapter[T]{DB: db, Table: tableName, Schema: schema, JsonColumnMap: jsonColumnMap, versionField: "", versionIndex: -1}
35+
if len(versionField) > 0 {
36+
index := q.FindFieldIndex(modelType, versionField)
37+
if index >= 0 {
38+
_, dbFieldName, exist := q.GetFieldByIndex(modelType, index)
39+
if !exist {
40+
dbFieldName = strings.ToLower(versionField)
41+
}
42+
adapter.versionField = versionField
43+
adapter.versionIndex = index
44+
adapter.versionDBField = dbFieldName
45+
}
46+
}
47+
return adapter, nil
48+
}
49+
50+
func (a *Adapter[T]) Create(ctx context.Context, model T) (int64, error) {
51+
query, args := q.BuildToInsertWithVersion(a.Table, model, a.versionIndex, false, a.Schema)
52+
ses, err := a.DB.CreateSession()
53+
if err != nil {
54+
return -1, err
55+
}
56+
defer ses.Close()
57+
er2 := q.Exec(ses, query, args...)
58+
if er2 != nil {
59+
return 0, er2
60+
}
61+
return 1, nil
62+
}
63+
func (a *Adapter[T]) Update(ctx context.Context, model T) (int64, error) {
64+
query, args := q.BuildToUpdateWithVersion(a.Table, model, a.versionIndex, a.Schema)
65+
ses, err := a.DB.CreateSession()
66+
if err != nil {
67+
return -1, err
68+
}
69+
defer ses.Close()
70+
er2 := q.Exec(ses, query, args...)
71+
if er2 != nil {
72+
return 0, er2
73+
}
74+
return 1, nil
75+
}
76+
func (a *Adapter[T]) Save(ctx context.Context, model T) (int64, error) {
77+
query, args := q.BuildToInsertWithVersion(a.Table, model, a.versionIndex, true, a.Schema)
78+
ses, err := a.DB.CreateSession()
79+
if err != nil {
80+
return -1, err
81+
}
82+
defer ses.Close()
83+
er2 := q.Exec(ses, query, args...)
84+
if er2 != nil {
85+
return 0, er2
86+
}
87+
return 1, nil
88+
}
89+
func (a *Adapter[T]) Patch(ctx context.Context, model map[string]interface{}) (int64, error) {
90+
dbColumnMap := q.JSONToColumns(model, a.JsonColumnMap)
91+
query, values := q.BuildToPatchWithVersion(a.Table, dbColumnMap, a.Schema.SKeys, a.versionDBField)
92+
ses, err := a.DB.CreateSession()
93+
if err != nil {
94+
return -1, err
95+
}
96+
defer ses.Close()
97+
er2 := q.Exec(ses, query, values...)
98+
if er2 == nil {
99+
return 1, er2
100+
}
101+
return 0, er2
102+
}

adapter/generic_adapter.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package adapter
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"fmt"
8+
"github.com/gocql/gocql"
9+
"reflect"
10+
11+
q "github.com/core-go/cassandra"
12+
)
13+
14+
type GenericAdapter[T any, K any] struct {
15+
*Adapter[*T]
16+
Map map[string]int
17+
Fields string
18+
Keys []string
19+
IdMap bool
20+
}
21+
22+
func NewGenericAdapter[T any, K any](db *gocql.ClusterConfig, tableName string) (*GenericAdapter[T, K], error) {
23+
return NewGenericAdapterWithVersion[T, K](db, tableName, "")
24+
}
25+
func NewGenericAdapterWithVersion[T any, K any](db *gocql.ClusterConfig, tableName string, versionField string) (*GenericAdapter[T, K], error) {
26+
adapter, err := NewAdapterWithVersion[*T](db, tableName, versionField)
27+
if err != nil {
28+
return nil, err
29+
}
30+
31+
var t T
32+
modelType := reflect.TypeOf(t)
33+
if modelType.Kind() != reflect.Struct {
34+
return nil, errors.New("T must be a struct")
35+
}
36+
37+
_, primaryKeys := q.FindPrimaryKeys(modelType)
38+
var k K
39+
kType := reflect.TypeOf(k)
40+
idMap := false
41+
if len(primaryKeys) > 1 {
42+
if kType.Kind() == reflect.Map {
43+
idMap = true
44+
} else if kType.Kind() != reflect.Struct {
45+
return nil, errors.New("for composite keys, K must be a struct or a map")
46+
}
47+
}
48+
49+
fieldsIndex, err := q.GetColumnIndexes(modelType)
50+
if err != nil {
51+
return nil, err
52+
}
53+
fields := q.BuildFieldsBySchema(adapter.Schema)
54+
return &GenericAdapter[T, K]{adapter, fieldsIndex, fields, primaryKeys, idMap}, nil
55+
}
56+
func (a *GenericAdapter[T, K]) All(ctx context.Context) ([]T, error) {
57+
var objs []T
58+
query := fmt.Sprintf("select %s from %s", a.Fields, a.Table)
59+
ses, err := a.DB.CreateSession()
60+
if err != nil {
61+
return objs, err
62+
}
63+
defer ses.Close()
64+
err = q.Query(ses, a.Map, &objs, query)
65+
return objs, err
66+
}
67+
func toMap(obj interface{}) (map[string]interface{}, error) {
68+
b, err := json.Marshal(obj)
69+
if err != nil {
70+
return nil, err
71+
}
72+
im := make(map[string]interface{})
73+
er2 := json.Unmarshal(b, &im)
74+
return im, er2
75+
}
76+
func (a *GenericAdapter[T, K]) getId(k K) (interface{}, error) {
77+
if len(a.Keys) >= 2 && !a.IdMap {
78+
ri, err := toMap(k)
79+
return ri, err
80+
} else {
81+
return k, nil
82+
}
83+
}
84+
func (a *GenericAdapter[T, K]) Load(ctx context.Context, id K) (*T, error) {
85+
ip, er0 := a.getId(id)
86+
if er0 != nil {
87+
return nil, er0
88+
}
89+
var objs []T
90+
queryAll := fmt.Sprintf("select %s from %s ", a.Fields, a.Table)
91+
query, args := q.BuildFindById(queryAll, ip, a.JsonColumnMap, a.Schema.SKeys)
92+
ses, err := a.DB.CreateSession()
93+
if err != nil {
94+
return nil, err
95+
}
96+
defer ses.Close()
97+
err = q.Query(ses, a.Map, &objs, query, args...)
98+
if len(objs) > 0 {
99+
return &objs[0], nil
100+
}
101+
return nil, nil
102+
}
103+
func (a *GenericAdapter[T, K]) Exist(ctx context.Context, id K) (bool, error) {
104+
ip, er0 := a.getId(id)
105+
if er0 != nil {
106+
return false, er0
107+
}
108+
query := fmt.Sprintf("select %s from %s ", a.Schema.SColumns[0], a.Table)
109+
query1, args := q.BuildFindById(query, ip, a.JsonColumnMap, a.Schema.SKeys)
110+
ses, err := a.DB.CreateSession()
111+
if err != nil {
112+
return false, err
113+
}
114+
defer ses.Close()
115+
res, err := q.QueryMap(ses, nil, query1, args...)
116+
if err != nil {
117+
return false, err
118+
}
119+
if len(res) > 0 {
120+
return true, nil
121+
}
122+
return false, nil
123+
}
124+
func (a *GenericAdapter[T, K]) Delete(ctx context.Context, id K) (int64, error) {
125+
ip, er0 := a.getId(id)
126+
if er0 != nil {
127+
return -1, er0
128+
}
129+
query := fmt.Sprintf("delete from %s ", a.Table)
130+
query1, args := q.BuildFindById(query, ip, a.JsonColumnMap, a.Schema.SKeys)
131+
ses, err := a.DB.CreateSession()
132+
if err != nil {
133+
return 0, err
134+
}
135+
defer ses.Close()
136+
er2 := q.Exec(ses, query1, args...)
137+
if er2 == nil {
138+
return 1, er2
139+
}
140+
return 0, er2
141+
}

adapter/search.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package adapter
2+
3+
import (
4+
"context"
5+
"github.com/gocql/gocql"
6+
"reflect"
7+
8+
q "github.com/core-go/cassandra"
9+
)
10+
11+
type SearchAdapter[T any, K any, F any] struct {
12+
*GenericAdapter[T, K]
13+
BuildQuery func(F) (string, []interface{})
14+
Mp func(*T)
15+
Map map[string]int
16+
}
17+
18+
func NewSearchAdapter[T any, K any, F any](db *gocql.ClusterConfig, table string, buildQuery func(F) (string, []interface{}), options ...func(*T)) (*SearchAdapter[T, K, F], error) {
19+
return NewSearchAdapterWithVersion[T, K, F](db, table, buildQuery, "", options...)
20+
}
21+
func NewSearchAdapterWithVersion[T any, K any, F any](db *gocql.ClusterConfig, table string, buildQuery func(F) (string, []interface{}), versionField string, opts ...func(*T)) (*SearchAdapter[T, K, F], error) {
22+
adapter, err := NewGenericAdapterWithVersion[T, K](db, table, versionField)
23+
if err != nil {
24+
return nil, err
25+
}
26+
var mp func(*T)
27+
if len(opts) >= 1 {
28+
mp = opts[0]
29+
}
30+
var t T
31+
modelType := reflect.TypeOf(t)
32+
if modelType.Kind() == reflect.Ptr {
33+
modelType = modelType.Elem()
34+
}
35+
fieldsIndex, err := q.GetColumnIndexes(modelType)
36+
if err != nil {
37+
return nil, err
38+
}
39+
builder := &SearchAdapter[T, K, F]{GenericAdapter: adapter, Map: fieldsIndex, BuildQuery: buildQuery, Mp: mp}
40+
return builder, nil
41+
}
42+
43+
func (b *SearchAdapter[T, K, F]) Search(ctx context.Context, filter F, limit int64, next string) ([]T, string, error) {
44+
var objs []T
45+
sql, params := b.BuildQuery(filter)
46+
ses, err := b.DB.CreateSession()
47+
defer ses.Close()
48+
49+
if err != nil {
50+
return objs, "", err
51+
}
52+
nextPageToken, er2 := q.QueryWithMap(ses, b.Map, &objs, sql, params, limit, next)
53+
if b.Mp != nil {
54+
l := len(objs)
55+
for i := 0; i < l; i++ {
56+
b.Mp(&objs[i])
57+
}
58+
}
59+
return objs, nextPageToken, er2
60+
}

batch/batch_inserter.go

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,59 +8,49 @@ import (
88
"github.com/gocql/gocql"
99
)
1010

11-
type BatchInserter struct {
11+
type BatchInserter[T any] struct {
1212
db *gocql.ClusterConfig
13-
tableName string
14-
Map func(ctx context.Context, model interface{}) (interface{}, error)
13+
table string
14+
Map func(*T)
1515
VersionIndex int
1616
Schema *c.Schema
1717
}
1818

19-
func NewBatchInserter(db *gocql.ClusterConfig, tableName string, modelType reflect.Type, options ...func(context.Context, interface{}) (interface{}, error)) *BatchInserter {
20-
var mp func(context.Context, interface{}) (interface{}, error)
19+
func NewBatchInserter[T any](db *gocql.ClusterConfig, table string, options ...func(*T)) *BatchInserter[T] {
20+
var mp func(*T)
2121
if len(options) > 0 && options[0] != nil {
2222
mp = options[0]
2323
}
24-
return NewBatchInserterWithVersion(db, tableName, modelType, mp)
24+
return NewBatchInserterWithVersion[T](db, table, mp)
2525
}
26-
func NewBatchInserterWithVersion(db *gocql.ClusterConfig, tableName string, modelType reflect.Type, mp func(context.Context, interface{}) (interface{}, error), options ...int) *BatchInserter {
26+
func NewBatchInserterWithVersion[T any](db *gocql.ClusterConfig, table string, mp func(*T), options ...int) *BatchInserter[T] {
27+
var t T
28+
modelType := reflect.TypeOf(t)
29+
if modelType.Kind() != reflect.Struct {
30+
panic("T must be a struct")
31+
}
2732
versionIndex := -1
2833
if len(options) > 0 && options[0] >= 0 {
2934
versionIndex = options[0]
3035
}
3136
schema := c.CreateSchema(modelType)
32-
return &BatchInserter{db: db, tableName: tableName, Schema: schema, VersionIndex: versionIndex, Map: mp}
37+
return &BatchInserter[T]{db: db, table: table, Schema: schema, VersionIndex: versionIndex, Map: mp}
3338
}
34-
func (w *BatchInserter) Write(ctx context.Context, models interface{}) ([]int, []int, error) {
35-
successIndices := make([]int, 0)
36-
failIndices := make([]int, 0)
37-
var models2 interface{}
38-
var er0 error
39+
func (w *BatchInserter[T]) Write(ctx context.Context, models []T) error {
40+
l := len(models)
41+
if l == 0 {
42+
return nil
43+
}
3944
if w.Map != nil {
40-
models2, er0 = c.MapModels(ctx, models, w.Map)
41-
if er0 != nil {
42-
s0 := reflect.ValueOf(models2)
43-
_, er0b := c.InterfaceSlice(models2)
44-
failIndices = c.ToArrayIndex(s0, failIndices)
45-
return successIndices, failIndices, er0b
45+
for i := 0; i < l; i++ {
46+
w.Map(&models[i])
4647
}
47-
} else {
48-
models2 = models
4948
}
5049
session, er0 := w.db.CreateSession()
5150
if er0 != nil {
52-
return successIndices, failIndices, er0
51+
return er0
5352
}
5453
defer session.Close()
55-
_, err := c.InsertBatchWithVersion(ctx, session, w.tableName, models2, w.VersionIndex, w.Schema)
56-
s := reflect.ValueOf(models)
57-
if err == nil {
58-
// Return full success
59-
successIndices = c.ToArrayIndex(s, successIndices)
60-
return successIndices, failIndices, err
61-
} else {
62-
// Return full fail
63-
failIndices = c.ToArrayIndex(s, failIndices)
64-
}
65-
return successIndices, failIndices, err
54+
_, err := c.InsertBatchWithSizeAndVersion(ctx, session, l, w.table, models, w.VersionIndex, w.Schema)
55+
return err
6656
}

0 commit comments

Comments
 (0)