Skip to content

Commit eb883ee

Browse files
authored
pkg/acquisition: remove/restore mock datasources after usage (#4190)
this allows us to override regular datasources as well during tests
1 parent 795222e commit eb883ee

File tree

2 files changed

+58
-24
lines changed

2 files changed

+58
-24
lines changed

pkg/acquisition/acquisition_test.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,19 @@ func (*MockSourceCantRun) CanRun() error { return errors.New("can't run bro")
7676
func (*MockSourceCantRun) GetName() string { return "mock_cant_run" }
7777

7878
// appendMockSource is only used to add mock source for tests.
79-
func appendMockSource() {
80-
registry.RegisterTestFactory("mock", func() types.DataSource { return &MockSource{} })
81-
registry.RegisterTestFactory("mock_cant_run", func() types.DataSource { return &MockSourceCantRun{} })
79+
func appendMockSource(t *testing.T) {
80+
t.Helper()
81+
82+
restore := registry.RegisterTestFactory("mock", func() types.DataSource { return &MockSource{} })
83+
t.Cleanup(restore)
84+
restore = registry.RegisterTestFactory("mock_cant_run", func() types.DataSource { return &MockSourceCantRun{} })
85+
t.Cleanup(restore)
8286
}
8387

8488
func TestDataSourceConfigure(t *testing.T) {
8589
ctx := t.Context()
8690

87-
appendMockSource()
91+
appendMockSource(t)
8892

8993
tests := []struct {
9094
TestName string
@@ -218,7 +222,7 @@ filename: foo.log
218222
}
219223

220224
func TestLoadAcquisitionFromFiles(t *testing.T) {
221-
appendMockSource()
225+
appendMockSource(t)
222226
t.Setenv("TEST_ENV", "test_value2")
223227

224228
ctx := t.Context()
@@ -554,7 +558,8 @@ func TestConfigureByDSN(t *testing.T) {
554558
},
555559
}
556560

557-
registry.RegisterTestFactory("mockdsn", func() types.DataSource { return &MockSourceByDSN{} })
561+
restore := registry.RegisterTestFactory("mockdsn", func() types.DataSource { return &MockSourceByDSN{} })
562+
t.Cleanup(restore)
558563

559564
for _, tc := range tests {
560565
t.Run(tc.dsn, func(t *testing.T) {

pkg/acquisition/registry/registry.go

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,82 @@ package registry
33
import (
44
"errors"
55
"fmt"
6+
"sync"
67

78
"github.com/crowdsecurity/crowdsec/pkg/acquisition/types"
89
"github.com/crowdsecurity/crowdsec/pkg/cwversion/component"
910
)
1011

1112
// factoriesByName is filled at init time so the application can report
1213
// if the datasources are unsupported, or simply excluded from the build.
13-
// We don't need to guard with a mutex if all writes to the map are done
14-
// inside init() functions, as they are guaranteed to run sequentially.
15-
var factoriesByName = map[string]func() types.DataSource{}
14+
var (
15+
factoriesByName = map[string]types.DataSourceFactory{}
16+
mu sync.RWMutex
17+
)
18+
19+
func register(module string, factory types.DataSourceFactory) (restore func()) {
20+
if module == "" {
21+
panic("registry: datasource type is empty")
22+
}
23+
24+
if factory == nil {
25+
panic("registry: factory is nil for " + module)
26+
}
27+
28+
mu.Lock()
29+
prev, had := factoriesByName[module]
30+
factoriesByName[module] = factory
31+
mu.Unlock()
32+
33+
return func() {
34+
mu.Lock()
35+
if had {
36+
factoriesByName[module] = prev
37+
} else {
38+
delete(factoriesByName, module)
39+
}
40+
mu.Unlock()
41+
}
42+
}
1643

1744
// RegisterFactory registers a datasource constructor in the factoriesByName map.
1845
// It must be called in the init() function of the datasource package.
1946
// In addition, the build component is registered so it will be reported
2047
// by the "cscli version / crowdsec --version" commands.
21-
func RegisterFactory(moduleName string, factory types.DataSourceFactory) {
22-
component.Register("datasource_" + moduleName)
23-
factoriesByName[moduleName] = factory
48+
func RegisterFactory(module string, factory types.DataSourceFactory) {
49+
component.Register("datasource_" + module)
50+
register(module, factory)
2451
}
2552

2653
// RegisterTestFactory does not attempt to register it as a component,
2754
// production code should call RegisterFactory() instead and make the datasource
2855
// code optional using the appropriate build tag.
2956
// This function may be called outside init().
30-
func RegisterTestFactory(moduleName string, factory types.DataSourceFactory) {
31-
factoriesByName[moduleName] = factory
57+
func RegisterTestFactory(module string, factory types.DataSourceFactory) (restore func()) {
58+
return register(module, factory)
3259
}
3360

34-
func LookupFactory(moduleName string) (types.DataSourceFactory, error) {
35-
source, registered := factoriesByName[moduleName]
36-
if registered {
37-
return source, nil
61+
func LookupFactory(module string) (types.DataSourceFactory, error) {
62+
if module == "" {
63+
return nil, errors.New("data source type is empty")
3864
}
3965

40-
built, known := component.Built["datasource_"+moduleName]
66+
mu.RLock()
67+
factory, registered := factoriesByName[module]
68+
mu.RUnlock()
4169

42-
if moduleName == "" {
43-
return nil, errors.New("data source type is empty")
70+
if registered {
71+
return factory, nil
4472
}
4573

74+
built, known := component.Built["datasource_"+module]
4675
if !known {
47-
return nil, fmt.Errorf("unknown data source %s", moduleName)
76+
return nil, fmt.Errorf("unknown data source %s", module)
4877
}
4978

5079
if built {
51-
panic("datasource " + moduleName + " is built but not registered")
80+
panic("datasource " + module + " is built but not registered")
5281
}
5382

54-
return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", moduleName)
83+
return nil, fmt.Errorf("data source %s is not built in this version of crowdsec", module)
5584
}

0 commit comments

Comments
 (0)