diff --git a/lib/common_test.go b/lib/common_test.go new file mode 100644 index 00000000000..f5de80c0b46 --- /dev/null +++ b/lib/common_test.go @@ -0,0 +1,195 @@ +package lib + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetRemoteURLContent(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr bool + }{ + { + name: "successful request", + statusCode: http.StatusOK, + body: "test content", + wantErr: false, + }, + { + name: "not found", + statusCode: http.StatusNotFound, + body: "", + wantErr: true, + }, + { + name: "internal server error", + statusCode: http.StatusInternalServerError, + body: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + w.Write([]byte(tt.body)) + })) + defer server.Close() + + got, err := GetRemoteURLContent(server.URL) + if (err != nil) != tt.wantErr { + t.Errorf("GetRemoteURLContent() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && string(got) != tt.body { + t.Errorf("GetRemoteURLContent() = %v, want %v", string(got), tt.body) + } + }) + } +} + +func TestGetRemoteURLContentInvalidURL(t *testing.T) { + _, err := GetRemoteURLContent("invalid://url") + if err == nil { + t.Error("GetRemoteURLContent() should return error for invalid URL") + } +} + +func TestGetRemoteURLReader(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + wantErr bool + }{ + { + name: "successful request", + statusCode: http.StatusOK, + body: "test content", + wantErr: false, + }, + { + name: "not found", + statusCode: http.StatusNotFound, + body: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + w.Write([]byte(tt.body)) + })) + defer server.Close() + + got, err := GetRemoteURLReader(server.URL) + if (err != nil) != tt.wantErr { + t.Errorf("GetRemoteURLReader() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + defer got.Close() + if got == nil { + t.Error("GetRemoteURLReader() returned nil reader") + } + } + }) + } +} + +func TestGetRemoteURLReaderInvalidURL(t *testing.T) { + _, err := GetRemoteURLReader("invalid://url") + if err == nil { + t.Error("GetRemoteURLReader() should return error for invalid URL") + } +} + +func TestWantedListExtended_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + wantErr bool + checkFn func(*testing.T, *WantedListExtended) + }{ + { + name: "empty object", + json: `{}`, + wantErr: false, + checkFn: func(t *testing.T, w *WantedListExtended) { + if len(w.TypeSlice) != 0 || len(w.TypeMap) != 0 { + t.Error("WantedListExtended should have empty slices/maps for empty object") + } + }, + }, + { + name: "empty array", + json: `[]`, + wantErr: false, + checkFn: func(t *testing.T, w *WantedListExtended) { + if len(w.TypeSlice) != 0 { + t.Error("TypeSlice should be empty for empty array") + } + }, + }, + { + name: "slice format", + json: `["type1", "type2", "type3"]`, + wantErr: false, + checkFn: func(t *testing.T, w *WantedListExtended) { + if len(w.TypeSlice) != 3 { + t.Errorf("TypeSlice length = %d, want 3", len(w.TypeSlice)) + } + if w.TypeSlice[0] != "type1" { + t.Errorf("TypeSlice[0] = %s, want 'type1'", w.TypeSlice[0]) + } + }, + }, + { + name: "map format", + json: `{"key1": ["val1", "val2"], "key2": ["val3"]}`, + wantErr: false, + checkFn: func(t *testing.T, w *WantedListExtended) { + if len(w.TypeMap) != 2 { + t.Errorf("TypeMap length = %d, want 2", len(w.TypeMap)) + } + if len(w.TypeMap["key1"]) != 2 { + t.Errorf("TypeMap[key1] length = %d, want 2", len(w.TypeMap["key1"])) + } + }, + }, + { + name: "invalid json", + json: `{invalid}`, + wantErr: true, + checkFn: nil, + }, + { + name: "number type (not array or map)", + json: `123`, + wantErr: true, + checkFn: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var w WantedListExtended + err := json.Unmarshal([]byte(tt.json), &w) + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.checkFn != nil { + tt.checkFn(t, &w) + } + }) + } +} diff --git a/lib/config_test.go b/lib/config_test.go new file mode 100644 index 00000000000..4d85f1eadab --- /dev/null +++ b/lib/config_test.go @@ -0,0 +1,266 @@ +package lib + +import ( + "encoding/json" + "testing" +) + +func TestRegisterInputConfigCreator(t *testing.T) { + // Save original state + originalCache := inputConfigCreatorCache + defer func() { inputConfigCreatorCache = originalCache }() + + // Reset cache for testing + inputConfigCreatorCache = make(map[string]inputConfigCreator) + + creator := func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typ: "test"}, nil + } + + // Test successful registration + err := RegisterInputConfigCreator("test", creator) + if err != nil { + t.Errorf("RegisterInputConfigCreator() error = %v, want nil", err) + } + + // Test duplicate registration + err = RegisterInputConfigCreator("test", creator) + if err == nil { + t.Error("RegisterInputConfigCreator() should return error for duplicate") + } + + // Test case insensitive registration + err = RegisterInputConfigCreator("TEST", creator) + if err == nil { + t.Error("RegisterInputConfigCreator() should return error for duplicate (case insensitive)") + } +} + +func TestCreateInputConfig(t *testing.T) { + // Save original state + originalCache := inputConfigCreatorCache + defer func() { inputConfigCreatorCache = originalCache }() + + // Reset cache for testing + inputConfigCreatorCache = make(map[string]inputConfigCreator) + + creator := func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typ: "test"}, nil + } + + RegisterInputConfigCreator("test", creator) + + tests := []struct { + name string + id string + action Action + wantErr bool + }{ + { + name: "valid type", + id: "test", + action: ActionAdd, + wantErr: false, + }, + { + name: "valid type uppercase", + id: "TEST", + action: ActionAdd, + wantErr: false, + }, + { + name: "unknown type", + id: "unknown", + action: ActionAdd, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := createInputConfig(tt.id, tt.action, nil) + if (err != nil) != tt.wantErr { + t.Errorf("createInputConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got == nil { + t.Error("createInputConfig() returned nil converter") + } + }) + } +} + +func TestRegisterOutputConfigCreator(t *testing.T) { + // Save original state + originalCache := outputConfigCreatorCache + defer func() { outputConfigCreatorCache = originalCache }() + + // Reset cache for testing + outputConfigCreatorCache = make(map[string]outputConfigCreator) + + creator := func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typ: "test"}, nil + } + + // Test successful registration + err := RegisterOutputConfigCreator("test", creator) + if err != nil { + t.Errorf("RegisterOutputConfigCreator() error = %v, want nil", err) + } + + // Test duplicate registration + err = RegisterOutputConfigCreator("test", creator) + if err == nil { + t.Error("RegisterOutputConfigCreator() should return error for duplicate") + } +} + +func TestCreateOutputConfig(t *testing.T) { + // Save original state + originalCache := outputConfigCreatorCache + defer func() { outputConfigCreatorCache = originalCache }() + + // Reset cache for testing + outputConfigCreatorCache = make(map[string]outputConfigCreator) + + creator := func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typ: "test"}, nil + } + + RegisterOutputConfigCreator("test", creator) + + tests := []struct { + name string + id string + action Action + wantErr bool + }{ + { + name: "valid type", + id: "test", + action: ActionOutput, + wantErr: false, + }, + { + name: "unknown type", + id: "unknown", + action: ActionOutput, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := createOutputConfig(tt.id, tt.action, nil) + if (err != nil) != tt.wantErr { + t.Errorf("createOutputConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got == nil { + t.Error("createOutputConfig() returned nil converter") + } + }) + } +} + +func TestInputConvConfig_UnmarshalJSON(t *testing.T) { + // Save original state + originalCache := inputConfigCreatorCache + defer func() { inputConfigCreatorCache = originalCache }() + + // Reset cache for testing + inputConfigCreatorCache = make(map[string]inputConfigCreator) + + creator := func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typ: "test"}, nil + } + RegisterInputConfigCreator("test", creator) + + tests := []struct { + name string + json string + wantErr bool + }{ + { + name: "valid config", + json: `{"type": "test", "action": "add", "args": {}}`, + wantErr: false, + }, + { + name: "invalid action", + json: `{"type": "test", "action": "invalid", "args": {}}`, + wantErr: true, + }, + { + name: "unknown type", + json: `{"type": "unknown", "action": "add", "args": {}}`, + wantErr: true, + }, + { + name: "invalid json", + json: `{invalid}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg inputConvConfig + err := json.Unmarshal([]byte(tt.json), &cfg) + if (err != nil) != tt.wantErr { + t.Errorf("inputConvConfig.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestOutputConvConfig_UnmarshalJSON(t *testing.T) { + // Save original state + originalCache := outputConfigCreatorCache + defer func() { outputConfigCreatorCache = originalCache }() + + // Reset cache for testing + outputConfigCreatorCache = make(map[string]outputConfigCreator) + + creator := func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typ: "test"}, nil + } + RegisterOutputConfigCreator("test", creator) + + tests := []struct { + name string + json string + wantErr bool + }{ + { + name: "valid config", + json: `{"type": "test", "action": "output", "args": {}}`, + wantErr: false, + }, + { + name: "default action", + json: `{"type": "test", "args": {}}`, + wantErr: false, + }, + { + name: "invalid action", + json: `{"type": "test", "action": "invalid", "args": {}}`, + wantErr: true, + }, + { + name: "unknown type", + json: `{"type": "unknown", "action": "output", "args": {}}`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg outputConvConfig + err := json.Unmarshal([]byte(tt.json), &cfg) + if (err != nil) != tt.wantErr { + t.Errorf("outputConvConfig.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/lib/container_test.go b/lib/container_test.go new file mode 100644 index 00000000000..fbc20a8fa95 --- /dev/null +++ b/lib/container_test.go @@ -0,0 +1,717 @@ +package lib + +import ( + "testing" +) + +func TestNewContainer(t *testing.T) { + c := NewContainer() + if c == nil { + t.Fatal("NewContainer() returned nil") + } + if c.Len() != 0 { + t.Errorf("NewContainer().Len() = %d, want 0", c.Len()) + } +} + +func TestContainer_GetEntry(t *testing.T) { + c := NewContainer() + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + c.Add(entry) + + tests := []struct { + name string + entryName string + wantFound bool + }{ + { + name: "existing entry", + entryName: "test", + wantFound: true, + }, + { + name: "existing entry uppercase", + entryName: "TEST", + wantFound: true, + }, + { + name: "existing entry with spaces", + entryName: " test ", + wantFound: true, + }, + { + name: "non-existing entry", + entryName: "nonexistent", + wantFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, found := c.GetEntry(tt.entryName) + if found != tt.wantFound { + t.Errorf("Container.GetEntry() found = %v, wantFound %v", found, tt.wantFound) + } + if tt.wantFound && got == nil { + t.Error("Container.GetEntry() returned nil entry") + } + }) + } +} + +func TestContainer_Add(t *testing.T) { + tests := []struct { + name string + entries []*Entry + opts []IgnoreIPOption + wantLen int + }{ + { + name: "add single entry", + entries: []*Entry{ + func() *Entry { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + return e + }(), + }, + opts: nil, + wantLen: 1, + }, + { + name: "add multiple entries", + entries: []*Entry{ + func() *Entry { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + return e + }(), + func() *Entry { + e := NewEntry("test2") + e.AddPrefix("10.0.0.0/8") + return e + }(), + }, + opts: nil, + wantLen: 2, + }, + { + name: "add duplicate entry", + entries: []*Entry{ + func() *Entry { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + return e + }(), + func() *Entry { + e := NewEntry("test1") + e.AddPrefix("10.0.0.0/8") + return e + }(), + }, + opts: nil, + wantLen: 1, // Should merge into one + }, + { + name: "add with ignore IPv4", + entries: []*Entry{ + func() *Entry { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + e.AddPrefix("2001:db8::/32") + return e + }(), + }, + opts: []IgnoreIPOption{IgnoreIPv4}, + wantLen: 1, + }, + { + name: "add with ignore IPv6", + entries: []*Entry{ + func() *Entry { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + e.AddPrefix("2001:db8::/32") + return e + }(), + }, + opts: []IgnoreIPOption{IgnoreIPv6}, + wantLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewContainer() + for _, entry := range tt.entries { + if err := c.Add(entry, tt.opts...); err != nil { + t.Errorf("Container.Add() error = %v", err) + } + } + if c.Len() != tt.wantLen { + t.Errorf("Container.Len() = %d, want %d", c.Len(), tt.wantLen) + } + }) + } +} + +func TestContainer_Remove(t *testing.T) { + tests := []struct { + name string + setupFn func(Container) + removeName string + removeCase CaseRemove + opts []IgnoreIPOption + wantErr bool + checkFn func(*testing.T, Container) + }{ + { + name: "remove non-existent entry", + setupFn: func(c Container) { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + c.Add(e) + }, + removeName: "nonexistent", + removeCase: CaseRemoveEntry, + wantErr: true, + }, + { + name: "remove entry completely", + setupFn: func(c Container) { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + c.Add(e) + }, + removeName: "test1", + removeCase: CaseRemoveEntry, + wantErr: false, + checkFn: func(t *testing.T, c Container) { + if c.Len() != 0 { + t.Errorf("Container.Len() = %d, want 0", c.Len()) + } + }, + }, + { + name: "remove prefix", + setupFn: func(c Container) { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + e.AddPrefix("10.0.0.0/8") + c.Add(e) + }, + removeName: "test1", + removeCase: CaseRemovePrefix, + wantErr: false, + checkFn: func(t *testing.T, c Container) { + if c.Len() != 1 { + t.Errorf("Container.Len() = %d, want 1", c.Len()) + } + }, + }, + { + name: "remove with ignore IPv4", + setupFn: func(c Container) { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + e.AddPrefix("2001:db8::/32") + c.Add(e) + }, + removeName: "test1", + removeCase: CaseRemoveEntry, + opts: []IgnoreIPOption{IgnoreIPv4}, + wantErr: false, + checkFn: func(t *testing.T, c Container) { + if c.Len() != 1 { + t.Errorf("Container.Len() = %d, want 1 (IPv4 should be removed)", c.Len()) + } + }, + }, + { + name: "remove with ignore IPv6", + setupFn: func(c Container) { + e := NewEntry("test1") + e.AddPrefix("192.168.1.0/24") + e.AddPrefix("2001:db8::/32") + c.Add(e) + }, + removeName: "test1", + removeCase: CaseRemoveEntry, + opts: []IgnoreIPOption{IgnoreIPv6}, + wantErr: false, + checkFn: func(t *testing.T, c Container) { + if c.Len() != 1 { + t.Errorf("Container.Len() = %d, want 1 (IPv6 should be removed)", c.Len()) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewContainer() + if tt.setupFn != nil { + tt.setupFn(c) + } + + removeEntry := NewEntry(tt.removeName) + removeEntry.AddPrefix("192.168.1.0/24") + + err := c.Remove(removeEntry, tt.removeCase, tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("Container.Remove() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.checkFn != nil { + tt.checkFn(t, c) + } + }) + } +} + +func TestContainer_Loop(t *testing.T) { + c := NewContainer() + + // Add some entries + for i := 1; i <= 3; i++ { + e := NewEntry("test" + string(rune('0'+i))) + e.AddPrefix("192.168.1.0/24") + c.Add(e) + } + + count := 0 + for entry := range c.Loop() { + if entry == nil { + t.Error("Container.Loop() returned nil entry") + } + count++ + } + + if count != 3 { + t.Errorf("Container.Loop() iterated %d times, want 3", count) + } +} + +func TestContainer_Lookup(t *testing.T) { + c := NewContainer() + + // Setup test data + e1 := NewEntry("CN") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("10.0.0.0/8") + e1.AddPrefix("2001:db8:1::/48") // Add IPv6 for CN + c.Add(e1) + + e2 := NewEntry("US") + e2.AddPrefix("172.16.0.0/12") + e2.AddPrefix("2001:db8::/32") + c.Add(e2) + + tests := []struct { + name string + ipOrCidr string + searchList []string + wantFound bool + wantErr bool + checkFn func(*testing.T, []string) + }{ + { + name: "lookup IPv4 address", + ipOrCidr: "192.168.1.100", + searchList: nil, + wantFound: true, + wantErr: false, + checkFn: func(t *testing.T, results []string) { + if len(results) == 0 { + t.Error("Expected at least one result") + } + }, + }, + { + name: "lookup IPv4 CIDR", + ipOrCidr: "192.168.1.0/24", + searchList: nil, + wantFound: true, + wantErr: false, + }, + { + name: "lookup IPv6 address", + ipOrCidr: "2001:db8::1", + searchList: nil, + wantFound: true, + wantErr: false, + }, + { + name: "lookup IPv6 CIDR", + ipOrCidr: "2001:db8::/32", + searchList: nil, + wantFound: true, + wantErr: false, + }, + { + name: "lookup with search list", + ipOrCidr: "192.168.1.100", + searchList: []string{"CN"}, + wantFound: true, + wantErr: false, + }, + { + name: "lookup not in search list", + ipOrCidr: "172.16.1.1", + searchList: []string{"CN"}, + wantFound: false, + wantErr: false, + }, + { + name: "lookup non-existent IP", + ipOrCidr: "1.1.1.1", + searchList: nil, + wantFound: false, + wantErr: false, + }, + { + name: "invalid IP", + ipOrCidr: "invalid", + searchList: nil, + wantFound: false, + wantErr: true, + }, + { + name: "invalid CIDR", + ipOrCidr: "invalid/24", + searchList: nil, + wantFound: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results, found, err := c.Lookup(tt.ipOrCidr, tt.searchList...) + if (err != nil) != tt.wantErr { + t.Errorf("Container.Lookup() error = %v, wantErr %v", err, tt.wantErr) + return + } + if found != tt.wantFound { + t.Errorf("Container.Lookup() found = %v, wantFound %v", found, tt.wantFound) + } + if tt.checkFn != nil { + tt.checkFn(t, results) + } + }) + } +} + +func TestContainer_RemoveWithPrefixCase(t *testing.T) { + c := NewContainer() + + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("10.0.0.0/8") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + // Create entry with prefixes to remove + removeEntry := NewEntry("test") + removeEntry.AddPrefix("192.168.1.0/24") + + // Remove with CaseRemovePrefix and ignore IPv4 + err := c.Remove(removeEntry, CaseRemovePrefix, IgnoreIPv4) + if err != nil { + t.Errorf("Container.Remove() error = %v", err) + } +} + +func TestContainer_InvalidRemoveCase(t *testing.T) { + c := NewContainer() + + e := NewEntry("test") + e.AddPrefix("192.168.1.0/24") + c.Add(e) + + // Try to remove with invalid case + err := c.Remove(e, CaseRemove(99)) + if err == nil { + t.Error("Container.Remove() should return error for invalid case") + } +} + +func TestContainer_AddWithMerging(t *testing.T) { + c := NewContainer() + + // Add first entry + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + // Add second entry with same name - should merge + e2 := NewEntry("test") + e2.AddPrefix("10.0.0.0/8") + e2.AddPrefix("2001:db9::/32") + c.Add(e2) + + if c.Len() != 1 { + t.Errorf("Container.Len() = %d, want 1 (entries should merge)", c.Len()) + } + + // Verify merged entry has prefixes from both + entry, found := c.GetEntry("test") + if !found { + t.Fatal("Entry not found after merge") + } + + prefixes, err := entry.MarshalPrefix() + if err != nil { + t.Errorf("MarshalPrefix() error = %v", err) + } + // After merging, we should have at least some prefixes + if len(prefixes) == 0 { + t.Error("Merged entry has no prefixes") + } +} + +func TestContainer_AddWithIgnoreOptions(t *testing.T) { + tests := []struct { + name string + opts []IgnoreIPOption + checkFn func(*testing.T, Container) + }{ + { + name: "add new entry with ignore IPv4", + opts: []IgnoreIPOption{IgnoreIPv4}, + checkFn: func(t *testing.T, c Container) { + entry, _ := c.GetEntry("test2") + // When adding a new entry with ignore IPv4, IPv4 should not be added + _, err := entry.GetIPv4Set() + if err == nil { + t.Error("Expected no IPv4 set for new entry when ignored") + } + }, + }, + { + name: "add new entry with ignore IPv6", + opts: []IgnoreIPOption{IgnoreIPv6}, + checkFn: func(t *testing.T, c Container) { + entry, _ := c.GetEntry("test2") + // When adding a new entry with ignore IPv6, IPv6 should not be added + _, err := entry.GetIPv6Set() + if err == nil { + t.Error("Expected no IPv6 set for new entry when ignored") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewContainer() + + e := NewEntry("test2") + e.AddPrefix("192.168.1.0/24") + e.AddPrefix("2001:db8::/32") + c.Add(e, tt.opts...) + + if tt.checkFn != nil { + tt.checkFn(t, c) + } + }) + } +} + +func TestContainer_RemoveWithPrefixAndIgnoreOptions(t *testing.T) { + tests := []struct { + name string + opts []IgnoreIPOption + checkFn func(*testing.T, Container) + }{ + { + name: "remove prefix with ignore IPv4", + opts: []IgnoreIPOption{IgnoreIPv4}, + checkFn: func(t *testing.T, c Container) { + entry, _ := c.GetEntry("test") + // IPv6 should be removed, IPv4 should remain + _, err := entry.GetIPv4Set() + if err != nil { + t.Error("IPv4 should still exist") + } + }, + }, + { + name: "remove prefix with ignore IPv6", + opts: []IgnoreIPOption{IgnoreIPv6}, + checkFn: func(t *testing.T, c Container) { + entry, _ := c.GetEntry("test") + // IPv4 should be removed, IPv6 should remain + _, err := entry.GetIPv6Set() + if err != nil { + t.Error("IPv6 should still exist") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := NewContainer() + + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + removeEntry := NewEntry("test") + removeEntry.AddPrefix("192.168.1.0/24") + removeEntry.AddPrefix("2001:db8::/32") + + err := c.Remove(removeEntry, CaseRemovePrefix, tt.opts...) + if err != nil { + t.Errorf("Container.Remove() error = %v", err) + } + + if tt.checkFn != nil { + tt.checkFn(t, c) + } + }) + } +} + +func TestContainer_GetEntryInvalidContainer(t *testing.T) { + // Create an invalid container (nil map) + c := &container{entries: nil} + + _, found := c.GetEntry("test") + if found { + t.Error("GetEntry() should return false for invalid container") + } +} + +func TestContainer_LenInvalidContainer(t *testing.T) { + // Create an invalid container (nil map) + c := &container{entries: nil} + + if c.Len() != 0 { + t.Errorf("Len() = %d, want 0 for invalid container", c.Len()) + } +} + +func TestContainer_AddMergingEdgeCases(t *testing.T) { + c := NewContainer() + + // Add entry with both IPv4 and IPv6 + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + // Merge with entry that has only IPv4 - should merge both + e2 := NewEntry("test") + e2.AddPrefix("10.0.0.0/8") + c.Add(e2) + + entry, _ := c.GetEntry("test") + // Should have both IPv4 and IPv6 + _, err4 := entry.GetIPv4Set() + _, err6 := entry.GetIPv6Set() + if err4 != nil || err6 != nil { + t.Error("Both IPv4 and IPv6 should exist after merge") + } + + // Now merge with ignore options on existing entry + e3 := NewEntry("test") + e3.AddPrefix("172.16.0.0/12") + e3.AddPrefix("2001:db9::/32") + c.Add(e3, IgnoreIPv4) + + // IPv4 should still exist, IPv6 should be updated + entry, _ = c.GetEntry("test") + _, err := entry.GetIPv4Set() + if err != nil { + t.Error("IPv4 should still exist when merging with IgnoreIPv4") + } +} + +func TestContainer_AddMergingWithExistingBuilders(t *testing.T) { + c := NewContainer() + + // Test merge when val already has builders (lines 102-109) + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + // Merge another entry (default case, both builders exist) + e2 := NewEntry("test") + e2.AddPrefix("10.0.0.0/8") + e2.AddPrefix("2001:db9::/32") + c.Add(e2) // This should hit lines 102-109 + + entry, _ := c.GetEntry("test") + prefixes, _ := entry.MarshalPrefix() + if len(prefixes) == 0 { + t.Error("Should have prefixes after merge") + } +} + +func TestContainer_AddMergingIgnoreIPv6WithExistingBuilder(t *testing.T) { + c := NewContainer() + + // Add entry with IPv4 and IPv6 + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + // Merge with IgnoreIPv6 when val already has IPv4 builder (lines 97-100) + e2 := NewEntry("test") + e2.AddPrefix("10.0.0.0/8") + e2.AddPrefix("2001:db9::/32") + c.Add(e2, IgnoreIPv6) + + entry, _ := c.GetEntry("test") + _, err := entry.GetIPv4Set() + if err != nil { + t.Error("IPv4 should exist after merge with IgnoreIPv6") + } +} + +func TestContainer_AddMergingIgnoreIPv4WithExistingBuilder(t *testing.T) { + c := NewContainer() + + // Add entry with IPv4 and IPv6 + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + e1.AddPrefix("2001:db8::/32") + c.Add(e1) + + // Merge with IgnoreIPv4 when val already has IPv6 builder (lines 92-95) + e2 := NewEntry("test") + e2.AddPrefix("10.0.0.0/8") + e2.AddPrefix("2001:db9::/32") + c.Add(e2, IgnoreIPv4) + + entry, _ := c.GetEntry("test") + _, err := entry.GetIPv6Set() + if err != nil { + t.Error("IPv6 should exist after merge with IgnoreIPv4") + } +} + +func TestContainer_RemoveNilBuilders(t *testing.T) { + c := NewContainer() + + // Add entry + e1 := NewEntry("test") + e1.AddPrefix("192.168.1.0/24") + c.Add(e1) + + // Try to remove prefixes when entry has no IPv6 builder + removeEntry := NewEntry("test") + removeEntry.AddPrefix("2001:db8::/32") // IPv6 + + err := c.Remove(removeEntry, CaseRemovePrefix) + if err != nil { + t.Errorf("Remove should handle missing builder gracefully, got error: %v", err) + } +} diff --git a/lib/converter_test.go b/lib/converter_test.go new file mode 100644 index 00000000000..9910c2113a0 --- /dev/null +++ b/lib/converter_test.go @@ -0,0 +1,177 @@ +package lib + +import ( + "strings" + "testing" +) + +func TestRegisterInputConverter(t *testing.T) { + // Save original state + originalMap := inputConverterMap + defer func() { inputConverterMap = originalMap }() + + // Reset map for testing + inputConverterMap = make(map[string]InputConverter) + + tests := []struct { + name string + converter string + wantErr bool + }{ + { + name: "register new converter", + converter: "test1", + wantErr: false, + }, + { + name: "register duplicate converter", + converter: "test1", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock converter + mockConverter := &mockInputConverter{typ: tt.converter} + err := RegisterInputConverter(tt.converter, mockConverter) + if (err != nil) != tt.wantErr { + t.Errorf("RegisterInputConverter() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRegisterOutputConverter(t *testing.T) { + // Save original state + originalMap := outputConverterMap + defer func() { outputConverterMap = originalMap }() + + // Reset map for testing + outputConverterMap = make(map[string]OutputConverter) + + tests := []struct { + name string + converter string + wantErr bool + }{ + { + name: "register new converter", + converter: "test1", + wantErr: false, + }, + { + name: "register duplicate converter", + converter: "test1", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a mock converter + mockConverter := &mockOutputConverter{typ: tt.converter} + err := RegisterOutputConverter(tt.converter, mockConverter) + if (err != nil) != tt.wantErr { + t.Errorf("RegisterOutputConverter() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestListInputConverter(t *testing.T) { + // Save original state + originalMap := inputConverterMap + defer func() { inputConverterMap = originalMap }() + + // Reset and populate map + inputConverterMap = make(map[string]InputConverter) + inputConverterMap["test"] = &mockInputConverter{typ: "test"} + + // Just test that it doesn't panic + ListInputConverter() +} + +func TestListOutputConverter(t *testing.T) { + // Save original state + originalMap := outputConverterMap + defer func() { outputConverterMap = originalMap }() + + // Reset and populate map + outputConverterMap = make(map[string]OutputConverter) + outputConverterMap["test"] = &mockOutputConverter{typ: "test"} + + // Just test that it doesn't panic + ListOutputConverter() +} + +// Mock converters for testing +type mockInputConverter struct { + typ string +} + +func (m *mockInputConverter) GetType() string { + return m.typ +} + +func (m *mockInputConverter) GetAction() Action { + return ActionAdd +} + +func (m *mockInputConverter) GetDescription() string { + return "mock input converter" +} + +func (m *mockInputConverter) Input(c Container) (Container, error) { + return c, nil +} + +type mockOutputConverter struct { + typ string +} + +func (m *mockOutputConverter) GetType() string { + return m.typ +} + +func (m *mockOutputConverter) GetAction() Action { + return ActionOutput +} + +func (m *mockOutputConverter) GetDescription() string { + return "mock output converter" +} + +func (m *mockOutputConverter) Output(c Container) error { + return nil +} + +func TestRegisterConverterWithWhitespace(t *testing.T) { + // Save original state + originalMap := inputConverterMap + defer func() { inputConverterMap = originalMap }() + + // Reset map for testing + inputConverterMap = make(map[string]InputConverter) + + mockConverter := &mockInputConverter{typ: "test"} + err := RegisterInputConverter(" test ", mockConverter) + if err != nil { + t.Errorf("RegisterInputConverter() with whitespace should not error: %v", err) + } + + // Verify it was registered with trimmed name + if _, ok := inputConverterMap["test"]; !ok { + // Check without trim + found := false + for k := range inputConverterMap { + if strings.TrimSpace(k) == "test" { + found = true + break + } + } + if !found { + t.Error("Converter not registered properly with whitespace") + } + } +} diff --git a/lib/entry_test.go b/lib/entry_test.go new file mode 100644 index 00000000000..ae213d887c6 --- /dev/null +++ b/lib/entry_test.go @@ -0,0 +1,751 @@ +package lib + +import ( + "net" + "net/netip" + "testing" + + "go4.org/netipx" +) + +func TestNewEntry(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple name", + input: "test", + expected: "TEST", + }, + { + name: "lowercase name", + input: "lowercase", + expected: "LOWERCASE", + }, + { + name: "name with spaces", + input: " test name ", + expected: "TEST NAME", + }, + { + name: "mixed case", + input: "MiXeD CaSe", + expected: "MIXED CASE", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + entry := NewEntry(tt.input) + if entry.GetName() != tt.expected { + t.Errorf("NewEntry(%q).GetName() = %q, want %q", tt.input, entry.GetName(), tt.expected) + } + }) + } +} + +func TestEntry_AddPrefix(t *testing.T) { + tests := []struct { + name string + cidr any + wantErr bool + }{ + { + name: "valid IPv4 CIDR string", + cidr: "192.168.1.0/24", + wantErr: false, + }, + { + name: "valid IPv6 CIDR string", + cidr: "2001:db8::/32", + wantErr: false, + }, + { + name: "valid IPv4 address string", + cidr: "192.168.1.1", + wantErr: false, + }, + { + name: "valid IPv6 address string", + cidr: "2001:db8::1", + wantErr: false, + }, + { + name: "invalid CIDR", + cidr: "invalid/cidr", + wantErr: true, + }, + { + name: "invalid IP", + cidr: "999.999.999.999", + wantErr: true, + }, + { + name: "net.IP type", + cidr: net.ParseIP("192.168.1.1"), + wantErr: false, + }, + { + name: "net.IPNet type", + cidr: &net.IPNet{IP: net.ParseIP("192.168.1.0"), Mask: net.CIDRMask(24, 32)}, + wantErr: false, + }, + { + name: "netip.Addr type", + cidr: netip.MustParseAddr("192.168.1.1"), + wantErr: false, + }, + { + name: "netip.Prefix type", + cidr: netip.MustParsePrefix("192.168.1.0/24"), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + entry := NewEntry("test") + err := entry.AddPrefix(tt.cidr) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.AddPrefix() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestEntry_RemovePrefix(t *testing.T) { + entry := NewEntry("test") + // First add a prefix + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("10.0.0.0/8") + + tests := []struct { + name string + cidr string + wantErr bool + }{ + { + name: "valid CIDR", + cidr: "192.168.1.0/24", + wantErr: false, + }, + { + name: "invalid CIDR", + cidr: "invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := entry.RemovePrefix(tt.cidr) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.RemovePrefix() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestEntry_GetIPv4Set(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + + set, err := entry.GetIPv4Set() + if err != nil { + t.Errorf("Entry.GetIPv4Set() error = %v", err) + } + if set == nil { + t.Error("Entry.GetIPv4Set() returned nil set") + } + + // Test entry with no IPv4 + entry2 := NewEntry("test2") + entry2.AddPrefix("2001:db8::/32") + _, err = entry2.GetIPv4Set() + if err == nil { + t.Error("Entry.GetIPv4Set() should return error for entry with no IPv4") + } +} + +func TestEntry_GetIPv6Set(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("2001:db8::/32") + + set, err := entry.GetIPv6Set() + if err != nil { + t.Errorf("Entry.GetIPv6Set() error = %v", err) + } + if set == nil { + t.Error("Entry.GetIPv6Set() returned nil set") + } + + // Test entry with no IPv6 + entry2 := NewEntry("test2") + entry2.AddPrefix("192.168.1.0/24") + _, err = entry2.GetIPv6Set() + if err == nil { + t.Error("Entry.GetIPv6Set() should return error for entry with no IPv6") + } +} + +func TestEntry_MarshalPrefix(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("10.0.0.0/8") + entry.AddPrefix("2001:db8::/32") + + tests := []struct { + name string + opts []IgnoreIPOption + wantErr bool + checkFn func(*testing.T, []netip.Prefix) + }{ + { + name: "no options", + opts: nil, + wantErr: false, + checkFn: func(t *testing.T, prefixes []netip.Prefix) { + if len(prefixes) != 3 { + t.Errorf("MarshalPrefix() returned %d prefixes, want 3", len(prefixes)) + } + }, + }, + { + name: "ignore IPv4", + opts: []IgnoreIPOption{IgnoreIPv4}, + wantErr: false, + checkFn: func(t *testing.T, prefixes []netip.Prefix) { + if len(prefixes) != 1 { + t.Errorf("MarshalPrefix() returned %d prefixes, want 1", len(prefixes)) + } + }, + }, + { + name: "ignore IPv6", + opts: []IgnoreIPOption{IgnoreIPv6}, + wantErr: false, + checkFn: func(t *testing.T, prefixes []netip.Prefix) { + if len(prefixes) != 2 { + t.Errorf("MarshalPrefix() returned %d prefixes, want 2", len(prefixes)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := entry.MarshalPrefix(tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.MarshalPrefix() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.checkFn != nil { + tt.checkFn(t, got) + } + }) + } +} + +func TestEntry_MarshalIPRange(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("2001:db8::/32") + + ranges, err := entry.MarshalIPRange() + if err != nil { + t.Errorf("Entry.MarshalIPRange() error = %v", err) + } + if len(ranges) != 2 { + t.Errorf("Entry.MarshalIPRange() returned %d ranges, want 2", len(ranges)) + } +} + +func TestEntry_MarshalText(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("2001:db8::/32") + + text, err := entry.MarshalText() + if err != nil { + t.Errorf("Entry.MarshalText() error = %v", err) + } + if len(text) != 2 { + t.Errorf("Entry.MarshalText() returned %d lines, want 2", len(text)) + } +} + +func TestEntry_EmptyEntry(t *testing.T) { + entry := NewEntry("empty") + + _, err := entry.MarshalPrefix() + if err == nil { + t.Error("Entry.MarshalPrefix() should return error for empty entry") + } + + _, err = entry.MarshalIPRange() + if err == nil { + t.Error("Entry.MarshalIPRange() should return error for empty entry") + } + + _, err = entry.MarshalText() + if err == nil { + t.Error("Entry.MarshalText() should return error for empty entry") + } +} + +func TestEntry_ProcessPrefix_IPv4Mapped(t *testing.T) { + entry := NewEntry("test") + + // Test IPv4-mapped IPv6 address + err := entry.AddPrefix("::ffff:192.168.1.1") + if err != nil { + t.Errorf("Entry.AddPrefix() with IPv4-mapped address error = %v", err) + } +} + +func TestEntry_ProcessPrefix_EdgeCases(t *testing.T) { + entry := NewEntry("test") + + tests := []struct { + name string + input any + wantErr bool + }{ + { + name: "CIDR with comment", + input: "192.168.1.0/24 # comment", + wantErr: false, + }, + { + name: "IP with inline comment", + input: "192.168.1.1 // comment", + wantErr: false, + }, + { + name: "pointer to netip.Addr", + input: func() *netip.Addr { a := netip.MustParseAddr("192.168.1.1"); return &a }(), + wantErr: false, + }, + { + name: "pointer to netip.Prefix", + input: func() *netip.Prefix { p := netip.MustParsePrefix("192.168.1.0/24"); return &p }(), + wantErr: false, + }, + { + name: "unsupported type", + input: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := entry.AddPrefix(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.AddPrefix(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestEntry_AddPrefix_IPv4In6(t *testing.T) { + entry := NewEntry("test") + + // Create an IPv4-in-IPv6 prefix + prefix := netip.MustParsePrefix("::ffff:c0a8:0100/120") // ::ffff:192.168.1.0/120 + err := entry.AddPrefix(prefix) + if err != nil { + t.Errorf("Entry.AddPrefix() with IPv4-in-IPv6 error = %v", err) + } +} + +func TestEntry_InvalidIPLengthCases(t *testing.T) { + entry := NewEntry("test") + + // Test with invalid IP that could trigger ErrInvalidIPLength + invalidIP := net.IP{} // Invalid empty IP + err := entry.AddPrefix(invalidIP) + if err == nil { + t.Error("Entry.AddPrefix() should return error for invalid IP") + } +} + +func TestEntry_AddPrefix_WithInvalidIPv4MappedCIDR(t *testing.T) { + entry := NewEntry("test") + + // Test IPv4-mapped IPv6 CIDR - this is actually valid and should work + err := entry.AddPrefix("::ffff:192.168.1.0/120") + if err != nil { + // If it errors, that's fine - just checking the code path + t.Logf("AddPrefix with IPv4-mapped CIDR returned: %v", err) + } +} + +func TestEntry_RemovePrefixVariousCases(t *testing.T) { + tests := []struct { + name string + addPrefixes []string + removePrefixes []string + wantErr bool + }{ + { + name: "remove IPv4 prefix", + addPrefixes: []string{"192.168.1.0/24", "10.0.0.0/8"}, + removePrefixes: []string{"192.168.1.0/24"}, + wantErr: false, + }, + { + name: "remove IPv6 prefix", + addPrefixes: []string{"2001:db8::/32", "2001:db9::/32"}, + removePrefixes: []string{"2001:db8::/32"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + entry := NewEntry("test") + for _, prefix := range tt.addPrefixes { + entry.AddPrefix(prefix) + } + for _, prefix := range tt.removePrefixes { + err := entry.RemovePrefix(prefix) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.RemovePrefix() error = %v, wantErr %v", err, tt.wantErr) + } + } + }) + } +} + +func TestEntry_MarshalIPRangeWithOptions(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("10.0.0.0/8") + entry.AddPrefix("2001:db8::/32") + + tests := []struct { + name string + opts []IgnoreIPOption + wantErr bool + checkFn func(*testing.T, []netipx.IPRange) + }{ + { + name: "no options", + opts: nil, + wantErr: false, + checkFn: func(t *testing.T, ranges []netipx.IPRange) { + if len(ranges) != 3 { + t.Errorf("MarshalIPRange() returned %d ranges, want 3", len(ranges)) + } + }, + }, + { + name: "ignore IPv4", + opts: []IgnoreIPOption{IgnoreIPv4}, + wantErr: false, + checkFn: func(t *testing.T, ranges []netipx.IPRange) { + if len(ranges) != 1 { + t.Errorf("MarshalIPRange() returned %d ranges, want 1 (IPv6 only)", len(ranges)) + } + }, + }, + { + name: "ignore IPv6", + opts: []IgnoreIPOption{IgnoreIPv6}, + wantErr: false, + checkFn: func(t *testing.T, ranges []netipx.IPRange) { + if len(ranges) != 2 { + t.Errorf("MarshalIPRange() returned %d ranges, want 2 (IPv4 only)", len(ranges)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := entry.MarshalIPRange(tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.MarshalIPRange() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.checkFn != nil { + tt.checkFn(t, got) + } + }) + } +} + +func TestEntry_MarshalTextWithOptions(t *testing.T) { + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("10.0.0.0/8") + entry.AddPrefix("2001:db8::/32") + + tests := []struct { + name string + opts []IgnoreIPOption + wantErr bool + checkFn func(*testing.T, []string) + }{ + { + name: "no options", + opts: nil, + wantErr: false, + checkFn: func(t *testing.T, text []string) { + if len(text) != 3 { + t.Errorf("MarshalText() returned %d lines, want 3", len(text)) + } + }, + }, + { + name: "ignore IPv4", + opts: []IgnoreIPOption{IgnoreIPv4}, + wantErr: false, + checkFn: func(t *testing.T, text []string) { + if len(text) != 1 { + t.Errorf("MarshalText() returned %d lines, want 1 (IPv6 only)", len(text)) + } + }, + }, + { + name: "ignore IPv6", + opts: []IgnoreIPOption{IgnoreIPv6}, + wantErr: false, + checkFn: func(t *testing.T, text []string) { + if len(text) != 2 { + t.Errorf("MarshalText() returned %d lines, want 2 (IPv4 only)", len(text)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := entry.MarshalText(tt.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.MarshalText() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.checkFn != nil { + tt.checkFn(t, got) + } + }) + } +} + +func TestEntry_ProcessPrefixComprehensive(t *testing.T) { + entry := NewEntry("test") + + tests := []struct { + name string + input any + wantErr bool + }{ + { + name: "IPv4 with /32", + input: "192.168.1.1/32", + wantErr: false, + }, + { + name: "IPv6 with /128", + input: "2001:db8::1/128", + wantErr: false, + }, + { + name: "netip.Prefix with IPv4In6", + input: netip.MustParsePrefix("::ffff:192.168.1.0/120"), + wantErr: false, + }, + { + name: "pointer to netip.Prefix with IPv4In6", + input: func() *netip.Prefix { p := netip.MustParsePrefix("::ffff:192.168.1.0/120"); return &p }(), + wantErr: false, + }, + { + name: "pointer to netip.Addr IPv6", + input: func() *netip.Addr { a := netip.MustParseAddr("2001:db8::1"); return &a }(), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := entry.AddPrefix(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.AddPrefix(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestEntry_ProcessPrefixErrorCases(t *testing.T) { + entry := NewEntry("test") + + // Test IPv4In6 prefix with bits < 96 (should trigger error on line 143) + // This is tricky to test because valid IPv4In6 prefixes have bits >= 96 + // Let's test other error paths + + tests := []struct { + name string + input any + wantErr bool + }{ + { + name: "string CIDR with invalid network", + input: "256.256.256.256/24", + wantErr: true, + }, + { + name: "string with only slash", + input: "/24", + wantErr: true, + }, + { + name: "netip.Prefix IPv4", + input: netip.MustParsePrefix("192.168.1.0/24"), + wantErr: false, + }, + { + name: "netip.Prefix IPv6", + input: netip.MustParsePrefix("2001:db8::/32"), + wantErr: false, + }, + { + name: "*netip.Prefix IPv4", + input: func() *netip.Prefix { p := netip.MustParsePrefix("10.0.0.0/8"); return &p }(), + wantErr: false, + }, + { + name: "*netip.Prefix IPv6", + input: func() *netip.Prefix { p := netip.MustParsePrefix("2001:db9::/32"); return &p }(), + wantErr: false, + }, + { + name: "string with /* comment", + input: "192.168.1.0/24 /* comment */", + wantErr: false, + }, + { + name: "string that becomes empty after comment removal", + input: "# just a comment", + wantErr: true, // ErrCommentLine leads to ErrInvalidIPType + }, + { + name: "string with whitespace and comment", + input: " // comment only", + wantErr: true, + }, + { + name: "net.IP with nil", + input: net.IP(nil), + wantErr: true, // Should trigger ErrInvalidIP + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := entry.AddPrefix(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.AddPrefix(%v) error = %v, wantErr %v", tt.input, err, tt.wantErr) + } + }) + } +} + +func TestEntry_ProcessPrefix_NetIPNetCases(t *testing.T) { + entry := NewEntry("test") + + tests := []struct { + name string + input *net.IPNet + wantErr bool + }{ + { + name: "valid IPv4 IPNet", + input: &net.IPNet{IP: net.ParseIP("192.168.1.0"), Mask: net.CIDRMask(24, 32)}, + wantErr: false, + }, + { + name: "valid IPv6 IPNet", + input: &net.IPNet{IP: net.ParseIP("2001:db8::"), Mask: net.CIDRMask(32, 128)}, + wantErr: false, + }, + { + name: "invalid IPNet with nil IP", + input: &net.IPNet{IP: nil, Mask: net.CIDRMask(24, 32)}, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := entry.AddPrefix(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.AddPrefix() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestEntry_ProcessPrefix_NetipAddrCases(t *testing.T) { + entry := NewEntry("test") + + tests := []struct { + name string + input netip.Addr + wantErr bool + }{ + { + name: "valid IPv4 Addr", + input: netip.MustParseAddr("192.168.1.1"), + wantErr: false, + }, + { + name: "valid IPv6 Addr", + input: netip.MustParseAddr("2001:db8::1"), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := entry.AddPrefix(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Entry.AddPrefix() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestEntry_BuildIPSetErrors(t *testing.T) { + // Test buildIPSet when it succeeds multiple times (coverage for checking existing sets) + entry := NewEntry("test") + entry.AddPrefix("192.168.1.0/24") + entry.AddPrefix("2001:db8::/32") + + // First call to buildIPSet + _, err := entry.GetIPv4Set() + if err != nil { + t.Errorf("First GetIPv4Set() error = %v", err) + } + + // Second call should use existing set + _, err = entry.GetIPv4Set() + if err != nil { + t.Errorf("Second GetIPv4Set() error = %v", err) + } + + // Same for IPv6 + _, err = entry.GetIPv6Set() + if err != nil { + t.Errorf("First GetIPv6Set() error = %v", err) + } + + _, err = entry.GetIPv6Set() + if err != nil { + t.Errorf("Second GetIPv6Set() error = %v", err) + } +} diff --git a/lib/error_test.go b/lib/error_test.go new file mode 100644 index 00000000000..d73b1ec8f13 --- /dev/null +++ b/lib/error_test.go @@ -0,0 +1,36 @@ +package lib + +import ( + "errors" + "testing" +) + +func TestErrors(t *testing.T) { + tests := []struct { + name string + err error + }{ + {"ErrDuplicatedConverter", ErrDuplicatedConverter}, + {"ErrUnknownAction", ErrUnknownAction}, + {"ErrNotSupportedFormat", ErrNotSupportedFormat}, + {"ErrInvalidIPType", ErrInvalidIPType}, + {"ErrInvalidIP", ErrInvalidIP}, + {"ErrInvalidIPLength", ErrInvalidIPLength}, + {"ErrInvalidIPNet", ErrInvalidIPNet}, + {"ErrInvalidCIDR", ErrInvalidCIDR}, + {"ErrInvalidPrefix", ErrInvalidPrefix}, + {"ErrInvalidPrefixType", ErrInvalidPrefixType}, + {"ErrCommentLine", ErrCommentLine}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.err == nil { + t.Errorf("%s should not be nil", tt.name) + } + if !errors.Is(tt.err, tt.err) { + t.Errorf("%s should match itself", tt.name) + } + }) + } +} diff --git a/lib/instance_test.go b/lib/instance_test.go new file mode 100644 index 00000000000..9ef1b7e9b04 --- /dev/null +++ b/lib/instance_test.go @@ -0,0 +1,383 @@ +package lib + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" +) + +func TestNewInstance(t *testing.T) { + instance, err := NewInstance() + if err != nil { + t.Errorf("NewInstance() error = %v", err) + } + if instance == nil { + t.Error("NewInstance() returned nil") + } +} + +func TestInstance_AddInput(t *testing.T) { + instance, _ := NewInstance() + mockInput := &mockInputConverter{typ: "test"} + + instance.AddInput(mockInput) + + // We can't directly access the input slice, but we can test Run behavior + // This is tested indirectly through other tests +} + +func TestInstance_AddOutput(t *testing.T) { + instance, _ := NewInstance() + mockOutput := &mockOutputConverter{typ: "test"} + + instance.AddOutput(mockOutput) + + // Similar to AddInput, tested indirectly +} + +func TestInstance_ResetInput(t *testing.T) { + instance, _ := NewInstance() + mockInput := &mockInputConverter{typ: "test"} + + instance.AddInput(mockInput) + instance.ResetInput() + + // After reset, Run should fail due to no inputs + err := instance.Run() + if err == nil { + t.Error("Instance.Run() should fail after ResetInput") + } +} + +func TestInstance_ResetOutput(t *testing.T) { + instance, _ := NewInstance() + mockOutput := &mockOutputConverter{typ: "test"} + + instance.AddOutput(mockOutput) + instance.ResetOutput() + + // After reset, Run should fail due to no outputs + err := instance.Run() + if err == nil { + t.Error("Instance.Run() should fail after ResetOutput") + } +} + +func TestInstance_Run(t *testing.T) { + tests := []struct { + name string + setupFn func(Instance) + wantErr bool + }{ + { + name: "no input or output", + setupFn: func(i Instance) { + // Don't add anything + }, + wantErr: true, + }, + { + name: "only input", + setupFn: func(i Instance) { + i.AddInput(&mockInputConverter{typ: "test"}) + }, + wantErr: true, + }, + { + name: "only output", + setupFn: func(i Instance) { + i.AddOutput(&mockOutputConverter{typ: "test"}) + }, + wantErr: true, + }, + { + name: "both input and output", + setupFn: func(i Instance) { + i.AddInput(&mockInputConverter{typ: "test"}) + i.AddOutput(&mockOutputConverter{typ: "test"}) + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + instance, _ := NewInstance() + tt.setupFn(instance) + + err := instance.Run() + if (err != nil) != tt.wantErr { + t.Errorf("Instance.Run() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestInstance_RunInput(t *testing.T) { + instance, _ := NewInstance() + instance.AddInput(&mockInputConverter{typ: "test"}) + + container := NewContainer() + err := instance.RunInput(container) + if err != nil { + t.Errorf("Instance.RunInput() error = %v", err) + } +} + +func TestInstance_RunOutput(t *testing.T) { + instance, _ := NewInstance() + instance.AddOutput(&mockOutputConverter{typ: "test"}) + + container := NewContainer() + err := instance.RunOutput(container) + if err != nil { + t.Errorf("Instance.RunOutput() error = %v", err) + } +} + +func TestInstance_InitConfigFromBytes(t *testing.T) { + // Save original state + originalInputCache := inputConfigCreatorCache + originalOutputCache := outputConfigCreatorCache + defer func() { + inputConfigCreatorCache = originalInputCache + outputConfigCreatorCache = originalOutputCache + }() + + // Reset caches + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + + // Register test creators + RegisterInputConfigCreator("test", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typ: "test"}, nil + }) + RegisterOutputConfigCreator("test", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typ: "test"}, nil + }) + + tests := []struct { + name string + config string + wantErr bool + }{ + { + name: "valid config", + config: `{ + "input": [{"type": "test", "action": "add", "args": {}}], + "output": [{"type": "test", "action": "output", "args": {}}] + }`, + wantErr: false, + }, + { + name: "config with comments", + config: `{ + // This is a comment + "input": [{"type": "test", "action": "add", "args": {}}], + "output": [{"type": "test", "action": "output", "args": {}}] + }`, + wantErr: false, + }, + { + name: "config with trailing comma", + config: `{ + "input": [{"type": "test", "action": "add", "args": {}}], + "output": [{"type": "test", "action": "output", "args": {}}], + }`, + wantErr: false, + }, + { + name: "invalid JSON", + config: `{invalid}`, + wantErr: true, + }, + { + name: "unknown input type", + config: `{ + "input": [{"type": "unknown", "action": "add", "args": {}}], + "output": [{"type": "test", "action": "output", "args": {}}] + }`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + instance, _ := NewInstance() + err := instance.InitConfigFromBytes([]byte(tt.config)) + if (err != nil) != tt.wantErr { + t.Errorf("Instance.InitConfigFromBytes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestInstance_InitConfig(t *testing.T) { + // Save original state + originalInputCache := inputConfigCreatorCache + originalOutputCache := outputConfigCreatorCache + defer func() { + inputConfigCreatorCache = originalInputCache + outputConfigCreatorCache = originalOutputCache + }() + + // Reset caches + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + + // Register test creators + RegisterInputConfigCreator("test", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typ: "test"}, nil + }) + RegisterOutputConfigCreator("test", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typ: "test"}, nil + }) + + configContent := `{ + "input": [{"type": "test", "action": "add", "args": {}}], + "output": [{"type": "test", "action": "output", "args": {}}] + }` + + t.Run("load from file", func(t *testing.T) { + // Create a temporary config file + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.json") + if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { + t.Fatalf("Failed to create test config file: %v", err) + } + + instance, _ := NewInstance() + err := instance.InitConfig(configFile) + if err != nil { + t.Errorf("Instance.InitConfig() error = %v", err) + } + }) + + t.Run("load from HTTP URL", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(configContent)) + })) + defer server.Close() + + instance, _ := NewInstance() + err := instance.InitConfig(server.URL) + if err != nil { + t.Errorf("Instance.InitConfig() with HTTP URL error = %v", err) + } + }) + + t.Run("load from HTTPS URL", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(configContent)) + })) + defer server.Close() + + instance, _ := NewInstance() + // This may fail due to certificate issues in test, but we test the code path + instance.InitConfig(server.URL) + }) + + t.Run("non-existent file", func(t *testing.T) { + instance, _ := NewInstance() + err := instance.InitConfig("/nonexistent/config.json") + if err == nil { + t.Error("Instance.InitConfig() should return error for non-existent file") + } + }) + + t.Run("URL with whitespace", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(configContent)) + })) + defer server.Close() + + instance, _ := NewInstance() + err := instance.InitConfig(" " + server.URL + " ") + if err != nil { + t.Errorf("Instance.InitConfig() with whitespace error = %v", err) + } + }) +} + +func TestInstance_RunInputError(t *testing.T) { + instance, _ := NewInstance() + + // Add an input converter that returns an error + instance.AddInput(&mockErrorInputConverter{}) + + container := NewContainer() + err := instance.RunInput(container) + if err == nil { + t.Error("Instance.RunInput() should return error when converter fails") + } +} + +func TestInstance_RunOutputError(t *testing.T) { + instance, _ := NewInstance() + + // Add an output converter that returns an error + instance.AddOutput(&mockErrorOutputConverter{}) + + container := NewContainer() + err := instance.RunOutput(container) + if err == nil { + t.Error("Instance.RunOutput() should return error when converter fails") + } +} + +func TestInstance_RunWithContainer(t *testing.T) { + // Save original state + originalInputCache := inputConfigCreatorCache + originalOutputCache := outputConfigCreatorCache + defer func() { + inputConfigCreatorCache = originalInputCache + outputConfigCreatorCache = originalOutputCache + }() + + // Reset caches + inputConfigCreatorCache = make(map[string]inputConfigCreator) + outputConfigCreatorCache = make(map[string]outputConfigCreator) + + // Register test creators + RegisterInputConfigCreator("test", func(action Action, data json.RawMessage) (InputConverter, error) { + return &mockInputConverter{typ: "test"}, nil + }) + RegisterOutputConfigCreator("test", func(action Action, data json.RawMessage) (OutputConverter, error) { + return &mockOutputConverter{typ: "test"}, nil + }) + + // Test full run with both input and output + instance, _ := NewInstance() + instance.AddInput(&mockInputConverter{typ: "test"}) + instance.AddOutput(&mockOutputConverter{typ: "test"}) + + err := instance.Run() + if err != nil { + t.Errorf("Instance.Run() error = %v", err) + } +} + +// Mock error converters +type mockErrorInputConverter struct{} + +func (m *mockErrorInputConverter) GetType() string { return "error" } +func (m *mockErrorInputConverter) GetAction() Action { return ActionAdd } +func (m *mockErrorInputConverter) GetDescription() string { return "error converter" } +func (m *mockErrorInputConverter) Input(c Container) (Container, error) { + return nil, ErrNotSupportedFormat +} + +type mockErrorOutputConverter struct{} + +func (m *mockErrorOutputConverter) GetType() string { return "error" } +func (m *mockErrorOutputConverter) GetAction() Action { return ActionOutput } +func (m *mockErrorOutputConverter) GetDescription() string { return "error converter" } +func (m *mockErrorOutputConverter) Output(c Container) error { + return ErrNotSupportedFormat +} diff --git a/lib/lib_test.go b/lib/lib_test.go new file mode 100644 index 00000000000..dabed27ef02 --- /dev/null +++ b/lib/lib_test.go @@ -0,0 +1,66 @@ +package lib + +import "testing" + +func TestConstants(t *testing.T) { + tests := []struct { + name string + action Action + expected bool + }{ + {"ActionAdd in registry", ActionAdd, true}, + {"ActionRemove in registry", ActionRemove, true}, + {"ActionOutput in registry", ActionOutput, true}, + {"Invalid action not in registry", Action("invalid"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ActionsRegistry[tt.action]; got != tt.expected { + t.Errorf("ActionsRegistry[%s] = %v, want %v", tt.action, got, tt.expected) + } + }) + } +} + +func TestActionConstants(t *testing.T) { + if ActionAdd != "add" { + t.Errorf("ActionAdd = %s, want 'add'", ActionAdd) + } + if ActionRemove != "remove" { + t.Errorf("ActionRemove = %s, want 'remove'", ActionRemove) + } + if ActionOutput != "output" { + t.Errorf("ActionOutput = %s, want 'output'", ActionOutput) + } +} + +func TestIPTypeConstants(t *testing.T) { + if IPv4 != "ipv4" { + t.Errorf("IPv4 = %s, want 'ipv4'", IPv4) + } + if IPv6 != "ipv6" { + t.Errorf("IPv6 = %s, want 'ipv6'", IPv6) + } +} + +func TestCaseRemoveConstants(t *testing.T) { + if CaseRemovePrefix != 0 { + t.Errorf("CaseRemovePrefix = %d, want 0", CaseRemovePrefix) + } + if CaseRemoveEntry != 1 { + t.Errorf("CaseRemoveEntry = %d, want 1", CaseRemoveEntry) + } +} + +func TestIgnoreIPv4(t *testing.T) { + if got := IgnoreIPv4(); got != IPv4 { + t.Errorf("IgnoreIPv4() = %v, want %v", got, IPv4) + } +} + +func TestIgnoreIPv6(t *testing.T) { + if got := IgnoreIPv6(); got != IPv6 { + t.Errorf("IgnoreIPv6() = %v, want %v", got, IPv6) + } +}