Skip to content

Commit b095b3e

Browse files
chore(pihole): reduce cyclometic complexity (#5802)
1 parent 4aac687 commit b095b3e

File tree

1 file changed

+73
-108
lines changed

1 file changed

+73
-108
lines changed

provider/pihole/client_test.go

Lines changed: 73 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,32 @@ func TestNewPiholeClient(t *testing.T) {
9696
}
9797
}
9898

99+
// Helper function to validate records against expected values
100+
func ValidateRecords(t *testing.T, records []*endpoint.Endpoint, expected [][]string, expectedCount int, recordType string) {
101+
t.Helper()
102+
if len(records) != expectedCount {
103+
t.Fatalf("Expected %d %s records returned, got: %d", expectedCount, recordType, len(records))
104+
}
105+
for idx, rec := range records {
106+
if rec.DNSName != expected[idx][0] {
107+
t.Errorf("Got invalid DNS Name: %s, expected: %s", rec.DNSName, expected[idx][0])
108+
}
109+
if rec.Targets[0] != expected[idx][1] {
110+
t.Errorf("Got invalid target: %s, expected: %s", rec.Targets[0], expected[idx][1])
111+
}
112+
}
113+
}
114+
115+
// Helper function to test record retrieval for a specific type
116+
func CheckRecordRetrieval(t *testing.T, cl *piholeClient, recordType string, expected [][]string, expectedCount int) {
117+
t.Helper()
118+
records, err := cl.listRecords(context.Background(), recordType)
119+
if err != nil {
120+
t.Fatal(err)
121+
}
122+
ValidateRecords(t, records, expected, expectedCount, recordType)
123+
}
124+
99125
func TestListRecords(t *testing.T) {
100126
srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
101127
r.ParseForm()
@@ -140,145 +166,81 @@ func TestListRecords(t *testing.T) {
140166
}
141167

142168
// Test retrieve A records unfiltered
143-
arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA)
144-
if err != nil {
145-
t.Fatal(err)
146-
}
147-
if len(arecs) != 3 {
148-
t.Fatal("Expected 3 A records returned, got:", len(arecs))
149-
}
150-
// Ensure records were parsed correctly
151-
expected := [][]string{
169+
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{
152170
{"test1.example.com", "192.168.1.1"},
153171
{"test2.example.com", "192.168.1.2"},
154172
{"test3.match.com", "192.168.1.3"},
155-
}
156-
for idx, rec := range arecs {
157-
if rec.DNSName != expected[idx][0] {
158-
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
159-
}
160-
if rec.Targets[0] != expected[idx][1] {
161-
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
162-
}
163-
}
173+
}, 3)
164174

165175
// Test retrieve AAAA records unfiltered
166-
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA)
167-
if err != nil {
168-
t.Fatal(err)
169-
}
170-
if len(arecs) != 3 {
171-
t.Fatal("Expected 3 AAAA records returned, got:", len(arecs))
172-
}
173-
// Ensure records were parsed correctly
174-
expected = [][]string{
176+
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{
175177
{"test1.example.com", "fc00::1:192:168:1:1"},
176178
{"test2.example.com", "fc00::1:192:168:1:2"},
177179
{"test3.match.com", "fc00::1:192:168:1:3"},
178-
}
179-
for idx, rec := range arecs {
180-
if rec.DNSName != expected[idx][0] {
181-
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
182-
}
183-
if rec.Targets[0] != expected[idx][1] {
184-
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
185-
}
186-
}
180+
}, 3)
187181

188182
// Test retrieve CNAME records unfiltered
189-
cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
190-
if err != nil {
191-
t.Fatal(err)
192-
}
193-
if len(cnamerecs) != 3 {
194-
t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs))
195-
}
196-
// Ensure records were parsed correctly
197-
expected = [][]string{
183+
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{
198184
{"test4.example.com", "cname.example.com"},
199185
{"test5.example.com", "cname.example.com"},
200186
{"test6.match.com", "cname.example.com"},
201-
}
202-
for idx, rec := range cnamerecs {
203-
if rec.DNSName != expected[idx][0] {
204-
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
205-
}
206-
if rec.Targets[0] != expected[idx][1] {
207-
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
208-
}
209-
}
187+
}, 3)
210188

211189
// Same tests but with a domain filter
212-
213190
cfg.DomainFilter = endpoint.NewDomainFilter([]string{"match.com"})
214191
cl, err = newPiholeClient(cfg)
215192
if err != nil {
216193
t.Fatal(err)
217194
}
218195

219196
// Test retrieve A records filtered
220-
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeA)
221-
if err != nil {
222-
t.Fatal(err)
223-
}
224-
if len(arecs) != 1 {
225-
t.Fatal("Expected 1 A record returned, got:", len(arecs))
226-
}
227-
// Ensure records were parsed correctly
228-
expected = [][]string{
197+
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeA, [][]string{
229198
{"test3.match.com", "192.168.1.3"},
230-
}
231-
for idx, rec := range arecs {
232-
if rec.DNSName != expected[idx][0] {
233-
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
234-
}
235-
if rec.Targets[0] != expected[idx][1] {
236-
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
237-
}
238-
}
199+
}, 1)
239200

240201
// Test retrieve AAAA records filtered
241-
arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeAAAA)
242-
if err != nil {
243-
t.Fatal(err)
244-
}
245-
if len(arecs) != 1 {
246-
t.Fatal("Expected 1 AAAA record returned, got:", len(arecs))
247-
}
248-
// Ensure records were parsed correctly
249-
expected = [][]string{
202+
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeAAAA, [][]string{
250203
{"test3.match.com", "fc00::1:192:168:1:3"},
251-
}
252-
for idx, rec := range arecs {
253-
if rec.DNSName != expected[idx][0] {
254-
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
255-
}
256-
if rec.Targets[0] != expected[idx][1] {
257-
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
258-
}
259-
}
204+
}, 1)
260205

261206
// Test retrieve CNAME records filtered
262-
cnamerecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
207+
CheckRecordRetrieval(t, cl.(*piholeClient), endpoint.RecordTypeCNAME, [][]string{
208+
{"test6.match.com", "cname.example.com"},
209+
}, 1)
210+
211+
}
212+
213+
// Helper function to test error scenarios
214+
func testErrorScenarios(t *testing.T, srvrErr *httptest.Server) {
215+
t.Helper()
216+
cfgExpired := PiholeConfig{
217+
Server: srvrErr.URL,
218+
}
219+
clExpired, err := newPiholeClient(cfgExpired)
263220
if err != nil {
264221
t.Fatal(err)
265222
}
266-
if len(cnamerecs) != 1 {
267-
t.Fatal("Expected 1 CNAME record returned, got:", len(cnamerecs))
223+
//set clExpired.token to a valid token
224+
clExpired.(*piholeClient).token = "expired"
225+
clExpired.(*piholeClient).cfg.Password = "notcorrect"
226+
227+
cnamerecs, err := clExpired.listRecords(context.Background(), "notarealrecordtype")
228+
if err == nil {
229+
t.Fatal("Should return error, type is unknown ! ")
268230
}
269-
// Ensure records were parsed correctly
270-
expected = [][]string{
271-
{"test6.match.com", "cname.example.com"},
231+
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
232+
if err == nil {
233+
t.Fatal("Should return error on failed auth ! ")
272234
}
273-
for idx, rec := range cnamerecs {
274-
if rec.DNSName != expected[idx][0] {
275-
t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
276-
}
277-
if rec.Targets[0] != expected[idx][1] {
278-
t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
279-
}
235+
clExpired.(*piholeClient).token = "correct"
236+
clExpired.(*piholeClient).cfg.Password = "correct"
237+
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
238+
if len(cnamerecs) != 0 {
239+
t.Fatal("Should return empty on missing data in response ! ")
280240
}
241+
}
281242

243+
func TestErrorScenarios(t *testing.T) {
282244
// Test errors token
283245
srvrErr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
284246
r.ParseForm()
@@ -318,6 +280,7 @@ func TestListRecords(t *testing.T) {
318280
`))
319281
})
320282
defer srvrErr.Close()
283+
321284
cfgExpired := PiholeConfig{
322285
Server: srvrErr.URL,
323286
}
@@ -329,21 +292,23 @@ func TestListRecords(t *testing.T) {
329292
clExpired.(*piholeClient).token = "expired"
330293
clExpired.(*piholeClient).cfg.Password = "notcorrect"
331294

332-
cnamerecs, err = clExpired.listRecords(context.Background(), "notarealrecordtype")
295+
_, err = clExpired.listRecords(context.Background(), "notarealrecordtype")
333296
if err == nil {
334297
t.Fatal("Should return error, type is unknown ! ")
335298
}
336-
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
299+
_, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
337300
if err == nil {
338301
t.Fatal("Should return error on failed auth ! ")
339302
}
340303
clExpired.(*piholeClient).token = "correct"
341304
clExpired.(*piholeClient).cfg.Password = "correct"
342-
cnamerecs, err = clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
305+
cnamerecs, err := clExpired.listRecords(context.Background(), endpoint.RecordTypeCNAME)
306+
if err != nil {
307+
t.Fatal(err)
308+
}
343309
if len(cnamerecs) != 0 {
344310
t.Fatal("Should return empty on missing data in response ! ")
345311
}
346-
347312
}
348313

349314
func TestCreateRecord(t *testing.T) {

0 commit comments

Comments
 (0)