@@ -5,15 +5,19 @@ import (
55 "encoding/json"
66 "errors"
77 "io"
8+ "net"
89 "net/http"
910 "net/http/httptest"
1011 "testing"
1112
13+ "github.com/google/uuid"
1214 "github.com/rs/zerolog"
1315 "github.com/stretchr/testify/assert"
1416 "github.com/stretchr/testify/require"
1517
18+ "github.com/cloudflare/cloudflared/connection"
1619 "github.com/cloudflare/cloudflared/diagnostic"
20+ "github.com/cloudflare/cloudflared/tunnelstate"
1721)
1822
1923type SystemCollectorMock struct {}
@@ -24,6 +28,23 @@ const (
2428 errorKey = "errkey"
2529)
2630
31+ func newTrackerFromConns (t * testing.T , connections []tunnelstate.IndexedConnectionInfo ) * tunnelstate.ConnTracker {
32+ t .Helper ()
33+
34+ log := zerolog .Nop ()
35+ tracker := tunnelstate .NewConnTracker (& log )
36+
37+ for _ , conn := range connections {
38+ tracker .OnTunnelEvent (connection.Event {
39+ Index : conn .Index ,
40+ EventType : connection .Connected ,
41+ Protocol : conn .Protocol ,
42+ EdgeAddress : conn .EdgeAddress ,
43+ })
44+ }
45+
46+ return tracker
47+ }
2748func setCtxValuesForSystemCollector (
2849 systemInfo * diagnostic.SystemInformation ,
2950 rawInfo string ,
@@ -83,7 +104,7 @@ func TestSystemHandler(t *testing.T) {
83104 for _ , tCase := range tests {
84105 t .Run (tCase .name , func (t * testing.T ) {
85106 t .Parallel ()
86- handler := diagnostic .NewDiagnosticHandler (& log , 0 , & SystemCollectorMock {})
107+ handler := diagnostic .NewDiagnosticHandler (& log , 0 , & SystemCollectorMock {}, uuid . New (), uuid . New (), nil )
87108 recorder := httptest .NewRecorder ()
88109 ctx := setCtxValuesForSystemCollector (tCase .systemInfo , tCase .rawInfo , tCase .err )
89110 request , err := http .NewRequestWithContext (ctx , http .MethodGet , "/diag/syste," , nil )
@@ -106,3 +127,58 @@ func TestSystemHandler(t *testing.T) {
106127 })
107128 }
108129}
130+
131+ func TestTunnelStateHandler (t * testing.T ) {
132+ t .Parallel ()
133+
134+ log := zerolog .Nop ()
135+ tests := []struct {
136+ name string
137+ tunnelID uuid.UUID
138+ clientID uuid.UUID
139+ connections []tunnelstate.IndexedConnectionInfo
140+ }{
141+ {
142+ name : "case1" ,
143+ tunnelID : uuid .New (),
144+ clientID : uuid .New (),
145+ },
146+ {
147+ name : "case2" ,
148+ tunnelID : uuid .New (),
149+ clientID : uuid .New (),
150+ connections : []tunnelstate.IndexedConnectionInfo {{
151+ ConnectionInfo : tunnelstate.ConnectionInfo {
152+ IsConnected : true ,
153+ Protocol : connection .QUIC ,
154+ EdgeAddress : net .IPv4 (100 , 100 , 100 , 100 ),
155+ },
156+ Index : 0 ,
157+ }},
158+ },
159+ }
160+
161+ for _ , tCase := range tests {
162+ t .Run (tCase .name , func (t * testing.T ) {
163+ t .Parallel ()
164+ tracker := newTrackerFromConns (t , tCase .connections )
165+ handler := diagnostic .NewDiagnosticHandler (& log , 0 , nil , tCase .tunnelID , tCase .clientID , tracker )
166+ recorder := httptest .NewRecorder ()
167+ handler .TunnelStateHandler (recorder , nil )
168+ decoder := json .NewDecoder (recorder .Body )
169+
170+ var response struct {
171+ TunnelID uuid.UUID `json:"tunnelID,omitempty"`
172+ ConnectorID uuid.UUID `json:"connectorID,omitempty"`
173+ Connections []tunnelstate.IndexedConnectionInfo `json:"connections,omitempty"`
174+ }
175+
176+ err := decoder .Decode (& response )
177+ require .NoError (t , err )
178+ assert .Equal (t , http .StatusOK , recorder .Code )
179+ assert .Equal (t , tCase .tunnelID , response .TunnelID )
180+ assert .Equal (t , tCase .clientID , response .ConnectorID )
181+ assert .Equal (t , tCase .connections , response .Connections )
182+ })
183+ }
184+ }
0 commit comments