|
4 | 4 | package api |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "encoding/json" |
7 | 8 | "io" |
8 | 9 | "net/http" |
9 | 10 | "regexp" |
10 | 11 | "strings" |
11 | 12 | "testing" |
12 | 13 |
|
| 14 | + "github.com/BurntSushi/toml" |
| 15 | + "github.com/pingcap/tiproxy/lib/config" |
13 | 16 | "github.com/stretchr/testify/require" |
14 | 17 | ) |
15 | 18 |
|
16 | 19 | func TestConfig(t *testing.T) { |
17 | 20 | _, doHTTP := createServer(t, nil) |
18 | 21 |
|
19 | | - doHTTP(t, http.MethodGet, "/api/admin/config", nil, func(t *testing.T, r *http.Response) { |
| 22 | + doHTTP(t, http.MethodGet, "/api/admin/config", nil, nil, func(t *testing.T, r *http.Response) { |
20 | 23 | all, err := io.ReadAll(r.Body) |
21 | 24 | require.NoError(t, err) |
22 | 25 | require.Equal(t, ` |
@@ -72,42 +75,75 @@ max-backups = 3 |
72 | 75 | `, string(regexp.MustCompile("workdir = '.+'\n").ReplaceAll(all, nil))) |
73 | 76 | require.Equal(t, http.StatusOK, r.StatusCode) |
74 | 77 | }) |
75 | | - doHTTP(t, http.MethodGet, "/api/admin/config?format=json", nil, func(t *testing.T, r *http.Response) { |
| 78 | + doHTTP(t, http.MethodGet, "/api/admin/config?format=json", nil, nil, func(t *testing.T, r *http.Response) { |
76 | 79 | all, err := io.ReadAll(r.Body) |
77 | 80 | require.NoError(t, err) |
78 | 81 | require.Equal(t, `{"proxy":{"addr":"0.0.0.0:6000","pd-addrs":"127.0.0.1:2379","frontend-keepalive":{"enabled":true},"backend-healthy-keepalive":{"enabled":true,"idle":60000000000,"cnt":5,"intvl":3000000000,"timeout":15000000000},"backend-unhealthy-keepalive":{"enabled":true,"idle":10000000000,"cnt":5,"intvl":1000000000,"timeout":5000000000},"graceful-close-conn-timeout":15},"api":{"addr":"0.0.0.0:3080"},"advance":{"ignore-wrong-namespace":true},"security":{"server-tls":{"min-tls-version":"1.2"},"server-http-tls":{"min-tls-version":"1.2"},"cluster-tls":{"min-tls-version":"1.2"},"sql-tls":{"min-tls-version":"1.2"}},"log":{"encoder":"tidb","level":"info","log-file":{"max-size":300,"max-days":3,"max-backups":3}}}`, |
79 | 82 | string(regexp.MustCompile(`"workdir":"[^"]+",`).ReplaceAll(all, nil))) |
80 | 83 | require.Equal(t, http.StatusOK, r.StatusCode) |
81 | 84 | }) |
82 | 85 |
|
83 | | - doHTTP(t, http.MethodPut, "/api/admin/config", strings.NewReader("security.require-backend-tls = true"), func(t *testing.T, r *http.Response) { |
| 86 | + doHTTP(t, http.MethodPut, "/api/admin/config", strings.NewReader("security.require-backend-tls = true"), nil, func(t *testing.T, r *http.Response) { |
84 | 87 | require.Equal(t, http.StatusOK, r.StatusCode) |
85 | 88 | }) |
86 | 89 | sum := "" |
87 | 90 | sumreg := regexp.MustCompile(`{"config_checksum":(.+)}`) |
88 | | - doHTTP(t, http.MethodGet, "/api/debug/health", nil, func(t *testing.T, r *http.Response) { |
| 91 | + doHTTP(t, http.MethodGet, "/api/debug/health", nil, nil, func(t *testing.T, r *http.Response) { |
89 | 92 | all, err := io.ReadAll(r.Body) |
90 | 93 | require.NoError(t, err) |
91 | 94 | sum = string(sumreg.Find(all)) |
92 | 95 | require.Equal(t, http.StatusOK, r.StatusCode) |
93 | 96 | }) |
94 | | - doHTTP(t, http.MethodPut, "/api/admin/config", strings.NewReader("proxy.require-back = false"), func(t *testing.T, r *http.Response) { |
| 97 | + doHTTP(t, http.MethodPut, "/api/admin/config", strings.NewReader("proxy.require-back = false"), nil, func(t *testing.T, r *http.Response) { |
95 | 98 | // no error |
96 | 99 | require.Equal(t, http.StatusOK, r.StatusCode) |
97 | 100 | }) |
98 | | - doHTTP(t, http.MethodGet, "/api/debug/health", nil, func(t *testing.T, r *http.Response) { |
| 101 | + doHTTP(t, http.MethodGet, "/api/debug/health", nil, nil, func(t *testing.T, r *http.Response) { |
99 | 102 | all, err := io.ReadAll(r.Body) |
100 | 103 | require.NoError(t, err) |
101 | 104 | require.Equal(t, sum, string(sumreg.Find(all))) |
102 | 105 | require.Equal(t, http.StatusOK, r.StatusCode) |
103 | 106 | }) |
104 | | - doHTTP(t, http.MethodPut, "/api/admin/config", strings.NewReader("security.require-backend-tls = false"), func(t *testing.T, r *http.Response) { |
| 107 | + doHTTP(t, http.MethodPut, "/api/admin/config", strings.NewReader("security.require-backend-tls = false"), nil, func(t *testing.T, r *http.Response) { |
105 | 108 | require.Equal(t, http.StatusOK, r.StatusCode) |
106 | 109 | }) |
107 | | - doHTTP(t, http.MethodGet, "/api/debug/health", nil, func(t *testing.T, r *http.Response) { |
| 110 | + doHTTP(t, http.MethodGet, "/api/debug/health", nil, nil, func(t *testing.T, r *http.Response) { |
108 | 111 | all, err := io.ReadAll(r.Body) |
109 | 112 | require.NoError(t, err) |
110 | 113 | require.NotEqual(t, sum, string(sumreg.Find(all))) |
111 | 114 | require.Equal(t, http.StatusOK, r.StatusCode) |
112 | 115 | }) |
113 | 116 | } |
| 117 | + |
| 118 | +func TestAcceptType(t *testing.T) { |
| 119 | + _, doHTTP := createServer(t, nil) |
| 120 | + checkRespContentType := func(expectedType string, r *http.Response) { |
| 121 | + require.Equal(t, http.StatusOK, r.StatusCode) |
| 122 | + data, err := io.ReadAll(r.Body) |
| 123 | + require.NoError(t, err) |
| 124 | + var cfg config.Config |
| 125 | + switch expectedType { |
| 126 | + case "json": |
| 127 | + require.Contains(t, r.Header.Get("Content-Type"), "application/json") |
| 128 | + require.NoError(t, json.Unmarshal(data, &cfg)) |
| 129 | + default: |
| 130 | + require.Contains(t, r.Header.Get("Content-Type"), "application/toml") |
| 131 | + require.NoError(t, toml.Unmarshal(data, &cfg)) |
| 132 | + } |
| 133 | + } |
| 134 | + doHTTP(t, http.MethodGet, "/api/admin/config", nil, nil, func(t *testing.T, r *http.Response) { |
| 135 | + checkRespContentType("toml", r) |
| 136 | + }) |
| 137 | + doHTTP(t, http.MethodGet, "/api/admin/config", nil, map[string]string{"Accept": "application/json"}, func(t *testing.T, r *http.Response) { |
| 138 | + checkRespContentType("json", r) |
| 139 | + }) |
| 140 | + doHTTP(t, http.MethodGet, "/api/admin/config", nil, map[string]string{"Accept": "application/toml"}, func(t *testing.T, r *http.Response) { |
| 141 | + checkRespContentType("toml", r) |
| 142 | + }) |
| 143 | + doHTTP(t, http.MethodGet, "/api/admin/config?format=json", nil, nil, func(t *testing.T, r *http.Response) { |
| 144 | + checkRespContentType("json", r) |
| 145 | + }) |
| 146 | + doHTTP(t, http.MethodGet, "/api/admin/config?format=JSON", nil, nil, func(t *testing.T, r *http.Response) { |
| 147 | + checkRespContentType("json", r) |
| 148 | + }) |
| 149 | +} |
0 commit comments