Skip to content

Commit e114641

Browse files
authored
feat: add api NewAdapterByDB (#54)
1 parent 293a343 commit e114641

File tree

2 files changed

+65
-11
lines changed

2 files changed

+65
-11
lines changed

adapter.go

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,10 @@ type CasbinRule struct {
4949

5050
// adapter represents the MongoDB adapter for policy storage.
5151
type adapter struct {
52-
clientOption *options.ClientOptions
53-
client *mongo.Client
54-
collection *mongo.Collection
55-
timeout time.Duration
56-
filtered bool
52+
client *mongo.Client
53+
collection *mongo.Collection
54+
timeout time.Duration
55+
filtered bool
5756
}
5857

5958
// finalizer is the destructor for adapter.
@@ -103,9 +102,7 @@ func NewAdapterWithCollectionName(clientOption *options.ClientOptions, databaseN
103102

104103
// baseNewAdapter is a base constructor for Adapter
105104
func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, collectionName string, timeout ...interface{}) (persist.BatchAdapter, error) {
106-
a := &adapter{
107-
clientOption: clientOption,
108-
}
105+
a := &adapter{}
109106
a.filtered = false
110107

111108
if len(timeout) == 1 {
@@ -117,7 +114,7 @@ func baseNewAdapter(clientOption *options.ClientOptions, databaseName string, co
117114
}
118115

119116
// Open the DB, create it if not existed.
120-
err := a.open(databaseName, collectionName)
117+
err := a.open(clientOption, databaseName, collectionName)
121118
if err != nil {
122119
return nil, err
123120
}
@@ -140,11 +137,45 @@ func NewFilteredAdapter(url string) (persist.FilteredAdapter, error) {
140137
return a.(*adapter), nil
141138
}
142139

143-
func (a *adapter) open(databaseName string, collectionName string) error {
140+
type AdapterConfig struct {
141+
DatabaseName string
142+
CollectionName string
143+
Timeout time.Duration
144+
IsFiltered bool
145+
}
146+
147+
func NewAdapterByDB(client *mongo.Client, config *AdapterConfig) (persist.BatchAdapter, error) {
148+
if config == nil {
149+
config = &AdapterConfig{}
150+
}
151+
if config.CollectionName == "" {
152+
config.CollectionName = defaultCollectionName
153+
}
154+
if config.DatabaseName == "" {
155+
config.DatabaseName = defaultDatabaseName
156+
}
157+
if config.Timeout == 0 {
158+
config.Timeout = defaultTimeout
159+
}
160+
161+
a := &adapter{
162+
client: client,
163+
collection: client.Database(config.DatabaseName).Collection(config.CollectionName),
164+
timeout: config.Timeout,
165+
filtered: config.IsFiltered,
166+
}
167+
168+
// Call the destructor when the object is released.
169+
runtime.SetFinalizer(a, finalizer)
170+
171+
return a, nil
172+
}
173+
174+
func (a *adapter) open(clientOption *options.ClientOptions, databaseName string, collectionName string) error {
144175
ctx, cancel := context.WithTimeout(context.TODO(), a.timeout)
145176
defer cancel()
146177

147-
client, err := mongo.Connect(ctx, a.clientOption)
178+
client, err := mongo.Connect(ctx, clientOption)
148179
if err != nil {
149180
return err
150181
}

adapter_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
package mongodbadapter
1616

1717
import (
18+
"context"
1819
"fmt"
20+
"go.mongodb.org/mongo-driver/mongo"
1921
"os"
2022
"strings"
2123
"testing"
@@ -481,6 +483,27 @@ func TestNewAdapterWithCollectionName(t *testing.T) {
481483
}
482484
}
483485

486+
func TestNewAdapterByDB(t *testing.T) {
487+
uri := getDbURL()
488+
if !strings.HasPrefix(uri, "mongodb+srv://") && !strings.HasPrefix(uri, "mongodb://") {
489+
uri = fmt.Sprint("mongodb://" + uri)
490+
}
491+
mongoClientOption := mongooptions.Client().ApplyURI(uri)
492+
client, err := mongo.Connect(context.Background(), mongoClientOption)
493+
if err != nil {
494+
panic(err)
495+
}
496+
497+
config := AdapterConfig{
498+
DatabaseName: "casbin_custom",
499+
CollectionName: "casbin_rule_custom",
500+
}
501+
_, err = NewAdapterByDB(client, &config)
502+
if err != nil {
503+
panic(err)
504+
}
505+
}
506+
484507
func TestUpdatePolicy(t *testing.T) {
485508
initPolicy(t, getDbURL())
486509

0 commit comments

Comments
 (0)