Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 28 additions & 28 deletions internal/core/bfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,88 +8,88 @@ import (
"github.com/microsoft/typescript-go/internal/collections"
)

type BreadthFirstSearchResult[N comparable] struct {
type BreadthFirstSearchResult[N any] struct {
Stopped bool
Path []N
}

type breadthFirstSearchJob[N comparable] struct {
type breadthFirstSearchJob[N any] struct {
node N
parent *breadthFirstSearchJob[N]
}

type BreadthFirstSearchLevel[N comparable] struct {
jobs *collections.OrderedMap[N, *breadthFirstSearchJob[N]]
type BreadthFirstSearchLevel[K comparable, N interface{ Key() K }] struct {
jobs *collections.OrderedMap[K, *breadthFirstSearchJob[N]]
}

func (l *BreadthFirstSearchLevel[N]) Has(node N) bool {
return l.jobs.Has(node)
func (l *BreadthFirstSearchLevel[K, N]) Has(node N) bool {
return l.jobs.Has(node.Key())
}

func (l *BreadthFirstSearchLevel[N]) Delete(node N) {
l.jobs.Delete(node)
func (l *BreadthFirstSearchLevel[K, N]) Delete(node N) {
l.jobs.Delete(node.Key())
}

func (l *BreadthFirstSearchLevel[N]) Range(f func(node N) bool) {
for node := range l.jobs.Keys() {
if !f(node) {
func (l *BreadthFirstSearchLevel[K, N]) Range(f func(node N) bool) {
for job := range l.jobs.Values() {
if !f(job.node) {
return
}
}
}

type BreadthFirstSearchOptions[N comparable] struct {
type BreadthFirstSearchOptions[K comparable, N interface{ Key() K }] struct {
// Visited is a set of nodes that have already been visited.
// If nil, a new set will be created.
Visited *collections.SyncSet[N]
Visited *collections.SyncSet[K]
// PreprocessLevel is a function that, if provided, will be called
// before each level, giving the caller an opportunity to remove nodes.
PreprocessLevel func(*BreadthFirstSearchLevel[N])
PreprocessLevel func(*BreadthFirstSearchLevel[K, N])
}

// BreadthFirstSearchParallel performs a breadth-first search on a graph
// starting from the given node. It processes nodes in parallel and returns the path
// from the first node that satisfies the `visit` function back to the start node.
func BreadthFirstSearchParallel[N comparable](
func BreadthFirstSearchParallel[K comparable, N interface{ Key() K }](
start N,
neighbors func(N) []N,
visit func(node N) (isResult bool, stop bool),
) BreadthFirstSearchResult[N] {
return BreadthFirstSearchParallelEx(start, neighbors, visit, BreadthFirstSearchOptions[N]{})
return BreadthFirstSearchParallelEx(start, neighbors, visit, BreadthFirstSearchOptions[K, N]{})
}

// BreadthFirstSearchParallelEx is an extension of BreadthFirstSearchParallel that allows
// the caller to pass a pre-seeded set of already-visited nodes and a preprocessing function
// that can be used to remove nodes from each level before parallel processing.
func BreadthFirstSearchParallelEx[N comparable](
func BreadthFirstSearchParallelEx[K comparable, N interface{ Key() K }](
start N,
neighbors func(N) []N,
visit func(node N) (isResult bool, stop bool),
options BreadthFirstSearchOptions[N],
options BreadthFirstSearchOptions[K, N],
) BreadthFirstSearchResult[N] {
visited := options.Visited
if visited == nil {
visited = &collections.SyncSet[N]{}
visited = &collections.SyncSet[K]{}
}

type result struct {
stop bool
job *breadthFirstSearchJob[N]
next *collections.OrderedMap[N, *breadthFirstSearchJob[N]]
next *collections.OrderedMap[K, *breadthFirstSearchJob[N]]
}

var fallback *breadthFirstSearchJob[N]
// processLevel processes each node at the current level in parallel.
// It produces either a list of jobs to be processed in the next level,
// or a result if the visit function returns true for any node.
processLevel := func(index int, jobs *collections.OrderedMap[N, *breadthFirstSearchJob[N]]) result {
processLevel := func(index int, jobs *collections.OrderedMap[K, *breadthFirstSearchJob[N]]) result {
var lowestFallback atomic.Int64
var lowestGoal atomic.Int64
var nextJobCount atomic.Int64
lowestGoal.Store(math.MaxInt64)
lowestFallback.Store(math.MaxInt64)
if options.PreprocessLevel != nil {
options.PreprocessLevel(&BreadthFirstSearchLevel[N]{jobs: jobs})
options.PreprocessLevel(&BreadthFirstSearchLevel[K, N]{jobs: jobs})
}
next := make([][]*breadthFirstSearchJob[N], jobs.Size())
var wg sync.WaitGroup
Expand All @@ -103,7 +103,7 @@ func BreadthFirstSearchParallelEx[N comparable](
}

// If we have already visited this node, skip it.
if !visited.AddIfAbsent(j.node) {
if !visited.AddIfAbsent(j.node.Key()) {
// Note that if we are here, we already visited this node at a
// previous *level*, which means `visit` must have returned false,
// so we don't need to update our result indices. This holds true
Expand Down Expand Up @@ -152,13 +152,13 @@ func BreadthFirstSearchParallelEx[N comparable](
_, fallback, _ = jobs.EntryAt(int(index))
}
}
nextJobs := collections.NewOrderedMapWithSizeHint[N, *breadthFirstSearchJob[N]](int(nextJobCount.Load()))
nextJobs := collections.NewOrderedMapWithSizeHint[K, *breadthFirstSearchJob[N]](int(nextJobCount.Load()))
for _, jobs := range next {
for _, j := range jobs {
if !nextJobs.Has(j.node) {
if !nextJobs.Has(j.node.Key()) {
// Deduplicate synchronously to avoid messy locks and spawning
// unnecessary goroutines.
nextJobs.Set(j.node, j)
nextJobs.Set(j.node.Key(), j)
}
}
}
Expand All @@ -175,8 +175,8 @@ func BreadthFirstSearchParallelEx[N comparable](
}

levelIndex := 0
level := collections.NewOrderedMapFromList([]collections.MapEntry[N, *breadthFirstSearchJob[N]]{
{Key: start, Value: &breadthFirstSearchJob[N]{node: start}},
level := collections.NewOrderedMapFromList([]collections.MapEntry[K, *breadthFirstSearchJob[N]]{
{Key: start.Key(), Value: &breadthFirstSearchJob[N]{node: start}},
})
for level.Size() > 0 {
result := processLevel(levelIndex, level)
Expand Down
52 changes: 28 additions & 24 deletions internal/core/bfs_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package core_test

import (
"sort"
"slices"
"sync"
"testing"

Expand All @@ -10,37 +10,41 @@ import (
"gotest.tools/v3/assert"
)

type node string

func (n node) Key() node { return n }

func TestBreadthFirstSearchParallel(t *testing.T) {
t.Parallel()
t.Run("basic functionality", func(t *testing.T) {
t.Parallel()
// Test basic functionality with a simple DAG
// Graph: A -> B, A -> C, B -> D, C -> D
graph := map[string][]string{
graph := map[node][]node{
"A": {"B", "C"},
"B": {"D"},
"C": {"D"},
"D": {},
}

children := func(node string) []string {
children := func(node node) []node {
return graph[node]
}

t.Run("find specific node", func(t *testing.T) {
t.Parallel()
result := core.BreadthFirstSearchParallel("A", children, func(node string) (bool, bool) {
result := core.BreadthFirstSearchParallel(node("A"), children, func(node node) (bool, bool) {
return node == "D", true
})
assert.Equal(t, result.Stopped, true, "Expected search to stop at D")
assert.DeepEqual(t, result.Path, []string{"D", "B", "A"})
assert.DeepEqual(t, result.Path, []node{"D", "B", "A"})
})

t.Run("visit all nodes", func(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var visitedNodes []string
result := core.BreadthFirstSearchParallel("A", children, func(node string) (bool, bool) {
var visitedNodes []node
result := core.BreadthFirstSearchParallel("A", children, func(node node) (bool, bool) {
mu.Lock()
defer mu.Unlock()
visitedNodes = append(visitedNodes, node)
Expand All @@ -52,16 +56,16 @@ func TestBreadthFirstSearchParallel(t *testing.T) {
assert.Assert(t, result.Path == nil, "Expected nil path when visit function never returns true")

// Should visit all nodes exactly once
sort.Strings(visitedNodes)
expected := []string{"A", "B", "C", "D"}
slices.Sort(visitedNodes)
expected := []node{"A", "B", "C", "D"}
assert.DeepEqual(t, visitedNodes, expected)
})
})

t.Run("early termination", func(t *testing.T) {
t.Parallel()
// Test that nodes below the target level are not visited
graph := map[string][]string{
graph := map[node][]node{
"Root": {"L1A", "L1B"},
"L1A": {"L2A", "L2B"},
"L1B": {"L2C"},
Expand All @@ -71,14 +75,14 @@ func TestBreadthFirstSearchParallel(t *testing.T) {
"L3A": {},
}

children := func(node string) []string {
children := func(node node) []node {
return graph[node]
}

var visited collections.SyncSet[string]
core.BreadthFirstSearchParallelEx("Root", children, func(node string) (bool, bool) {
var visited collections.SyncSet[node]
core.BreadthFirstSearchParallelEx("Root", children, func(node node) (bool, bool) {
return node == "L2B", true // Stop at level 2
}, core.BreadthFirstSearchOptions[string]{
}, core.BreadthFirstSearchOptions[node, node]{
Visited: &visited,
})

Expand All @@ -94,26 +98,26 @@ func TestBreadthFirstSearchParallel(t *testing.T) {
t.Run("returns fallback when no other result found", func(t *testing.T) {
t.Parallel()
// Test that fallback behavior works correctly
graph := map[string][]string{
graph := map[node][]node{
"A": {"B", "C"},
"B": {"D"},
"C": {"D"},
"D": {},
}

children := func(node string) []string {
children := func(node node) []node {
return graph[node]
}

var visited collections.SyncSet[string]
result := core.BreadthFirstSearchParallelEx("A", children, func(node string) (bool, bool) {
var visited collections.SyncSet[node]
result := core.BreadthFirstSearchParallelEx("A", children, func(node node) (bool, bool) {
return node == "A", false // Record A as a fallback, but do not stop
}, core.BreadthFirstSearchOptions[string]{
}, core.BreadthFirstSearchOptions[node, node]{
Visited: &visited,
})

assert.Equal(t, result.Stopped, false, "Expected search to not stop early")
assert.DeepEqual(t, result.Path, []string{"A"})
assert.DeepEqual(t, result.Path, []node{"A"})
assert.Assert(t, visited.Has("B"), "Expected to visit B")
assert.Assert(t, visited.Has("C"), "Expected to visit C")
assert.Assert(t, visited.Has("D"), "Expected to visit D")
Expand All @@ -122,18 +126,18 @@ func TestBreadthFirstSearchParallel(t *testing.T) {
t.Run("returns a stop result over a fallback", func(t *testing.T) {
t.Parallel()
// Test that a stop result is preferred over a fallback
graph := map[string][]string{
graph := map[node][]node{
"A": {"B", "C"},
"B": {"D"},
"C": {"D"},
"D": {},
}

children := func(node string) []string {
children := func(node node) []node {
return graph[node]
}

result := core.BreadthFirstSearchParallel("A", children, func(node string) (bool, bool) {
result := core.BreadthFirstSearchParallel("A", children, func(node node) (bool, bool) {
switch node {
case "A":
return true, false // Record fallback
Expand All @@ -145,6 +149,6 @@ func TestBreadthFirstSearchParallel(t *testing.T) {
})

assert.Equal(t, result.Stopped, true, "Expected search to stop at D")
assert.DeepEqual(t, result.Path, []string{"D", "B", "A"})
assert.DeepEqual(t, result.Path, []node{"D", "B", "A"})
})
}
5 changes: 5 additions & 0 deletions internal/project/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ func (p *Project) ConfigFilePath() tspath.Path {
return p.configFilePath
}

// Needed for BreadthFirstSearch
func (p *Project) Key() *Project {
return p
}

// GetProgram implements ls.Host.
func (p *Project) GetProgram() *compiler.Program {
return p.Program
Expand Down
2 changes: 1 addition & 1 deletion internal/project/projectcollection.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func (c *ProjectCollection) findDefaultConfiguredProjectWorker(fileName string,
}
return false, false
},
core.BreadthFirstSearchOptions[*Project]{
core.BreadthFirstSearchOptions[*Project, *Project]{
Visited: visited,
},
)
Expand Down
19 changes: 14 additions & 5 deletions internal/project/projectcollectionbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,15 @@ type searchNode struct {
logger *logging.LogTree
}

func (n searchNode) Key() searchNodeKey {
return searchNodeKey{configFileName: n.configFileName, loadKind: n.loadKind}
}

type searchNodeKey struct {
configFileName string
loadKind projectLoadKind
}

type searchResult struct {
project *dirty.SyncMapEntry[tspath.Path, *Project]
retain collections.Set[tspath.Path]
Expand All @@ -483,13 +492,13 @@ func (b *projectCollectionBuilder) findOrCreateDefaultConfiguredProjectWorker(
path tspath.Path,
configFileName string,
loadKind projectLoadKind,
visited *collections.SyncSet[searchNode],
visited *collections.SyncSet[searchNodeKey],
fallback *searchResult,
logger *logging.LogTree,
) searchResult {
var configs collections.SyncMap[tspath.Path, *tsoptions.ParsedCommandLine]
if visited == nil {
visited = &collections.SyncSet[searchNode]{}
visited = &collections.SyncSet[searchNodeKey]{}
}

search := core.BreadthFirstSearchParallelEx(
Expand Down Expand Up @@ -558,9 +567,9 @@ func (b *projectCollectionBuilder) findOrCreateDefaultConfiguredProjectWorker(
node.logger.Log("Project does not contain file")
return false, false
},
core.BreadthFirstSearchOptions[searchNode]{
core.BreadthFirstSearchOptions[searchNodeKey, searchNode]{
Visited: visited,
PreprocessLevel: func(level *core.BreadthFirstSearchLevel[searchNode]) {
PreprocessLevel: func(level *core.BreadthFirstSearchLevel[searchNodeKey, searchNode]) {
level.Range(func(node searchNode) bool {
if node.loadKind == projectLoadKindFind && level.Has(searchNode{configFileName: node.configFileName, loadKind: projectLoadKindCreate, logger: node.logger}) {
// Remove find requests when a create request for the same project is already present.
Expand Down Expand Up @@ -626,7 +635,7 @@ func (b *projectCollectionBuilder) findOrCreateDefaultConfiguredProjectWorker(
// If we didn't find anything, we can retain everything we visited,
// since the whole graph must have been traversed (i.e., the set of
// retained projects is guaranteed to be deterministic).
visited.Range(func(node searchNode) bool {
visited.Range(func(node searchNodeKey) bool {
retain.Add(b.toPath(node.configFileName))
return true
})
Expand Down
Loading
Loading