77 "time"
88
99 "github.com/rs/zerolog"
10+
11+ "github.com/cloudflare/cloudflared/edgediscovery"
1012)
1113
1214const (
@@ -24,7 +26,7 @@ const (
2426
2527var (
2628 // ProtocolList represents a list of supported protocols for communication with the edge.
27- ProtocolList = []Protocol {H2mux , HTTP2 , QUIC }
29+ ProtocolList = []Protocol {H2mux , HTTP2 , HTTP2Warp , QUIC , QUICWarp }
2830)
2931
3032type Protocol int64
@@ -36,6 +38,12 @@ const (
3638 HTTP2
3739 // QUIC is used only with named tunnels.
3840 QUIC
41+ // HTTP2Warp is used only with named tunnels. It's useful for warp-routing where we don't want to fallback to
42+ // H2mux on HTTP2 failure to connect.
43+ HTTP2Warp
44+ //QUICWarp is used only with named tunnels. It's useful for warp-routing where we want to fallback to HTTP2 but
45+ // dont' want HTTP2 to fallback to H2mux
46+ QUICWarp
3947)
4048
4149// Fallback returns the fallback protocol and whether the protocol has a fallback
@@ -45,8 +53,12 @@ func (p Protocol) fallback() (Protocol, bool) {
4553 return 0 , false
4654 case HTTP2 :
4755 return H2mux , true
56+ case HTTP2Warp :
57+ return 0 , false
4858 case QUIC :
4959 return HTTP2 , true
60+ case QUICWarp :
61+ return HTTP2Warp , true
5062 default :
5163 return 0 , false
5264 }
@@ -56,9 +68,9 @@ func (p Protocol) String() string {
5668 switch p {
5769 case H2mux :
5870 return "h2mux"
59- case HTTP2 :
71+ case HTTP2 , HTTP2Warp :
6072 return "http2"
61- case QUIC :
73+ case QUIC , QUICWarp :
6274 return "quic"
6375 default :
6476 return fmt .Sprintf ("unknown protocol" )
@@ -71,11 +83,11 @@ func (p Protocol) TLSSettings() *TLSSettings {
7183 return & TLSSettings {
7284 ServerName : edgeH2muxTLSServerName ,
7385 }
74- case HTTP2 :
86+ case HTTP2 , HTTP2Warp :
7587 return & TLSSettings {
7688 ServerName : edgeH2TLSServerName ,
7789 }
78- case QUIC :
90+ case QUIC , QUICWarp :
7991 return & TLSSettings {
8092 ServerName : edgeQUICServerName ,
8193 NextProtos : []string {"argotunnel" },
@@ -108,29 +120,36 @@ func (s *staticProtocolSelector) Fallback() (Protocol, bool) {
108120}
109121
110122type autoProtocolSelector struct {
111- lock sync.RWMutex
112- current Protocol
113- switchThrehold int32
114- fetchFunc PercentageFetcher
115- refreshAfter time.Time
116- ttl time.Duration
117- log * zerolog.Logger
123+ lock sync.RWMutex
124+
125+ current Protocol
126+
127+ // protocolPool is desired protocols in the order of priority they should be picked in.
128+ protocolPool []Protocol
129+
130+ switchThreshold int32
131+ fetchFunc PercentageFetcher
132+ refreshAfter time.Time
133+ ttl time.Duration
134+ log * zerolog.Logger
118135}
119136
120137func newAutoProtocolSelector (
121138 current Protocol ,
122- switchThrehold int32 ,
139+ protocolPool []Protocol ,
140+ switchThreshold int32 ,
123141 fetchFunc PercentageFetcher ,
124142 ttl time.Duration ,
125143 log * zerolog.Logger ,
126144) * autoProtocolSelector {
127145 return & autoProtocolSelector {
128- current : current ,
129- switchThrehold : switchThrehold ,
130- fetchFunc : fetchFunc ,
131- refreshAfter : time .Now ().Add (ttl ),
132- ttl : ttl ,
133- log : log ,
146+ current : current ,
147+ protocolPool : protocolPool ,
148+ switchThreshold : switchThreshold ,
149+ fetchFunc : fetchFunc ,
150+ refreshAfter : time .Now ().Add (ttl ),
151+ ttl : ttl ,
152+ log : log ,
134153 }
135154}
136155
@@ -141,28 +160,39 @@ func (s *autoProtocolSelector) Current() Protocol {
141160 return s .current
142161 }
143162
144- percentage , err := s . fetchFunc ( )
163+ protocol , err := getProtocol ( s . protocolPool , s . fetchFunc , s . switchThreshold )
145164 if err != nil {
146165 s .log .Err (err ).Msg ("Failed to refresh protocol" )
147166 return s .current
148167 }
168+ s .current = protocol
149169
150- if s .switchThrehold < percentage {
151- s .current = HTTP2
152- } else {
153- s .current = H2mux
154- }
155170 s .refreshAfter = time .Now ().Add (s .ttl )
156171 return s .current
157172}
158173
174+ func getProtocol (protocolPool []Protocol , fetchFunc PercentageFetcher , switchThreshold int32 ) (Protocol , error ) {
175+ protocolPercentages , err := fetchFunc ()
176+ if err != nil {
177+ return 0 , err
178+ }
179+ for _ , protocol := range protocolPool {
180+ protocolPercentage := protocolPercentages .GetPercentage (protocol .String ())
181+ if protocolPercentage > switchThreshold {
182+ return protocol , nil
183+ }
184+ }
185+
186+ return protocolPool [len (protocolPool )- 1 ], nil
187+ }
188+
159189func (s * autoProtocolSelector ) Fallback () (Protocol , bool ) {
160190 s .lock .RLock ()
161191 defer s .lock .RUnlock ()
162192 return s .current .fallback ()
163193}
164194
165- type PercentageFetcher func () (int32 , error )
195+ type PercentageFetcher func () (edgediscovery. ProtocolPercents , error )
166196
167197func NewProtocolSelector (
168198 protocolFlag string ,
@@ -179,54 +209,76 @@ func NewProtocolSelector(
179209 }, nil
180210 }
181211
182- // warp routing cannot be served over h2mux connections
183- if warpRoutingEnabled {
184- if protocolFlag == H2mux .String () {
185- log .Warn ().Msg ("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it." )
186- }
187-
188- if protocolFlag == QUIC .String () {
189- return & staticProtocolSelector {
190- current : QUIC ,
191- }, nil
192- }
212+ threshold := switchThreshold (namedTunnel .Credentials .AccountTag )
213+ fetchedProtocol , err := getProtocol ([]Protocol {QUIC , HTTP2 }, fetchFunc , threshold )
214+ if err != nil {
215+ log .Err (err ).Msg ("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command." )
193216 return & staticProtocolSelector {
194217 current : HTTP2 ,
195218 }, nil
196219 }
220+ if warpRoutingEnabled {
221+ if protocolFlag == H2mux .String () || fetchedProtocol == H2mux {
222+ log .Warn ().Msg ("Warp routing is not supported in h2mux protocol. Upgrading to http2 to allow it." )
223+ protocolFlag = HTTP2 .String ()
224+ fetchedProtocol = HTTP2Warp
225+ }
226+ return selectWarpRoutingProtocols (protocolFlag , fetchFunc , ttl , log , threshold , fetchedProtocol )
227+ }
228+
229+ return selectNamedTunnelProtocols (protocolFlag , fetchFunc , ttl , log , threshold , fetchedProtocol )
230+ }
197231
232+ func selectNamedTunnelProtocols (
233+ protocolFlag string ,
234+ fetchFunc PercentageFetcher ,
235+ ttl time.Duration ,
236+ log * zerolog.Logger ,
237+ threshold int32 ,
238+ protocol Protocol ,
239+ ) (ProtocolSelector , error ) {
198240 if protocolFlag == H2mux .String () {
199241 return & staticProtocolSelector {
200242 current : H2mux ,
201243 }, nil
202244 }
203245
204246 if protocolFlag == QUIC .String () {
205- return newAutoProtocolSelector (QUIC , explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
247+ return newAutoProtocolSelector (QUIC , [] Protocol { QUIC , HTTP2 , H2mux }, explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
206248 }
207249
208- http2Percentage , err := fetchFunc ()
209- if err != nil {
210- log .Err (err ).Msg ("Unable to lookup protocol. Defaulting to `http2`. If this fails, you can set `--protocol h2mux` in your cloudflared command." )
211- return & staticProtocolSelector {
212- current : HTTP2 ,
213- }, nil
214- }
215250 if protocolFlag == HTTP2 .String () {
216- if http2Percentage < 0 {
217- return newAutoProtocolSelector (H2mux , explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
218- }
219- return newAutoProtocolSelector (HTTP2 , explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
251+ return newAutoProtocolSelector (HTTP2 , []Protocol {HTTP2 , H2mux }, explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
220252 }
221253
222254 if protocolFlag != autoSelectFlag {
223255 return nil , fmt .Errorf ("Unknown protocol %s, %s" , protocolFlag , AvailableProtocolFlagMessage )
224256 }
225- threshold := switchThreshold (namedTunnel .Credentials .AccountTag )
226- if threshold < http2Percentage {
227- return newAutoProtocolSelector (HTTP2 , threshold , fetchFunc , ttl , log ), nil
257+
258+ return newAutoProtocolSelector (protocol , []Protocol {QUIC , HTTP2 , H2mux }, threshold , fetchFunc , ttl , log ), nil
259+ }
260+
261+ func selectWarpRoutingProtocols (
262+ protocolFlag string ,
263+ fetchFunc PercentageFetcher ,
264+ ttl time.Duration ,
265+ log * zerolog.Logger ,
266+ threshold int32 ,
267+ protocol Protocol ,
268+ ) (ProtocolSelector , error ) {
269+ if protocolFlag == QUIC .String () {
270+ return newAutoProtocolSelector (QUICWarp , []Protocol {QUICWarp , HTTP2Warp }, explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
228271 }
229- return newAutoProtocolSelector (H2mux , threshold , fetchFunc , ttl , log ), nil
272+
273+ if protocolFlag == HTTP2 .String () {
274+ return newAutoProtocolSelector (HTTP2Warp , []Protocol {HTTP2Warp }, explicitHTTP2FallbackThreshold , fetchFunc , ttl , log ), nil
275+ }
276+
277+ if protocolFlag != autoSelectFlag {
278+ return nil , fmt .Errorf ("Unknown protocol %s, %s" , protocolFlag , AvailableProtocolFlagMessage )
279+ }
280+
281+ return newAutoProtocolSelector (protocol , []Protocol {QUICWarp , HTTP2Warp }, threshold , fetchFunc , ttl , log ), nil
230282}
231283
232284func switchThreshold (accountTag string ) int32 {
0 commit comments