diff --git a/CHANGELOG.md b/CHANGELOG.md index 4690a92f8..f0e2c2c90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - fix: doc comments from object fields should be present in generated GraphQL schema [#630](https://github.com/hypermodeinc/modus/pull/630) - feat: add neo4j support in modus [#636](https://github.com/hypermodeinc/modus/pull/636) +- perf: improve locking code [#637](https://github.com/hypermodeinc/modus/pull/637) ## UNRELEASED - Go SDK diff --git a/cspell.json b/cspell.json index 8af85c4ab..a8d82ae80 100644 --- a/cspell.json +++ b/cspell.json @@ -94,6 +94,7 @@ "ldflags", "legacymodels", "Lessable", + "lestrrat", "linkname", "logit", "logits", @@ -143,6 +144,7 @@ "prereleases", "promhttp", "ptrs", + "puzpuzpuz", "quickstart", "reindex", "renameio", @@ -203,6 +205,7 @@ "Weasley", "Weaviate", "wundergraph", + "xsync", "xxhash", "zenquotes", "zerolog" diff --git a/runtime/collections/collection.go b/runtime/collections/collection.go index bf4850016..93dffc607 100644 --- a/runtime/collections/collection.go +++ b/runtime/collections/collection.go @@ -11,31 +11,33 @@ package collections import ( "fmt" - "sync" "github.com/hypermodeinc/modus/runtime/collections/index/interfaces" + "github.com/puzpuzpuz/xsync/v3" ) type collection struct { - collectionNamespaceMap map[string]interfaces.CollectionNamespace - mu sync.RWMutex + collectionNamespaceMap *xsync.MapOf[string, interfaces.CollectionNamespace] } func newCollection() *collection { return &collection{ - collectionNamespaceMap: map[string]interfaces.CollectionNamespace{}, + collectionNamespaceMap: xsync.NewMapOf[string, interfaces.CollectionNamespace](), } } func (c *collection) getCollectionNamespaceMap() map[string]interfaces.CollectionNamespace { - return c.collectionNamespaceMap + m := make(map[string]interfaces.CollectionNamespace, c.collectionNamespaceMap.Size()) + c.collectionNamespaceMap.Range(func(key string, value interfaces.CollectionNamespace) bool { + m[key] = value + return true + }) + + return m } func (c *collection) findNamespace(namespace string) (interfaces.CollectionNamespace, error) { - c.mu.RLock() - defer c.mu.RUnlock() - - ns, found := c.collectionNamespaceMap[namespace] + ns, found := c.collectionNamespaceMap.Load(namespace) if !found { return nil, errNamespaceNotFound } @@ -43,34 +45,16 @@ func (c *collection) findNamespace(namespace string) (interfaces.CollectionNames } func (c *collection) findOrCreateNamespace(namespace string, index interfaces.CollectionNamespace) (interfaces.CollectionNamespace, error) { - c.mu.RLock() - ns, found := c.collectionNamespaceMap[namespace] - if found { - defer c.mu.RUnlock() - return ns, nil - } - - c.mu.RUnlock() - c.mu.Lock() - defer c.mu.Unlock() - - ns, found = c.collectionNamespaceMap[namespace] - if found { - return ns, nil - } - - c.collectionNamespaceMap[namespace] = index - return index, nil + result, _ := c.collectionNamespaceMap.LoadOrStore(namespace, index) + return result, nil // TODO: remove unused error } func (c *collection) createCollectionNamespace(namespace string, index interfaces.CollectionNamespace) (interfaces.CollectionNamespace, error) { - c.mu.Lock() - defer c.mu.Unlock() - - if _, found := c.collectionNamespaceMap[namespace]; found { + _, found := c.collectionNamespaceMap.Load(namespace) + if found { return nil, fmt.Errorf("namespace with name %s already exists", namespace) } - c.collectionNamespaceMap[namespace] = index + c.collectionNamespaceMap.Store(namespace, index) return index, nil } diff --git a/runtime/collections/factory.go b/runtime/collections/factory.go index 262f40b26..b3c6e173d 100644 --- a/runtime/collections/factory.go +++ b/runtime/collections/factory.go @@ -19,6 +19,7 @@ import ( "github.com/hypermodeinc/modus/runtime/collections/index/interfaces" "github.com/hypermodeinc/modus/runtime/db" "github.com/hypermodeinc/modus/runtime/logger" + "github.com/puzpuzpuz/xsync/v3" ) const collectionFactoryWriteInterval = 1 @@ -40,7 +41,7 @@ func newCollectionFactory() *collectionFactory { return &collectionFactory{ collectionMap: map[string]*collection{ "": { - collectionNamespaceMap: map[string]interfaces.CollectionNamespace{}, + collectionNamespaceMap: xsync.NewMapOf[string, interfaces.CollectionNamespace](), }, }, quit: make(chan struct{}), @@ -86,7 +87,7 @@ func (cf *collectionFactory) readFromPostgres(ctx context.Context) bool { resetTimerFaster := false var err error for _, namespaceCollectionFactory := range cf.collectionMap { - for _, col := range namespaceCollectionFactory.collectionNamespaceMap { + for _, col := range namespaceCollectionFactory.getCollectionNamespaceMap() { resetTimerFaster, err = loadTextsIntoCollection(ctx, col) if err != nil { logger.Err(ctx, err). diff --git a/runtime/collections/in_mem/hnsw/vector_index.go b/runtime/collections/in_mem/hnsw/vector_index.go index d93b3137c..74eb746c1 100644 --- a/runtime/collections/in_mem/hnsw/vector_index.go +++ b/runtime/collections/in_mem/hnsw/vector_index.go @@ -102,11 +102,11 @@ func (ims *HnswVectorIndex) Search(ctx context.Context, query []float32, maxResu func (ims *HnswVectorIndex) SearchWithKey(ctx context.Context, queryKey string, maxResults int, filter index.SearchFilter) (utils.MaxTupleHeap, error) { ims.mu.RLock() + defer ims.mu.RUnlock() query, found := ims.HnswIndex.Lookup(queryKey) if !found { return nil, fmt.Errorf("key not found") } - ims.mu.RUnlock() if query == nil { return nil, nil } diff --git a/runtime/collections/in_mem/sequential/vector_index.go b/runtime/collections/in_mem/sequential/vector_index.go index 315ba8196..76d564079 100644 --- a/runtime/collections/in_mem/sequential/vector_index.go +++ b/runtime/collections/in_mem/sequential/vector_index.go @@ -103,8 +103,8 @@ func (ims *SequentialVectorIndex) Search(ctx context.Context, query []float32, m func (ims *SequentialVectorIndex) SearchWithKey(ctx context.Context, queryKey string, maxResults int, filter index.SearchFilter) (utils.MaxTupleHeap, error) { ims.mu.RLock() + defer ims.mu.RUnlock() query := ims.VectorMap[queryKey] - ims.mu.RUnlock() if query == nil { return nil, nil } diff --git a/runtime/collections/vector.go b/runtime/collections/vector.go index e04fca7ff..aedb640ca 100644 --- a/runtime/collections/vector.go +++ b/runtime/collections/vector.go @@ -98,7 +98,7 @@ func deleteIndexesNotInManifest(ctx context.Context, man *manifest.Manifest) { Msg("Failed to find collection.") continue } - for _, collNs := range col.collectionNamespaceMap { + for _, collNs := range col.getCollectionNamespaceMap() { vectorIndexMap := collNs.GetVectorIndexMap() if vectorIndexMap == nil { continue @@ -160,7 +160,7 @@ func processManifestCollections(ctx context.Context, man *manifest.Manifest) { } } } - for _, collNs := range col.collectionNamespaceMap { + for _, collNs := range col.getCollectionNamespaceMap() { for searchMethodName, searchMethod := range collectionInfo.SearchMethods { vi, err := collNs.GetVectorIndex(ctx, searchMethodName) diff --git a/runtime/db/db.go b/runtime/db/db.go index 49414b162..f9fae3ced 100644 --- a/runtime/db/db.go +++ b/runtime/db/db.go @@ -53,7 +53,7 @@ type runtimePostgresWriter struct { buffer chan inferenceHistory quit chan struct{} done chan struct{} - mu sync.RWMutex + once sync.Once } type inferenceHistory struct { @@ -67,37 +67,39 @@ type inferenceHistory struct { } func (w *runtimePostgresWriter) GetPool(ctx context.Context) (*pgxpool.Pool, error) { - w.mu.RLock() + var initErr error + w.once.Do(func() { + var connStr string + var err error + if secrets.HasSecret("MODUS_DB") { + connStr, err = secrets.GetSecretValue("MODUS_DB") + } else if secrets.HasSecret("HYPERMODE_METADATA_DB") { + // fallback to old secret name + // TODO: remove this after the transition is complete + connStr, err = secrets.GetSecretValue("HYPERMODE_METADATA_DB") + } else { + return + } + + if err != nil { + initErr = err + return + } + + if pool, err := pgxpool.New(ctx, connStr); err != nil { + initErr = err + } else { + w.dbpool = pool + } + }) + if w.dbpool != nil { - defer w.mu.RUnlock() return w.dbpool, nil - } - w.mu.RUnlock() - - w.mu.Lock() - defer w.mu.Unlock() - - var connStr string - var err error - if secrets.HasSecret("MODUS_DB") { - connStr, err = secrets.GetSecretValue("MODUS_DB") - } else if secrets.HasSecret("HYPERMODE_METADATA_DB") { - // fallback to old secret name - // TODO: remove this after the transition is complete - connStr, err = secrets.GetSecretValue("HYPERMODE_METADATA_DB") + } else if initErr != nil { + return nil, initErr } else { return nil, errDbNotConfigured } - if err != nil { - return nil, err - } - - if pool, err := pgxpool.New(ctx, connStr); err != nil { - return nil, err - } else { - w.dbpool = pool - return pool, nil - } } func (w *runtimePostgresWriter) Write(data inferenceHistory) { diff --git a/runtime/dgraphclient/registry.go b/runtime/dgraphclient/registry.go index 6d8b46c01..7e720e96f 100644 --- a/runtime/dgraphclient/registry.go +++ b/runtime/dgraphclient/registry.go @@ -14,7 +14,6 @@ import ( "crypto/x509" "fmt" "strings" - "sync" "github.com/hypermodeinc/modus/lib/manifest" "github.com/hypermodeinc/modus/runtime/manifestdata" @@ -26,13 +25,13 @@ import ( "github.com/dgraph-io/dgo/v240" "github.com/dgraph-io/dgo/v240/protos/api" + "github.com/puzpuzpuz/xsync/v3" ) var dgr = newDgraphRegistry() type dgraphRegistry struct { - sync.RWMutex - dgraphConnectorCache map[string]*dgraphConnector + cache *xsync.MapOf[string, *dgraphConnector] } type authCreds struct { @@ -51,35 +50,41 @@ func (a *authCreds) RequireTransportSecurity() bool { func newDgraphRegistry() *dgraphRegistry { return &dgraphRegistry{ - dgraphConnectorCache: make(map[string]*dgraphConnector), + cache: xsync.NewMapOf[string, *dgraphConnector](), } } func ShutdownConns() { - dgr.Lock() - defer dgr.Unlock() - for _, ds := range dgr.dgraphConnectorCache { - ds.conn.Close() - } - clear(dgr.dgraphConnectorCache) + dgr.cache.Range(func(key string, _ *dgraphConnector) bool { + if connector, ok := dgr.cache.LoadAndDelete(key); ok { + connector.conn.Close() + } + return true + }) } func (dr *dgraphRegistry) getDgraphConnector(ctx context.Context, dgName string) (*dgraphConnector, error) { - dr.RLock() - ds, ok := dr.dgraphConnectorCache[dgName] - dr.RUnlock() - if ok { - return ds, nil - } - - dr.Lock() - defer dr.Unlock() + var creationErr error + ds, _ := dr.cache.LoadOrCompute(dgName, func() *dgraphConnector { + conn, err := createConnector(ctx, dgName) + if err != nil { + creationErr = err + return nil + } + return conn + }) - if ds, ok := dr.dgraphConnectorCache[dgName]; ok { - return ds, nil + if creationErr != nil { + dr.cache.Delete(dgName) + return nil, creationErr } - info, ok := manifestdata.GetManifest().Connections[dgName] + return ds, nil +} + +func createConnector(ctx context.Context, dgName string) (*dgraphConnector, error) { + man := manifestdata.GetManifest() + info, ok := man.Connections[dgName] if !ok { return nil, fmt.Errorf("dgraph connection [%s] not found", dgName) } @@ -130,10 +135,10 @@ func (dr *dgraphRegistry) getDgraphConnector(ctx context.Context, dgName string) return nil, err } - ds = &dgraphConnector{ + ds := &dgraphConnector{ conn: conn, dgClient: dgo.NewDgraphClient(api.NewDgraphClient(conn)), } - dr.dgraphConnectorCache[dgName] = ds + return ds, nil } diff --git a/runtime/go.mod b/runtime/go.mod index a73e98319..9e9aef097 100644 --- a/runtime/go.mod +++ b/runtime/go.mod @@ -35,6 +35,7 @@ require ( github.com/neo4j/neo4j-go-driver/v5 v5.27.0 github.com/prometheus/client_golang v1.20.5 github.com/prometheus/common v0.60.1 + github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/rs/cors v1.11.1 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 diff --git a/runtime/go.sum b/runtime/go.sum index 261ebf559..59b07d6b5 100644 --- a/runtime/go.sum +++ b/runtime/go.sum @@ -230,6 +230,8 @@ github.com/prometheus/common v0.60.1 h1:FUas6GcOw66yB/73KC+BOZoFJmbo/1pojoILArPA github.com/prometheus/common v0.60.1/go.mod h1:h0LYf1R1deLSKtD4Vdg8gy4RuOvENW2J/h19V5NADQw= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4= +github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/r3labs/sse/v2 v2.10.0 h1:hFEkLLFY4LDifoHdiCN/LlGBAdVJYsANaLqNYa1l/v0= github.com/r3labs/sse/v2 v2.10.0/go.mod h1:Igau6Whc+F17QUgML1fYe1VPZzTV6EMCnYktEmkNJ7I= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= diff --git a/runtime/neo4jclient/registry.go b/runtime/neo4jclient/registry.go index a073a1e23..b6ee19738 100644 --- a/runtime/neo4jclient/registry.go +++ b/runtime/neo4jclient/registry.go @@ -12,58 +12,58 @@ package neo4jclient import ( "context" "fmt" - "sync" "github.com/hypermodeinc/modus/lib/manifest" "github.com/hypermodeinc/modus/runtime/manifestdata" "github.com/hypermodeinc/modus/runtime/secrets" + "github.com/neo4j/neo4j-go-driver/v5/neo4j" + "github.com/puzpuzpuz/xsync/v3" ) var n4j = newNeo4jRegistry() type neo4jRegistry struct { - sync.RWMutex - neo4jDriverCache map[string]neo4j.DriverWithContext + cache *xsync.MapOf[string, neo4j.DriverWithContext] } func newNeo4jRegistry() *neo4jRegistry { return &neo4jRegistry{ - neo4jDriverCache: make(map[string]neo4j.DriverWithContext), + cache: xsync.NewMapOf[string, neo4j.DriverWithContext](), } } func CloseDrivers(ctx context.Context) { - n4j.Lock() - defer n4j.Unlock() - - removed := make([]string, 0) - - for key, driver := range n4j.neo4jDriverCache { - driver.Close(ctx) - removed = append(removed, key) - } - - for _, key := range removed { - delete(n4j.neo4jDriverCache, key) - } + n4j.cache.Range(func(key string, _ neo4j.DriverWithContext) bool { + if driver, ok := n4j.cache.LoadAndDelete(key); ok { + driver.Close(ctx) + } + return true + }) } func (nr *neo4jRegistry) getDriver(ctx context.Context, n4jName string) (neo4j.DriverWithContext, error) { - nr.RLock() - ds, ok := nr.neo4jDriverCache[n4jName] - nr.RUnlock() - if ok { - return ds, nil - } - nr.Lock() - defer nr.Unlock() + var creationErr error + driver, _ := n4j.cache.LoadOrCompute(n4jName, func() neo4j.DriverWithContext { + driver, err := createDriver(ctx, n4jName) + if err != nil { + creationErr = err + return nil + } + return driver + }) - if driver, ok := nr.neo4jDriverCache[n4jName]; ok { - return driver, nil + if creationErr != nil { + n4j.cache.Delete(n4jName) + return nil, creationErr } - info, ok := manifestdata.GetManifest().Connections[n4jName] + return driver, nil +} + +func createDriver(ctx context.Context, n4jName string) (neo4j.DriverWithContext, error) { + man := manifestdata.GetManifest() + info, ok := man.Connections[n4jName] if !ok { return nil, fmt.Errorf("Neo4j connection [%s] not found", n4jName) } @@ -100,8 +100,6 @@ func (nr *neo4jRegistry) getDriver(ctx context.Context, n4jName string) (neo4j.D return nil, err } - nr.neo4jDriverCache[n4jName] = driver - return driver, nil } diff --git a/runtime/sqlclient/pooling.go b/runtime/sqlclient/pooling.go index 3709ea27c..8dd446989 100644 --- a/runtime/sqlclient/pooling.go +++ b/runtime/sqlclient/pooling.go @@ -16,67 +16,63 @@ import ( "github.com/hypermodeinc/modus/lib/manifest" "github.com/hypermodeinc/modus/runtime/manifestdata" "github.com/hypermodeinc/modus/runtime/secrets" - "github.com/jackc/pgx/v5/pgxpool" ) // ShutdownPGPools shuts down all the PostgreSQL connection pools. func ShutdownPGPools() { - dsr.Lock() - defer dsr.Unlock() + dsr.cache.Range(func(key string, _ *postgresqlDS) bool { + if ds, ok := dsr.cache.LoadAndDelete(key); ok { + ds.pool.Close() + } + return true + }) +} + +func (r *dsRegistry) getPostgresDS(ctx context.Context, dsName string) (*postgresqlDS, error) { + var creationErr error + ds, _ := r.cache.LoadOrCompute(dsName, func() *postgresqlDS { + ds, err := createDS(ctx, dsName) + if err != nil { + creationErr = err + return nil + } + return ds + }) - for _, ds := range dsr.pgCache { - ds.pool.Close() + if creationErr != nil { + r.cache.Delete(dsName) + return nil, creationErr } - clear(dsr.pgCache) + return ds, nil } -func (r *dsRegistry) getPGPool(ctx context.Context, dsName string) (*postgresqlDS, error) { - // fast path - r.RLock() - ds, ok := r.pgCache[dsName] - r.RUnlock() - if ok { - return ds, nil +func createDS(ctx context.Context, dsName string) (*postgresqlDS, error) { + man := manifestdata.GetManifest() + info, ok := man.Connections[dsName] + if !ok { + return nil, fmt.Errorf("postgresql connection [%s] not found", dsName) } - // slow path - r.Lock() - defer r.Unlock() - - // we do another lookup to make sure any other goroutine didn't create it - if ds, ok := r.pgCache[dsName]; ok { - return ds, nil + if info.ConnectionType() != manifest.ConnectionTypePostgresql { + return nil, fmt.Errorf("[%s] is not a postgresql connection", dsName) } - for name, info := range manifestdata.GetManifest().Connections { - if name != dsName { - continue - } - - if info.ConnectionType() != manifest.ConnectionTypePostgresql { - return nil, fmt.Errorf("[%s] is not a postgresql connection", dsName) - } - - conf := info.(manifest.PostgresqlConnectionInfo) - if conf.ConnStr == "" { - return nil, fmt.Errorf("postgresql connection [%s] has empty connString", dsName) - } - - fullConnStr, err := secrets.ApplySecretsToString(ctx, info, conf.ConnStr) - if err != nil { - return nil, fmt.Errorf("failed to apply secrets to connection string for connection [%s]: %w", dsName, err) - } + conf := info.(manifest.PostgresqlConnectionInfo) + if conf.ConnStr == "" { + return nil, fmt.Errorf("postgresql connection [%s] has empty connString", dsName) + } - dbpool, err := pgxpool.New(ctx, fullConnStr) - if err != nil { - return nil, fmt.Errorf("failed to connect to postgres connection [%s]: %w", dsName, err) - } + connStr, err := secrets.ApplySecretsToString(ctx, info, conf.ConnStr) + if err != nil { + return nil, fmt.Errorf("failed to apply secrets to connection string for connection [%s]: %w", dsName, err) + } - r.pgCache[dsName] = &postgresqlDS{pool: dbpool} - return r.pgCache[dsName], nil + pool, err := pgxpool.New(ctx, connStr) + if err != nil { + return nil, fmt.Errorf("failed to connect to postgres connection [%s]: %w", dsName, err) } - return nil, fmt.Errorf("postgresql connection [%s] not found", dsName) + return &postgresqlDS{pool}, nil } diff --git a/runtime/sqlclient/registry.go b/runtime/sqlclient/registry.go index b58bd2d71..7aa013688 100644 --- a/runtime/sqlclient/registry.go +++ b/runtime/sqlclient/registry.go @@ -9,19 +9,16 @@ package sqlclient -import ( - "sync" -) +import "github.com/puzpuzpuz/xsync/v3" var dsr = newDSRegistry() type dsRegistry struct { - sync.RWMutex - pgCache map[string]*postgresqlDS + cache *xsync.MapOf[string, *postgresqlDS] } func newDSRegistry() *dsRegistry { return &dsRegistry{ - pgCache: make(map[string]*postgresqlDS), + cache: xsync.NewMapOf[string, *postgresqlDS](), } } diff --git a/runtime/sqlclient/sqlclient.go b/runtime/sqlclient/sqlclient.go index 5eb6c8d13..4f25751c9 100644 --- a/runtime/sqlclient/sqlclient.go +++ b/runtime/sqlclient/sqlclient.go @@ -60,7 +60,7 @@ func ExecuteQuery(ctx context.Context, connectionName, dbType, statement, params func doExecuteQuery(ctx context.Context, dsName, dsType, stmt string, params []any) (*dbResponse, error) { switch dsType { case "postgresql": - ds, err := dsr.getPGPool(ctx, dsName) + ds, err := dsr.getPostgresDS(ctx, dsName) if err != nil { return nil, err }