Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .mise.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[tools]
go = "1.25.3"
go = "latest"
59 changes: 52 additions & 7 deletions cached_source.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package pubgrub

import "fmt"
import (
"fmt"
"sync"
)

// CachedSource wraps a Source and caches GetVersions and GetDependencies calls
// to improve performance when the same queries are made repeatedly.
Expand All @@ -23,6 +26,8 @@ import "fmt"
type CachedSource struct {
source Source

mu sync.Mutex

// Cache for GetVersions results
versionsCache map[Name][]Version
versionsCalls int
Expand All @@ -45,27 +50,37 @@ func NewCachedSource(source Source) *CachedSource {

// GetVersions returns all available versions for a package, caching the result.
func (c *CachedSource) GetVersions(name Name) ([]Version, error) {
c.mu.Lock()
c.versionsCalls++

// Check cache first
if versions, ok := c.versionsCache[name]; ok {
c.versionsCacheHits++
return versions, nil
out := cloneVersions(versions)
c.mu.Unlock()
Comment on lines +59 to +60
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mutex is unlocked after cloning but before returning. If cloning is expensive, consider cloning after unlocking to reduce lock contention. Move c.mu.Unlock() to line 58 (before cloning).

Suggested change
out := cloneVersions(versions)
c.mu.Unlock()
c.mu.Unlock()
out := cloneVersions(versions)

Copilot uses AI. Check for mistakes.
return out, nil
}
c.mu.Unlock()

// Cache miss - fetch from underlying source
versions, err := c.source.GetVersions(name)
if err != nil {
return nil, err
}

cloned := cloneVersions(versions)

// Store in cache
c.versionsCache[name] = versions
return versions, nil
c.mu.Lock()
c.versionsCache[name] = cloned
c.mu.Unlock()

return cloneVersions(cloned), nil
}

// GetDependencies returns dependencies for a specific package version, caching the result.
func (c *CachedSource) GetDependencies(name Name, version Version) ([]Term, error) {
c.mu.Lock()
c.depsCalls++

// Create cache key from name and version
Expand All @@ -74,18 +89,26 @@ func (c *CachedSource) GetDependencies(name Name, version Version) ([]Term, erro
// Check cache first
if deps, ok := c.depsCache[key]; ok {
c.depsCacheHits++
return deps, nil
out := cloneTerms(deps)
c.mu.Unlock()
Comment on lines +92 to +93
Copy link

Copilot AI Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mutex is unlocked after cloning but before returning. If cloning is expensive, consider cloning after unlocking to reduce lock contention. Move c.mu.Unlock() to line 91 (before cloning).

Suggested change
out := cloneTerms(deps)
c.mu.Unlock()
c.mu.Unlock()
out := cloneTerms(deps)

Copilot uses AI. Check for mistakes.
return out, nil
}
c.mu.Unlock()

// Cache miss - fetch from underlying source
deps, err := c.source.GetDependencies(name, version)
if err != nil {
return nil, err
}

cloned := cloneTerms(deps)

// Store in cache
c.depsCache[key] = deps
return deps, nil
c.mu.Lock()
c.depsCache[key] = cloned
c.mu.Unlock()

return cloneTerms(cloned), nil
}

// CacheStats returns statistics about cache performance.
Expand All @@ -105,6 +128,7 @@ type CacheStats struct {

// GetCacheStats returns cache performance statistics.
func (c *CachedSource) GetCacheStats() CacheStats {
c.mu.Lock()
stats := CacheStats{
VersionsCalls: c.versionsCalls,
VersionsCacheHits: c.versionsCacheHits,
Expand All @@ -113,6 +137,7 @@ func (c *CachedSource) GetCacheStats() CacheStats {
TotalCalls: c.versionsCalls + c.depsCalls,
TotalCacheHits: c.versionsCacheHits + c.depsCacheHits,
}
c.mu.Unlock()

if stats.VersionsCalls > 0 {
stats.VersionsHitRate = float64(stats.VersionsCacheHits) / float64(stats.VersionsCalls)
Expand All @@ -131,10 +156,30 @@ func (c *CachedSource) GetCacheStats() CacheStats {

// ClearCache clears all cached data while preserving the underlying source.
func (c *CachedSource) ClearCache() {
c.mu.Lock()
c.versionsCache = make(map[Name][]Version)
c.depsCache = make(map[string][]Term)
c.versionsCalls = 0
c.versionsCacheHits = 0
c.depsCalls = 0
c.depsCacheHits = 0
c.mu.Unlock()
}

func cloneVersions(in []Version) []Version {
if len(in) == 0 {
return nil
}
out := make([]Version, len(in))
copy(out, in)
return out
}

func cloneTerms(in []Term) []Term {
if len(in) == 0 {
return nil
}
out := make([]Term, len(in))
copy(out, in)
return out
}
63 changes: 63 additions & 0 deletions cached_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,66 @@ func TestCachedSource_Integration(t *testing.T) {
t.Error("expected some calls to be made")
}
}

func TestCachedSource_ReturnsCopiesForVersions(t *testing.T) {
inner := &InMemorySource{}
v1 := SimpleVersion("1.0.0")
v2 := SimpleVersion("2.0.0")
pkg := MakeName("pkg")

inner.AddPackage(pkg, v1, nil)
inner.AddPackage(pkg, v2, nil)

cached := NewCachedSource(inner)

versions, err := cached.GetVersions(pkg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(versions) != 2 {
t.Fatalf("expected 2 versions, got %d", len(versions))
}

versions[0] = SimpleVersion("9.9.9")

versions2, err := cached.GetVersions(pkg)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if versions2[0].String() == "9.9.9" {
t.Fatalf("cached versions slice should not be mutable by callers")
}
}

func TestCachedSource_ReturnsCopiesForDependencies(t *testing.T) {
inner := &InMemorySource{}
pkg := MakeName("pkg")
v1 := SimpleVersion("1.0.0")
dep := MakeName("dep")

inner.AddPackage(pkg, v1, []Term{
NewTerm(dep, EqualsCondition{Version: v1}),
})

cached := NewCachedSource(inner)

deps, err := cached.GetDependencies(pkg, v1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(deps) != 1 {
t.Fatalf("expected 1 dependency, got %d", len(deps))
}

deps = append(deps, NewTerm(dep, EqualsCondition{Version: SimpleVersion("2.0.0")}))
_ = deps

deps2, err := cached.GetDependencies(pkg, v1)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(deps2) != 1 {
t.Fatalf("cached dependencies slice should not be mutable by callers; got %d entries", len(deps2))
}
}
13 changes: 13 additions & 0 deletions solver_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type SolverOptions struct {
// Logger enables debug logging of solver operations.
// When nil, no logging is performed.
Logger *slog.Logger

// PreferHighestVersions forces the solver to pick the highest allowed version.
// When false, the solver uses a dependency-flexibility heuristic.
PreferHighestVersions bool
}

// SolverOption is a functional option for configuring the solver.
Expand All @@ -49,6 +53,7 @@ func defaultSolverOptions() SolverOptions {
return SolverOptions{
TrackIncompatibilities: false,
MaxSteps: defaultMaxSteps,
PreferHighestVersions: false,
}
}

Expand Down Expand Up @@ -106,3 +111,11 @@ func WithLogger(logger *slog.Logger) SolverOption {
opts.Logger = logger
}
}

// WithPreferHighestVersions forces the solver to choose the highest allowed version.
// This matches Bundler-style "latest" resolution behavior.
func WithPreferHighestVersions(enabled bool) SolverOption {
return func(opts *SolverOptions) {
opts.PreferHighestVersions = enabled
}
}
53 changes: 40 additions & 13 deletions source_combined.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type CombinedSource []Source
// GetVersions queries all sources and returns the combined set of versions
// in sorted order. Returns an error only if all sources fail with non-NotFound errors.
func (s CombinedSource) GetVersions(name Name) ([]Version, error) {
var ret []Version
unique := make(map[string]Version)
var firstErr error
var sawNotFound bool
for _, source := range s {
versions, err := source.GetVersions(name)
Expand All @@ -50,29 +51,48 @@ func (s CombinedSource) GetVersions(name Name) ([]Version, error) {
sawNotFound = true
continue
}
return nil, err
if firstErr == nil {
firstErr = err
}
continue
}
for _, ver := range versions {
key := ver.String()
if _, ok := unique[key]; !ok {
unique[key] = ver
}
}
ret = append(ret, versions...)
}

if len(ret) == 0 {
if sawNotFound {
return nil, &PackageNotFoundError{Package: name}
if len(unique) > 0 {
ret := make([]Version, 0, len(unique))
for _, ver := range unique {
ret = append(ret, ver)
}
return nil, &PackageNotFoundError{Package: name}

// sort the versions
slices.SortFunc(ret, func(a Version, b Version) int {
return a.Sort(b)
})

return ret, nil
}

// sort the versions
slices.SortFunc(ret, func(a Version, b Version) int {
return a.Sort(b)
})
if firstErr != nil {
return nil, firstErr
}

return ret, nil
if sawNotFound {
return nil, &PackageNotFoundError{Package: name}
}

return nil, &PackageNotFoundError{Package: name}
}

// GetDependencies queries sources in order and returns dependencies from the
// first source that has the specified package version.
func (s CombinedSource) GetDependencies(name Name, version Version) ([]Term, error) {
var lastErr error
for _, source := range s {
deps, err := source.GetDependencies(name, version)
if err != nil {
Expand All @@ -84,13 +104,20 @@ func (s CombinedSource) GetDependencies(name Name, version Version) ([]Term, err
case errors.As(err, &verErr):
continue
default:
return nil, err
if lastErr == nil {
lastErr = err
}
continue
}
} else {
return deps, nil
}
}

if lastErr != nil {
return nil, lastErr
}

return nil, &PackageVersionNotFoundError{Package: name, Version: version}
}

Expand Down
Loading
Loading