Skip to content

Commit bec683b

Browse files
committed
TUN-7700: Implement feature selector to determine if connections will prefer post quantum cryptography
1 parent 38d3c3c commit bec683b

File tree

8 files changed

+385
-43
lines changed

8 files changed

+385
-43
lines changed

cmd/cloudflared/tunnel/cmd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ func StartServer(
392392
observer.SendURL(quickTunnelURL)
393393
}
394394

395-
tunnelConfig, orchestratorConfig, err := prepareTunnelConfig(c, info, log, logTransport, observer, namedTunnel)
395+
tunnelConfig, orchestratorConfig, err := prepareTunnelConfig(ctx, c, info, log, logTransport, observer, namedTunnel)
396396
if err != nil {
397397
log.Err(err).Msg("Couldn't start tunnel")
398398
return err

cmd/cloudflared/tunnel/config_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ import (
44
"testing"
55

66
"github.com/stretchr/testify/require"
7+
8+
"github.com/cloudflare/cloudflared/features"
79
)
810

911
func TestDedup(t *testing.T) {
1012
expected := []string{"a", "b"}
11-
actual := dedup([]string{"a", "b", "a"})
13+
actual := features.Dedup([]string{"a", "b", "a"})
1214
require.ElementsMatch(t, expected, actual)
1315
}

cmd/cloudflared/tunnel/configuration.go

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package tunnel
22

33
import (
4+
"context"
45
"crypto/tls"
56
"fmt"
67
"net"
@@ -112,6 +113,7 @@ func dnsProxyStandAlone(c *cli.Context, namedTunnel *connection.NamedTunnelPrope
112113
}
113114

114115
func prepareTunnelConfig(
116+
ctx context.Context,
115117
c *cli.Context,
116118
info *cliutil.BuildInfo,
117119
log, logTransport *zerolog.Logger,
@@ -131,22 +133,36 @@ func prepareTunnelConfig(
131133
tags = append(tags, tunnelpogs.Tag{Name: "ID", Value: clientID.String()})
132134

133135
transportProtocol := c.String("protocol")
134-
needPQ := c.Bool("post-quantum")
135-
if needPQ {
136+
137+
clientFeatures := features.Dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
138+
139+
staticFeatures := features.StaticFeatures{}
140+
if c.Bool("post-quantum") {
136141
if FipsEnabled {
137142
return nil, nil, fmt.Errorf("post-quantum not supported in FIPS mode")
138143
}
144+
pqMode := features.PostQuantumStrict
145+
staticFeatures.PostQuantumMode = &pqMode
146+
}
147+
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, staticFeatures, log)
148+
if err != nil {
149+
return nil, nil, errors.Wrap(err, "Failed to create feature selector")
150+
}
151+
pqMode := featureSelector.PostQuantumMode()
152+
if pqMode == features.PostQuantumStrict {
139153
// Error if the user tries to force a non-quic transport protocol
140154
if transportProtocol != connection.AutoSelectFlag && transportProtocol != connection.QUIC.String() {
141155
return nil, nil, fmt.Errorf("post-quantum is only supported with the quic transport")
142156
}
143157
transportProtocol = connection.QUIC.String()
144-
}
145-
146-
clientFeatures := dedup(append(c.StringSlice("features"), features.DefaultFeatures...))
147-
if needPQ {
148158
clientFeatures = append(clientFeatures, features.FeaturePostQuantum)
159+
160+
log.Info().Msgf(
161+
"Using hybrid post-quantum key agreement %s",
162+
supervisor.PQKexName,
163+
)
149164
}
165+
150166
namedTunnel.Client = tunnelpogs.ClientInfo{
151167
ClientID: clientID[:],
152168
Features: clientFeatures,
@@ -202,13 +218,6 @@ func prepareTunnelConfig(
202218
log.Warn().Str("edgeIPVersion", edgeIPVersion.String()).Err(err).Msg("Overriding edge-ip-version")
203219
}
204220

205-
if needPQ {
206-
log.Info().Msgf(
207-
"Using hybrid post-quantum key agreement %s",
208-
supervisor.PQKexName,
209-
)
210-
}
211-
212221
tunnelConfig := &supervisor.TunnelConfig{
213222
GracePeriod: gracePeriod,
214223
ReplaceExisting: c.Bool("force"),
@@ -233,7 +242,7 @@ func prepareTunnelConfig(
233242
NamedTunnel: namedTunnel,
234243
ProtocolSelector: protocolSelector,
235244
EdgeTLSConfigs: edgeTLSConfigs,
236-
NeedPQ: needPQ,
245+
FeatureSelector: featureSelector,
237246
MaxEdgeAddrRetries: uint8(c.Int("max-edge-addr-retries")),
238247
UDPUnregisterSessionTimeout: c.Duration(udpUnregisterSessionTimeoutFlag),
239248
DisableQUICPathMTUDiscovery: c.Bool(quicDisablePathMTUDiscovery),
@@ -276,25 +285,6 @@ func isRunningFromTerminal() bool {
276285
return term.IsTerminal(int(os.Stdout.Fd()))
277286
}
278287

279-
// Remove any duplicates from the slice
280-
func dedup(slice []string) []string {
281-
282-
// Convert the slice into a set
283-
set := make(map[string]bool, 0)
284-
for _, str := range slice {
285-
set[str] = true
286-
}
287-
288-
// Convert the set back into a slice
289-
keys := make([]string, len(set))
290-
i := 0
291-
for str := range set {
292-
keys[i] = str
293-
i++
294-
}
295-
return keys
296-
}
297-
298288
// ParseConfigIPVersion returns the IP version from possible expected values from config
299289
func parseConfigIPVersion(version string) (v allregions.ConfigIPVersion, err error) {
300290
switch version {

features/features.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,22 @@ func Contains(feature string) bool {
2828
}
2929
return false
3030
}
31+
32+
// Remove any duplicates from the slice
33+
func Dedup(slice []string) []string {
34+
35+
// Convert the slice into a set
36+
set := make(map[string]bool, 0)
37+
for _, str := range slice {
38+
set[str] = true
39+
}
40+
41+
// Convert the set back into a slice
42+
keys := make([]string, len(set))
43+
i := 0
44+
for str := range set {
45+
keys[i] = str
46+
i++
47+
}
48+
return keys
49+
}

features/selector.go

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
package features
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"hash/fnv"
8+
"net"
9+
"sync"
10+
"time"
11+
12+
"github.com/rs/zerolog"
13+
)
14+
15+
const (
16+
featureSelectorHostname = "cfd-features.argotunnel.com"
17+
defaultRefreshFreq = time.Hour * 6
18+
lookupTimeout = time.Second * 10
19+
)
20+
21+
type PostQuantumMode uint8
22+
23+
const (
24+
PostQuantumDisabled PostQuantumMode = iota
25+
// Prefer post quantum, but fallback if connection cannot be established
26+
PostQuantumPrefer
27+
// If the user passes the --post-quantum flag, we override
28+
// CurvePreferences to only support hybrid post-quantum key agreements.
29+
PostQuantumStrict
30+
)
31+
32+
// If the TXT record adds other fields, the umarshal logic will ignore those keys
33+
// If the TXT record is missing a key, the field will unmarshal to the default Go value
34+
type featuresRecord struct {
35+
PostQuantumPercentage int32 `json:"pq"`
36+
}
37+
38+
func NewFeatureSelector(ctx context.Context, accountTag string, staticFeatures StaticFeatures, logger *zerolog.Logger) (*FeatureSelector, error) {
39+
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), staticFeatures, defaultRefreshFreq)
40+
}
41+
42+
// FeatureSelector determines if this account will try new features. It preiodically queries a DNS TXT record
43+
// to see which features are turned on
44+
type FeatureSelector struct {
45+
accountHash int32
46+
logger *zerolog.Logger
47+
resolver resolver
48+
49+
staticFeatures StaticFeatures
50+
51+
// lock protects concurrent access to dynamic features
52+
lock sync.RWMutex
53+
features featuresRecord
54+
}
55+
56+
// Features set by user provided flags
57+
type StaticFeatures struct {
58+
PostQuantumMode *PostQuantumMode
59+
}
60+
61+
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, staticFeatures StaticFeatures, refreshFreq time.Duration) (*FeatureSelector, error) {
62+
selector := &FeatureSelector{
63+
accountHash: switchThreshold(accountTag),
64+
logger: logger,
65+
resolver: resolver,
66+
staticFeatures: staticFeatures,
67+
}
68+
69+
if err := selector.refresh(ctx); err != nil {
70+
logger.Err(err).Msg("Failed to fetch features, default to disable")
71+
}
72+
73+
go selector.refreshLoop(ctx, refreshFreq)
74+
75+
return selector, nil
76+
}
77+
78+
func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
79+
if fs.staticFeatures.PostQuantumMode != nil {
80+
return *fs.staticFeatures.PostQuantumMode
81+
}
82+
83+
fs.lock.RLock()
84+
defer fs.lock.RUnlock()
85+
86+
if fs.features.PostQuantumPercentage > fs.accountHash {
87+
return PostQuantumPrefer
88+
}
89+
return PostQuantumDisabled
90+
}
91+
92+
func (fs *FeatureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) {
93+
ticker := time.NewTicker(refreshFreq)
94+
for {
95+
select {
96+
case <-ctx.Done():
97+
return
98+
case <-ticker.C:
99+
err := fs.refresh(ctx)
100+
if err != nil {
101+
fs.logger.Err(err).Msg("Failed to refresh feature selector")
102+
}
103+
}
104+
}
105+
}
106+
107+
func (fs *FeatureSelector) refresh(ctx context.Context) error {
108+
record, err := fs.resolver.lookupRecord(ctx)
109+
if err != nil {
110+
return err
111+
}
112+
113+
var features featuresRecord
114+
if err := json.Unmarshal(record, &features); err != nil {
115+
return err
116+
}
117+
118+
pq_enabled := features.PostQuantumPercentage > fs.accountHash
119+
fs.logger.Debug().Int32("account_hash", fs.accountHash).Int32("pq_perct", features.PostQuantumPercentage).Bool("pq_enabled", pq_enabled).Msg("Refreshed feature")
120+
121+
fs.lock.Lock()
122+
defer fs.lock.Unlock()
123+
124+
fs.features = features
125+
126+
return nil
127+
}
128+
129+
// resolver represents an object that can look up featuresRecord
130+
type resolver interface {
131+
lookupRecord(ctx context.Context) ([]byte, error)
132+
}
133+
134+
type dnsResolver struct {
135+
resolver *net.Resolver
136+
}
137+
138+
func newDNSResolver() *dnsResolver {
139+
return &dnsResolver{
140+
resolver: net.DefaultResolver,
141+
}
142+
}
143+
144+
func (dr *dnsResolver) lookupRecord(ctx context.Context) ([]byte, error) {
145+
ctx, cancel := context.WithTimeout(ctx, lookupTimeout)
146+
defer cancel()
147+
148+
records, err := dr.resolver.LookupTXT(ctx, featureSelectorHostname)
149+
if err != nil {
150+
return nil, err
151+
}
152+
153+
if len(records) == 0 {
154+
return nil, fmt.Errorf("No TXT record found for %s to determine which features to opt-in", featureSelectorHostname)
155+
}
156+
157+
return []byte(records[0]), nil
158+
}
159+
160+
func switchThreshold(accountTag string) int32 {
161+
h := fnv.New32a()
162+
_, _ = h.Write([]byte(accountTag))
163+
return int32(h.Sum32() % 100)
164+
}

0 commit comments

Comments
 (0)