@@ -96,6 +96,32 @@ func TestNewPiholeClient(t *testing.T) {
9696 }
9797}
9898
99+ // Helper function to validate records against expected values
100+ func ValidateRecords (t * testing.T , records []* endpoint.Endpoint , expected [][]string , expectedCount int , recordType string ) {
101+ t .Helper ()
102+ if len (records ) != expectedCount {
103+ t .Fatalf ("Expected %d %s records returned, got: %d" , expectedCount , recordType , len (records ))
104+ }
105+ for idx , rec := range records {
106+ if rec .DNSName != expected [idx ][0 ] {
107+ t .Errorf ("Got invalid DNS Name: %s, expected: %s" , rec .DNSName , expected [idx ][0 ])
108+ }
109+ if rec .Targets [0 ] != expected [idx ][1 ] {
110+ t .Errorf ("Got invalid target: %s, expected: %s" , rec .Targets [0 ], expected [idx ][1 ])
111+ }
112+ }
113+ }
114+
115+ // Helper function to test record retrieval for a specific type
116+ func CheckRecordRetrieval (t * testing.T , cl * piholeClient , recordType string , expected [][]string , expectedCount int ) {
117+ t .Helper ()
118+ records , err := cl .listRecords (context .Background (), recordType )
119+ if err != nil {
120+ t .Fatal (err )
121+ }
122+ ValidateRecords (t , records , expected , expectedCount , recordType )
123+ }
124+
99125func TestListRecords (t * testing.T ) {
100126 srvr := newTestServer (t , func (w http.ResponseWriter , r * http.Request ) {
101127 r .ParseForm ()
@@ -140,145 +166,81 @@ func TestListRecords(t *testing.T) {
140166 }
141167
142168 // Test retrieve A records unfiltered
143- arecs , err := cl .listRecords (context .Background (), endpoint .RecordTypeA )
144- if err != nil {
145- t .Fatal (err )
146- }
147- if len (arecs ) != 3 {
148- t .Fatal ("Expected 3 A records returned, got:" , len (arecs ))
149- }
150- // Ensure records were parsed correctly
151- expected := [][]string {
169+ CheckRecordRetrieval (t , cl .(* piholeClient ), endpoint .RecordTypeA , [][]string {
152170 {"test1.example.com" , "192.168.1.1" },
153171 {"test2.example.com" , "192.168.1.2" },
154172 {"test3.match.com" , "192.168.1.3" },
155- }
156- for idx , rec := range arecs {
157- if rec .DNSName != expected [idx ][0 ] {
158- t .Error ("Got invalid DNS Name:" , rec .DNSName , "expected:" , expected [idx ][0 ])
159- }
160- if rec .Targets [0 ] != expected [idx ][1 ] {
161- t .Error ("Got invalid target:" , rec .Targets [0 ], "expected:" , expected [idx ][1 ])
162- }
163- }
173+ }, 3 )
164174
165175 // Test retrieve AAAA records unfiltered
166- arecs , err = cl .listRecords (context .Background (), endpoint .RecordTypeAAAA )
167- if err != nil {
168- t .Fatal (err )
169- }
170- if len (arecs ) != 3 {
171- t .Fatal ("Expected 3 AAAA records returned, got:" , len (arecs ))
172- }
173- // Ensure records were parsed correctly
174- expected = [][]string {
176+ CheckRecordRetrieval (t , cl .(* piholeClient ), endpoint .RecordTypeAAAA , [][]string {
175177 {"test1.example.com" , "fc00::1:192:168:1:1" },
176178 {"test2.example.com" , "fc00::1:192:168:1:2" },
177179 {"test3.match.com" , "fc00::1:192:168:1:3" },
178- }
179- for idx , rec := range arecs {
180- if rec .DNSName != expected [idx ][0 ] {
181- t .Error ("Got invalid DNS Name:" , rec .DNSName , "expected:" , expected [idx ][0 ])
182- }
183- if rec .Targets [0 ] != expected [idx ][1 ] {
184- t .Error ("Got invalid target:" , rec .Targets [0 ], "expected:" , expected [idx ][1 ])
185- }
186- }
180+ }, 3 )
187181
188182 // Test retrieve CNAME records unfiltered
189- cnamerecs , err := cl .listRecords (context .Background (), endpoint .RecordTypeCNAME )
190- if err != nil {
191- t .Fatal (err )
192- }
193- if len (cnamerecs ) != 3 {
194- t .Fatal ("Expected 3 CAME records returned, got:" , len (cnamerecs ))
195- }
196- // Ensure records were parsed correctly
197- expected = [][]string {
183+ CheckRecordRetrieval (t , cl .(* piholeClient ), endpoint .RecordTypeCNAME , [][]string {
198184 {"test4.example.com" , "cname.example.com" },
199185 {"test5.example.com" , "cname.example.com" },
200186 {"test6.match.com" , "cname.example.com" },
201- }
202- for idx , rec := range cnamerecs {
203- if rec .DNSName != expected [idx ][0 ] {
204- t .Error ("Got invalid DNS Name:" , rec .DNSName , "expected:" , expected [idx ][0 ])
205- }
206- if rec .Targets [0 ] != expected [idx ][1 ] {
207- t .Error ("Got invalid target:" , rec .Targets [0 ], "expected:" , expected [idx ][1 ])
208- }
209- }
187+ }, 3 )
210188
211189 // Same tests but with a domain filter
212-
213190 cfg .DomainFilter = endpoint .NewDomainFilter ([]string {"match.com" })
214191 cl , err = newPiholeClient (cfg )
215192 if err != nil {
216193 t .Fatal (err )
217194 }
218195
219196 // Test retrieve A records filtered
220- arecs , err = cl .listRecords (context .Background (), endpoint .RecordTypeA )
221- if err != nil {
222- t .Fatal (err )
223- }
224- if len (arecs ) != 1 {
225- t .Fatal ("Expected 1 A record returned, got:" , len (arecs ))
226- }
227- // Ensure records were parsed correctly
228- expected = [][]string {
197+ CheckRecordRetrieval (t , cl .(* piholeClient ), endpoint .RecordTypeA , [][]string {
229198 {"test3.match.com" , "192.168.1.3" },
230- }
231- for idx , rec := range arecs {
232- if rec .DNSName != expected [idx ][0 ] {
233- t .Error ("Got invalid DNS Name:" , rec .DNSName , "expected:" , expected [idx ][0 ])
234- }
235- if rec .Targets [0 ] != expected [idx ][1 ] {
236- t .Error ("Got invalid target:" , rec .Targets [0 ], "expected:" , expected [idx ][1 ])
237- }
238- }
199+ }, 1 )
239200
240201 // Test retrieve AAAA records filtered
241- arecs , err = cl .listRecords (context .Background (), endpoint .RecordTypeAAAA )
242- if err != nil {
243- t .Fatal (err )
244- }
245- if len (arecs ) != 1 {
246- t .Fatal ("Expected 1 AAAA record returned, got:" , len (arecs ))
247- }
248- // Ensure records were parsed correctly
249- expected = [][]string {
202+ CheckRecordRetrieval (t , cl .(* piholeClient ), endpoint .RecordTypeAAAA , [][]string {
250203 {"test3.match.com" , "fc00::1:192:168:1:3" },
251- }
252- for idx , rec := range arecs {
253- if rec .DNSName != expected [idx ][0 ] {
254- t .Error ("Got invalid DNS Name:" , rec .DNSName , "expected:" , expected [idx ][0 ])
255- }
256- if rec .Targets [0 ] != expected [idx ][1 ] {
257- t .Error ("Got invalid target:" , rec .Targets [0 ], "expected:" , expected [idx ][1 ])
258- }
259- }
204+ }, 1 )
260205
261206 // Test retrieve CNAME records filtered
262- cnamerecs , err = cl .listRecords (context .Background (), endpoint .RecordTypeCNAME )
207+ CheckRecordRetrieval (t , cl .(* piholeClient ), endpoint .RecordTypeCNAME , [][]string {
208+ {"test6.match.com" , "cname.example.com" },
209+ }, 1 )
210+
211+ }
212+
213+ // Helper function to test error scenarios
214+ func testErrorScenarios (t * testing.T , srvrErr * httptest.Server ) {
215+ t .Helper ()
216+ cfgExpired := PiholeConfig {
217+ Server : srvrErr .URL ,
218+ }
219+ clExpired , err := newPiholeClient (cfgExpired )
263220 if err != nil {
264221 t .Fatal (err )
265222 }
266- if len (cnamerecs ) != 1 {
267- t .Fatal ("Expected 1 CNAME record returned, got:" , len (cnamerecs ))
223+ //set clExpired.token to a valid token
224+ clExpired .(* piholeClient ).token = "expired"
225+ clExpired .(* piholeClient ).cfg .Password = "notcorrect"
226+
227+ cnamerecs , err := clExpired .listRecords (context .Background (), "notarealrecordtype" )
228+ if err == nil {
229+ t .Fatal ("Should return error, type is unknown ! " )
268230 }
269- // Ensure records were parsed correctly
270- expected = [][] string {
271- { "test6.match.com" , "cname.example.com" },
231+ cnamerecs , err = clExpired . listRecords ( context . Background (), endpoint . RecordTypeCNAME )
232+ if err == nil {
233+ t . Fatal ( "Should return error on failed auth ! " )
272234 }
273- for idx , rec := range cnamerecs {
274- if rec .DNSName != expected [idx ][0 ] {
275- t .Error ("Got invalid DNS Name:" , rec .DNSName , "expected:" , expected [idx ][0 ])
276- }
277- if rec .Targets [0 ] != expected [idx ][1 ] {
278- t .Error ("Got invalid target:" , rec .Targets [0 ], "expected:" , expected [idx ][1 ])
279- }
235+ clExpired .(* piholeClient ).token = "correct"
236+ clExpired .(* piholeClient ).cfg .Password = "correct"
237+ cnamerecs , err = clExpired .listRecords (context .Background (), endpoint .RecordTypeCNAME )
238+ if len (cnamerecs ) != 0 {
239+ t .Fatal ("Should return empty on missing data in response ! " )
280240 }
241+ }
281242
243+ func TestErrorScenarios (t * testing.T ) {
282244 // Test errors token
283245 srvrErr := newTestServer (t , func (w http.ResponseWriter , r * http.Request ) {
284246 r .ParseForm ()
@@ -318,6 +280,7 @@ func TestListRecords(t *testing.T) {
318280 ` ))
319281 })
320282 defer srvrErr .Close ()
283+
321284 cfgExpired := PiholeConfig {
322285 Server : srvrErr .URL ,
323286 }
@@ -329,21 +292,23 @@ func TestListRecords(t *testing.T) {
329292 clExpired .(* piholeClient ).token = "expired"
330293 clExpired .(* piholeClient ).cfg .Password = "notcorrect"
331294
332- cnamerecs , err = clExpired .listRecords (context .Background (), "notarealrecordtype" )
295+ _ , err = clExpired .listRecords (context .Background (), "notarealrecordtype" )
333296 if err == nil {
334297 t .Fatal ("Should return error, type is unknown ! " )
335298 }
336- cnamerecs , err = clExpired .listRecords (context .Background (), endpoint .RecordTypeCNAME )
299+ _ , err = clExpired .listRecords (context .Background (), endpoint .RecordTypeCNAME )
337300 if err == nil {
338301 t .Fatal ("Should return error on failed auth ! " )
339302 }
340303 clExpired .(* piholeClient ).token = "correct"
341304 clExpired .(* piholeClient ).cfg .Password = "correct"
342- cnamerecs , err = clExpired .listRecords (context .Background (), endpoint .RecordTypeCNAME )
305+ cnamerecs , err := clExpired .listRecords (context .Background (), endpoint .RecordTypeCNAME )
306+ if err != nil {
307+ t .Fatal (err )
308+ }
343309 if len (cnamerecs ) != 0 {
344310 t .Fatal ("Should return empty on missing data in response ! " )
345311 }
346-
347312}
348313
349314func TestCreateRecord (t * testing.T ) {
0 commit comments