Skip to content

Commit ba9f28e

Browse files
committed
TUN-8786: calculate cli flags once for the diagnostic procedure
## Summary The flags were always being computed when their value is static. Closes TUN-8786
1 parent 77b99cf commit ba9f28e

File tree

4 files changed

+75
-109
lines changed

4 files changed

+75
-109
lines changed

cmd/cloudflared/tunnel/cmd.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net/url"
88
"os"
9+
"path/filepath"
910
"runtime/trace"
1011
"strings"
1112
"sync"
@@ -560,15 +561,15 @@ func StartServer(
560561
}
561562

562563
readinessServer := metrics.NewReadyServer(clientID, tracker)
564+
cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList)
563565
diagnosticHandler := diagnostic.NewDiagnosticHandler(
564566
log,
565567
0,
566568
diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion),
567569
tunnelConfig.NamedTunnel.Credentials.TunnelID,
568570
clientID,
569571
tracker,
570-
c,
571-
nonSecretFlagsList,
572+
cliFlags,
572573
sources,
573574
)
574575
metricsConfig := metrics.Config{
@@ -1309,3 +1310,46 @@ reconnect [delay]
13091310
}
13101311
}
13111312
}
1313+
1314+
func nonSecretCliFlags(log *zerolog.Logger, cli *cli.Context, flagInclusionList []string) map[string]string {
1315+
flagsNames := cli.FlagNames()
1316+
flags := make(map[string]string, len(flagsNames))
1317+
1318+
for _, flag := range flagsNames {
1319+
value := cli.String(flag)
1320+
1321+
if value == "" {
1322+
continue
1323+
}
1324+
1325+
isIncluded := isFlagIncluded(flagInclusionList, flag)
1326+
if !isIncluded {
1327+
continue
1328+
}
1329+
1330+
switch flag {
1331+
case logger.LogDirectoryFlag, logger.LogFileFlag:
1332+
{
1333+
absolute, err := filepath.Abs(value)
1334+
if err != nil {
1335+
log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
1336+
} else {
1337+
flags[flag] = absolute
1338+
}
1339+
}
1340+
default:
1341+
flags[flag] = value
1342+
}
1343+
}
1344+
return flags
1345+
}
1346+
1347+
func isFlagIncluded(flagInclusionList []string, flag string) bool {
1348+
for _, include := range flagInclusionList {
1349+
if include == flag {
1350+
return true
1351+
}
1352+
}
1353+
1354+
return false
1355+
}

diagnostic/diagnostic_utils_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func helperCreateServer(t *testing.T, listeners *gracenet.Net, tunnelID uuid.UUI
2525
require.NoError(t, err)
2626
log := zerolog.Nop()
2727
tracker := tunnelstate.NewConnTracker(&log)
28-
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, nil, []string{}, []string{})
28+
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, tunnelID, connectorID, tracker, map[string]string{}, []string{})
2929
router := http.NewServeMux()
3030
router.HandleFunc("/diag/tunnel", handler.TunnelStateHandler)
3131
server := &http.Server{

diagnostic/handlers.go

Lines changed: 19 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,24 @@ import (
55
"encoding/json"
66
"net/http"
77
"os"
8-
"path/filepath"
98
"strconv"
109
"time"
1110

1211
"github.com/google/uuid"
1312
"github.com/rs/zerolog"
14-
"github.com/urfave/cli/v2"
1513

16-
"github.com/cloudflare/cloudflared/logger"
1714
"github.com/cloudflare/cloudflared/tunnelstate"
1815
)
1916

2017
type Handler struct {
21-
log *zerolog.Logger
22-
timeout time.Duration
23-
systemCollector SystemCollector
24-
tunnelID uuid.UUID
25-
connectorID uuid.UUID
26-
tracker *tunnelstate.ConnTracker
27-
cli *cli.Context
28-
flagInclusionList []string
29-
icmpSources []string
18+
log *zerolog.Logger
19+
timeout time.Duration
20+
systemCollector SystemCollector
21+
tunnelID uuid.UUID
22+
connectorID uuid.UUID
23+
tracker *tunnelstate.ConnTracker
24+
cliFlags map[string]string
25+
icmpSources []string
3026
}
3127

3228
func NewDiagnosticHandler(
@@ -36,25 +32,24 @@ func NewDiagnosticHandler(
3632
tunnelID uuid.UUID,
3733
connectorID uuid.UUID,
3834
tracker *tunnelstate.ConnTracker,
39-
cli *cli.Context,
40-
flagInclusionList []string,
35+
cliFlags map[string]string,
4136
icmpSources []string,
4237
) *Handler {
4338
logger := log.With().Logger()
4439
if timeout == 0 {
4540
timeout = defaultCollectorTimeout
4641
}
4742

43+
cliFlags[configurationKeyUID] = strconv.Itoa(os.Getuid())
4844
return &Handler{
49-
log: &logger,
50-
timeout: timeout,
51-
systemCollector: systemCollector,
52-
tunnelID: tunnelID,
53-
connectorID: connectorID,
54-
tracker: tracker,
55-
cli: cli,
56-
flagInclusionList: flagInclusionList,
57-
icmpSources: icmpSources,
45+
log: &logger,
46+
timeout: timeout,
47+
systemCollector: systemCollector,
48+
tunnelID: tunnelID,
49+
connectorID: connectorID,
50+
tracker: tracker,
51+
cliFlags: cliFlags,
52+
icmpSources: icmpSources,
5853
}
5954
}
6055

@@ -140,68 +135,15 @@ func (handler *Handler) ConfigurationHandler(writer http.ResponseWriter, _ *http
140135
log.Info().Msg("Collection finished")
141136
}()
142137

143-
flagsNames := handler.cli.FlagNames()
144-
flags := make(map[string]string, len(flagsNames))
145-
146-
for _, flag := range flagsNames {
147-
value := handler.cli.String(flag)
148-
149-
// empty values are not relevant
150-
if value == "" {
151-
continue
152-
}
153-
154-
// exclude flags that are sensitive
155-
isIncluded := handler.isFlagIncluded(flag)
156-
if !isIncluded {
157-
continue
158-
}
159-
160-
switch flag {
161-
case logger.LogDirectoryFlag:
162-
fallthrough
163-
case logger.LogFileFlag:
164-
{
165-
// the log directory may be relative to the instance thus it must be resolved
166-
absolute, err := filepath.Abs(value)
167-
if err != nil {
168-
handler.log.Error().Err(err).Msgf("could not convert %s path to absolute", flag)
169-
} else {
170-
flags[flag] = absolute
171-
}
172-
}
173-
default:
174-
flags[flag] = value
175-
}
176-
}
177-
178-
// The UID is included to help the
179-
// diagnostic tool to understand
180-
// if this instance is managed or not.
181-
flags[configurationKeyUID] = strconv.Itoa(os.Getuid())
182138
encoder := json.NewEncoder(writer)
183139

184-
err := encoder.Encode(flags)
140+
err := encoder.Encode(handler.cliFlags)
185141
if err != nil {
186142
handler.log.Error().Err(err).Msgf("error occurred whilst serializing response")
187143
writer.WriteHeader(http.StatusInternalServerError)
188144
}
189145
}
190146

191-
func (handler *Handler) isFlagIncluded(flag string) bool {
192-
isIncluded := false
193-
194-
for _, include := range handler.flagInclusionList {
195-
if include == flag {
196-
isIncluded = true
197-
198-
break
199-
}
200-
}
201-
202-
return isIncluded
203-
}
204-
205147
func writeResponse(w http.ResponseWriter, bytes []byte, logger *zerolog.Logger) {
206148
bytesWritten, err := w.Write(bytes)
207149
if err != nil {

diagnostic/handlers_test.go

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7-
"flag"
87
"io"
98
"net"
109
"net/http"
@@ -15,7 +14,6 @@ import (
1514
"github.com/rs/zerolog"
1615
"github.com/stretchr/testify/assert"
1716
"github.com/stretchr/testify/require"
18-
"github.com/urfave/cli/v2"
1917

2018
"github.com/cloudflare/cloudflared/connection"
2119
"github.com/cloudflare/cloudflared/diagnostic"
@@ -30,21 +28,6 @@ const (
3028
errorKey = "errkey"
3129
)
3230

33-
func buildCliContext(t *testing.T, flags map[string]string) *cli.Context {
34-
t.Helper()
35-
36-
flagSet := flag.NewFlagSet("", flag.PanicOnError)
37-
ctx := cli.NewContext(cli.NewApp(), flagSet, nil)
38-
39-
for k, v := range flags {
40-
flagSet.String(k, v, "")
41-
err := ctx.Set(k, v)
42-
require.NoError(t, err)
43-
}
44-
45-
return ctx
46-
}
47-
4831
func newTrackerFromConns(t *testing.T, connections []tunnelstate.IndexedConnectionInfo) *tunnelstate.ConnTracker {
4932
t.Helper()
5033

@@ -80,7 +63,6 @@ func (*SystemCollectorMock) Collect(ctx context.Context) (*diagnostic.SystemInfo
8063
si, _ := ctx.Value(systemInformationKey).(*diagnostic.SystemInformation)
8164
ri, _ := ctx.Value(rawInformationKey).(string)
8265
err, _ := ctx.Value(errorKey).(error)
83-
8466
return si, ri, err
8567
}
8668

@@ -122,8 +104,7 @@ func TestSystemHandler(t *testing.T) {
122104
for _, tCase := range tests {
123105
t.Run(tCase.name, func(t *testing.T) {
124106
t.Parallel()
125-
126-
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, nil, nil, nil)
107+
handler := diagnostic.NewDiagnosticHandler(&log, 0, &SystemCollectorMock{}, uuid.New(), uuid.New(), nil, map[string]string{}, nil)
127108
recorder := httptest.NewRecorder()
128109
ctx := setCtxValuesForSystemCollector(tCase.systemInfo, tCase.rawInfo, tCase.err)
129110
request, err := http.NewRequestWithContext(ctx, http.MethodGet, "/diag/syste,", nil)
@@ -190,8 +171,7 @@ func TestTunnelStateHandler(t *testing.T) {
190171
tCase.tunnelID,
191172
tCase.clientID,
192173
tracker,
193-
nil,
194-
nil,
174+
map[string]string{},
195175
tCase.icmpSources,
196176
)
197177
recorder := httptest.NewRecorder()
@@ -230,10 +210,10 @@ func TestConfigurationHandler(t *testing.T) {
230210
{
231211
name: "cli with flags",
232212
flags: map[string]string{
233-
"a": "a",
234-
"b": "a",
235-
"c": "a",
236-
"d": "a",
213+
"b": "a",
214+
"c": "a",
215+
"d": "a",
216+
"uid": "0",
237217
},
238218
expected: map[string]string{
239219
"b": "a",
@@ -246,11 +226,11 @@ func TestConfigurationHandler(t *testing.T) {
246226

247227
for _, tCase := range tests {
248228
t.Run(tCase.name, func(t *testing.T) {
229+
t.Parallel()
230+
249231
var response map[string]string
250232

251-
t.Parallel()
252-
ctx := buildCliContext(t, tCase.flags)
253-
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, ctx, []string{"b", "c", "d"}, nil)
233+
handler := diagnostic.NewDiagnosticHandler(&log, 0, nil, uuid.New(), uuid.New(), nil, tCase.flags, nil)
254234
recorder := httptest.NewRecorder()
255235
handler.ConfigurationHandler(recorder, nil)
256236
decoder := json.NewDecoder(recorder.Body)

0 commit comments

Comments
 (0)