Skip to content

Commit b3af7a3

Browse files
authored
Add unit tests (#1195)
* added unit tests Signed-off-by: Etai Lev Ran <[email protected]> * lint fix Signed-off-by: Etai Lev Ran <[email protected]> * simplify tests by removing redundant and using Signed-off-by: Etai Lev Ran <[email protected]> * changed to use cmp.Diff where appropriate Signed-off-by: Etai Lev Ran <[email protected]> --------- Signed-off-by: Etai Lev Ran <[email protected]>
1 parent 216d5bf commit b3af7a3

File tree

6 files changed

+374
-84
lines changed

6 files changed

+374
-84
lines changed

pkg/epp/datalayer/attributemap.go

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,27 @@ type AttributeMap interface {
3232
Put(string, Cloneable)
3333
Get(string) (Cloneable, bool)
3434
Keys() []string
35+
Clone() *Attributes
3536
}
3637

37-
// Attributes provides a goroutine safe implementation of AttributeMap.
38+
// Attributes provides a goroutine-safe implementation of AttributeMap.
3839
type Attributes struct {
3940
data sync.Map
4041
}
4142

42-
// NewAttributes return a new attribute map instance.
43+
// NewAttributes returns a new instance of Attributes.
4344
func NewAttributes() *Attributes {
44-
return &Attributes{
45-
data: sync.Map{},
46-
}
45+
return &Attributes{}
4746
}
4847

49-
// Put adds (or updates) an attribute in the map.
48+
// Put adds or updates an attribute in the map.
5049
func (a *Attributes) Put(key string, value Cloneable) {
51-
a.data.Store(key, value) // TODO: Clone into map?
50+
if value != nil {
51+
a.data.Store(key, value) // TODO: Clone into map to ensure isolation
52+
}
5253
}
5354

54-
// Get returns an attribute from the map.
55+
// Get retrieves an attribute by key, returning a cloned copy.
5556
func (a *Attributes) Get(key string) (Cloneable, bool) {
5657
val, ok := a.data.Load(key)
5758
if !ok {
@@ -60,30 +61,31 @@ func (a *Attributes) Get(key string) (Cloneable, bool) {
6061
if cloneable, ok := val.(Cloneable); ok {
6162
return cloneable.Clone(), true
6263
}
63-
return nil, false // shouldn't happen since Put accepts Cloneables only
64+
return nil, false
6465
}
6566

66-
// Keys returns an array of all the names of attributes stored in the map.
67+
// Keys returns all keys in the attribute map.
6768
func (a *Attributes) Keys() []string {
68-
keys := []string{}
69+
var keys []string
6970
a.data.Range(func(key, _ any) bool {
70-
if k, ok := key.(string); ok {
71-
keys = append(keys, k)
71+
if sk, ok := key.(string); ok {
72+
keys = append(keys, sk)
7273
}
73-
return true // continue iteration
74+
return true
7475
})
7576
return keys
7677
}
7778

78-
// Clone the attributes object itself.
79+
// Clone creates a deep copy of the entire Attributes map.
7980
func (a *Attributes) Clone() *Attributes {
80-
cloned := &Attributes{
81-
data: sync.Map{},
82-
}
83-
84-
a.data.Range(func(k, v interface{}) bool {
85-
cloned.data.Store(k, v)
81+
clone := NewAttributes()
82+
a.data.Range(func(key, value any) bool {
83+
if sk, ok := key.(string); ok {
84+
if v, ok := value.(Cloneable); ok {
85+
clone.Put(sk, v)
86+
}
87+
}
8688
return true
8789
})
88-
return cloned
90+
return clone
8991
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
Copyright 2025 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package datalayer
18+
19+
import (
20+
"testing"
21+
22+
"github.com/google/go-cmp/cmp"
23+
"github.com/stretchr/testify/assert"
24+
)
25+
26+
type dummy struct {
27+
Text string
28+
}
29+
30+
func (d *dummy) Clone() Cloneable {
31+
return &dummy{Text: d.Text}
32+
}
33+
34+
func TestExpectPutThenGetToMatch(t *testing.T) {
35+
attrs := NewAttributes()
36+
original := &dummy{"foo"}
37+
attrs.Put("a", original)
38+
39+
got, ok := attrs.Get("a")
40+
assert.True(t, ok, "expected key to exist")
41+
assert.NotSame(t, original, got, "expected Get to return a clone, not original")
42+
43+
dv, ok := got.(*dummy)
44+
assert.True(t, ok, "expected value to be of type *dummy")
45+
assert.Equal(t, "foo", dv.Text)
46+
}
47+
48+
func TestExpectKeysToMatchAdded(t *testing.T) {
49+
attrs := NewAttributes()
50+
attrs.Put("x", &dummy{"1"})
51+
attrs.Put("y", &dummy{"2"})
52+
53+
keys := attrs.Keys()
54+
assert.Len(t, keys, 2)
55+
assert.ElementsMatch(t, keys, []string{"x", "y"})
56+
}
57+
58+
func TestCloneReturnsCopy(t *testing.T) {
59+
original := NewAttributes()
60+
original.Put("k", &dummy{"value"})
61+
62+
cloned := original.Clone()
63+
64+
kOrig, _ := original.Get("k")
65+
kClone, _ := cloned.Get("k")
66+
67+
assert.NotSame(t, kOrig, kClone, "expected cloned value to be a different instance")
68+
if diff := cmp.Diff(kOrig, kClone); diff != "" {
69+
t.Errorf("Unexpected output (-want +got): %v", diff)
70+
}
71+
}

pkg/epp/datalayer/datasource.go

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,126 +23,98 @@ import (
2323
"sync"
2424
)
2525

26-
// DataSource is an interface required from all data layer data collection
27-
// sources.
26+
// DataSource provides raw data to registered Extractors.
2827
type DataSource interface {
29-
// Name returns the name of this datasource.
3028
Name() string
31-
32-
// AddExtractor adds an extractor to the data source.
33-
// The extractor will be called whenever the Collector might
29+
// AddExtractor adds an extractor to the data source. Multiple
30+
// Extractors can be registered.
31+
// The extractor will be called whenever the DataSource might
3432
// have some new raw information regarding an endpoint.
3533
// The Extractor's expected input type should be validated against
3634
// the data source's output type upon registration.
3735
AddExtractor(extractor Extractor) error
38-
3936
// Collect is triggered by the data layer framework to fetch potentially new
40-
// data for an endpoint. It passes retrieved data to registered Extractors.
37+
// data for an endpoint. Collect calls registered Extractors to convert the
38+
// raw data into structured attributes.
4139
Collect(ep Endpoint)
4240
}
4341

44-
// Extractor is used to convert raw data into relevant data layer information
45-
// for an endpoint. They are called by data sources whenever new data might be
46-
// available. Multiple Extractors can be registered with a source. Extractors
47-
// are expected to save their output with an endpoint so it becomes accessible
48-
// to consumers in other subsystem of the inference gateway (e.g., when making
49-
// scheduling decisions).
42+
// Extractor transforms raw data into structured attributes.
5043
type Extractor interface {
51-
// Name returns the name of the extractor.
5244
Name() string
53-
54-
// ExpectedType defines the type expected by the extractor. It must match
55-
// the output type of the data source where the extractor is registered.
45+
// ExpectedType defines the type expected by the extractor.
5646
ExpectedInputType() reflect.Type
57-
58-
// Extract transforms the data source output into a concrete attribute that
59-
// is stored on the given endpoint.
47+
// Extract transforms the raw data source output into a concrete structured
48+
// attribute, stored on the given endpoint.
6049
Extract(data any, ep Endpoint)
6150
}
6251

63-
var (
64-
// defaultDataSources is the system default data source registry.
65-
defaultDataSources = DataSourceRegistry{}
66-
)
52+
var defaultDataSources = DataSourceRegistry{}
6753

68-
// DataSourceRegistry stores named data sources and makes them
69-
// accessible to other subsystems in the inference gateway.
54+
// DataSourceRegistry stores named data sources.
7055
type DataSourceRegistry struct {
7156
sources sync.Map
7257
}
7358

74-
// Register adds a source to the registry.
59+
// Register adds a new DataSource to the registry.
7560
func (dsr *DataSourceRegistry) Register(src DataSource) error {
7661
if src == nil {
7762
return errors.New("unable to register a nil data source")
7863
}
79-
80-
if _, found := dsr.sources.Load(src.Name()); found {
64+
if _, loaded := dsr.sources.LoadOrStore(src.Name(), src); loaded {
8165
return fmt.Errorf("unable to register duplicate data source: %s", src.Name())
8266
}
83-
dsr.sources.Store(src.Name(), src)
8467
return nil
8568
}
8669

87-
// GetNamedSource returns the named data source, if found.
70+
// GetNamedSource fetches a source by name.
8871
func (dsr *DataSourceRegistry) GetNamedSource(name string) (DataSource, bool) {
89-
if name == "" {
90-
return nil, false
91-
}
92-
93-
if val, found := dsr.sources.Load(name); found {
72+
if val, ok := dsr.sources.Load(name); ok {
9473
if ds, ok := val.(DataSource); ok {
9574
return ds, true
96-
} // ignore type assertion failures and fall through
75+
}
9776
}
9877
return nil, false
9978
}
10079

101-
// GetSources returns all sources registered.
80+
// GetSources returns all registered sources.
10281
func (dsr *DataSourceRegistry) GetSources() []DataSource {
103-
sources := []DataSource{}
82+
var result []DataSource
10483
dsr.sources.Range(func(_, val any) bool {
10584
if ds, ok := val.(DataSource); ok {
106-
sources = append(sources, ds)
85+
result = append(result, ds)
10786
}
108-
return true // continue iteration
87+
return true
10988
})
110-
return sources
89+
return result
11190
}
11291

113-
// RegisterSource adds the data source to the default registry.
92+
// --- default registry accessors ---
93+
11494
func RegisterSource(src DataSource) error {
11595
return defaultDataSources.Register(src)
11696
}
11797

118-
// GetNamedSource returns the named source from the default registry,
119-
// if found.
12098
func GetNamedSource(name string) (DataSource, bool) {
12199
return defaultDataSources.GetNamedSource(name)
122100
}
123101

124-
// GetSources returns all sources in the default registry.
125102
func GetSources() []DataSource {
126103
return defaultDataSources.GetSources()
127104
}
128105

129106
// ValidateExtractorType checks if an extractor can handle
130-
// the collector's output.
131-
func ValidateExtractorType(collectorOutputType, extractorInputType reflect.Type) error {
132-
if collectorOutputType == extractorInputType {
133-
return nil
134-
}
135-
136-
// extractor accepts anything (i.e., interface{})
137-
if extractorInputType.Kind() == reflect.Interface && extractorInputType.NumMethod() == 0 {
138-
return nil
107+
// the DataSource's output. It should be called by a DataSource
108+
// when an extractor is added.
109+
func ValidateExtractorType(collectorOutput, extractorInput reflect.Type) error {
110+
if collectorOutput == nil || extractorInput == nil {
111+
return errors.New("extractor input type or data source output type can't be nil")
139112
}
140-
141-
// check if collector output implements extractor input interface
142-
if collectorOutputType.Implements(extractorInputType) {
113+
if collectorOutput == extractorInput ||
114+
(extractorInput.Kind() == reflect.Interface && extractorInput.NumMethod() == 0) ||
115+
(extractorInput.Kind() == reflect.Interface && collectorOutput.Implements(extractorInput)) {
143116
return nil
144117
}
145-
146-
return fmt.Errorf("extractor input type %v cannot handle collector output type %v",
147-
extractorInputType, collectorOutputType)
118+
return fmt.Errorf("extractor input type %v cannot handle data source output type %v",
119+
extractorInput, collectorOutput)
148120
}

0 commit comments

Comments
 (0)