|
| 1 | +package handlers |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "encoding/json" |
| 6 | + "net/http" |
| 7 | + "net/http/httptest" |
| 8 | + "net/url" |
| 9 | + "strings" |
| 10 | + "testing" |
| 11 | + "time" |
| 12 | + |
| 13 | + "github.com/go-authgate/authgate/internal/config" |
| 14 | + "github.com/go-authgate/authgate/internal/metrics" |
| 15 | + "github.com/go-authgate/authgate/internal/models" |
| 16 | + "github.com/go-authgate/authgate/internal/services" |
| 17 | + "github.com/go-authgate/authgate/internal/store" |
| 18 | + |
| 19 | + "github.com/gin-gonic/gin" |
| 20 | + "github.com/google/uuid" |
| 21 | + "github.com/stretchr/testify/assert" |
| 22 | + "github.com/stretchr/testify/require" |
| 23 | +) |
| 24 | + |
| 25 | +func setupDeviceTestEnv(t *testing.T) (*gin.Engine, *store.Store) { |
| 26 | + t.Helper() |
| 27 | + gin.SetMode(gin.TestMode) |
| 28 | + |
| 29 | + cfg := &config.Config{ |
| 30 | + DeviceCodeExpiration: 30 * time.Minute, |
| 31 | + PollingInterval: 5, |
| 32 | + BaseURL: "http://localhost:8080", |
| 33 | + } |
| 34 | + |
| 35 | + s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{}) |
| 36 | + require.NoError(t, err) |
| 37 | + |
| 38 | + auditSvc := services.NewAuditService(s, false, 0) |
| 39 | + deviceSvc := services.NewDeviceService(s, cfg, auditSvc, metrics.NewNoopMetrics()) |
| 40 | + userSvc := services.NewUserService(s, nil, nil, "local", false, auditSvc, nil, 0) |
| 41 | + authzSvc := services.NewAuthorizationService(s, cfg, auditSvc) |
| 42 | + handler := NewDeviceHandler(deviceSvc, userSvc, authzSvc, cfg) |
| 43 | + |
| 44 | + r := gin.New() |
| 45 | + r.POST("/oauth/device/code", handler.DeviceCodeRequest) |
| 46 | + |
| 47 | + return r, s |
| 48 | +} |
| 49 | + |
| 50 | +func createDeviceFlowClient( |
| 51 | + t *testing.T, |
| 52 | + s *store.Store, |
| 53 | + active bool, |
| 54 | + deviceFlowEnabled bool, |
| 55 | +) *models.OAuthApplication { |
| 56 | + t.Helper() |
| 57 | + status := models.ClientStatusActive |
| 58 | + if !active { |
| 59 | + status = models.ClientStatusInactive |
| 60 | + } |
| 61 | + client := &models.OAuthApplication{ |
| 62 | + ClientID: uuid.New().String(), |
| 63 | + ClientName: "Device Test Client", |
| 64 | + UserID: uuid.New().String(), |
| 65 | + Scopes: "email profile", |
| 66 | + GrantTypes: "device_code", |
| 67 | + EnableDeviceFlow: deviceFlowEnabled, |
| 68 | + Status: status, |
| 69 | + } |
| 70 | + require.NoError(t, s.CreateClient(client)) |
| 71 | + return client |
| 72 | +} |
| 73 | + |
| 74 | +func TestDeviceCodeRequest_Success(t *testing.T) { |
| 75 | + r, s := setupDeviceTestEnv(t) |
| 76 | + client := createDeviceFlowClient(t, s, true, true) |
| 77 | + |
| 78 | + w := httptest.NewRecorder() |
| 79 | + form := url.Values{"client_id": {client.ClientID}, "scope": {"email"}} |
| 80 | + req, _ := http.NewRequest( |
| 81 | + http.MethodPost, |
| 82 | + "/oauth/device/code", |
| 83 | + strings.NewReader(form.Encode()), |
| 84 | + ) |
| 85 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 86 | + r.ServeHTTP(w, req) |
| 87 | + |
| 88 | + assert.Equal(t, http.StatusOK, w.Code) |
| 89 | + |
| 90 | + var resp map[string]any |
| 91 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 92 | + assert.NotEmpty(t, resp["device_code"]) |
| 93 | + assert.NotEmpty(t, resp["user_code"]) |
| 94 | + assert.Contains(t, resp["verification_uri"], "/device") |
| 95 | + assert.NotZero(t, resp["expires_in"]) |
| 96 | + assert.NotZero(t, resp["interval"]) |
| 97 | +} |
| 98 | + |
| 99 | +func TestDeviceCodeRequest_MissingClientID(t *testing.T) { |
| 100 | + r, _ := setupDeviceTestEnv(t) |
| 101 | + |
| 102 | + w := httptest.NewRecorder() |
| 103 | + req, _ := http.NewRequest(http.MethodPost, "/oauth/device/code", nil) |
| 104 | + r.ServeHTTP(w, req) |
| 105 | + |
| 106 | + assert.Equal(t, http.StatusBadRequest, w.Code) |
| 107 | + |
| 108 | + var resp map[string]string |
| 109 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 110 | + assert.Equal(t, "invalid_request", resp["error"]) |
| 111 | +} |
| 112 | + |
| 113 | +func TestDeviceCodeRequest_UnknownClient(t *testing.T) { |
| 114 | + r, _ := setupDeviceTestEnv(t) |
| 115 | + |
| 116 | + w := httptest.NewRecorder() |
| 117 | + form := url.Values{"client_id": {"nonexistent-client"}} |
| 118 | + req, _ := http.NewRequest( |
| 119 | + http.MethodPost, |
| 120 | + "/oauth/device/code", |
| 121 | + strings.NewReader(form.Encode()), |
| 122 | + ) |
| 123 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 124 | + r.ServeHTTP(w, req) |
| 125 | + |
| 126 | + assert.Equal(t, http.StatusBadRequest, w.Code) |
| 127 | + |
| 128 | + var resp map[string]string |
| 129 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 130 | + assert.Equal(t, "invalid_client", resp["error"]) |
| 131 | +} |
| 132 | + |
| 133 | +func TestDeviceCodeRequest_InactiveClient(t *testing.T) { |
| 134 | + r, s := setupDeviceTestEnv(t) |
| 135 | + client := createDeviceFlowClient(t, s, false, true) |
| 136 | + |
| 137 | + w := httptest.NewRecorder() |
| 138 | + form := url.Values{"client_id": {client.ClientID}} |
| 139 | + req, _ := http.NewRequest( |
| 140 | + http.MethodPost, |
| 141 | + "/oauth/device/code", |
| 142 | + strings.NewReader(form.Encode()), |
| 143 | + ) |
| 144 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 145 | + r.ServeHTTP(w, req) |
| 146 | + |
| 147 | + assert.Equal(t, http.StatusBadRequest, w.Code) |
| 148 | + |
| 149 | + var resp map[string]string |
| 150 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 151 | + assert.Equal(t, "invalid_client", resp["error"]) |
| 152 | +} |
| 153 | + |
| 154 | +func TestDeviceCodeRequest_DeviceFlowNotEnabled(t *testing.T) { |
| 155 | + r, s := setupDeviceTestEnv(t) |
| 156 | + // Create client with device flow enabled first, then disable it |
| 157 | + // (GORM default:true skips zero-value false on insert) |
| 158 | + client := createDeviceFlowClient(t, s, true, true) |
| 159 | + client.EnableDeviceFlow = false |
| 160 | + require.NoError(t, s.UpdateClient(client)) |
| 161 | + |
| 162 | + w := httptest.NewRecorder() |
| 163 | + form := url.Values{"client_id": {client.ClientID}} |
| 164 | + req, _ := http.NewRequest( |
| 165 | + http.MethodPost, |
| 166 | + "/oauth/device/code", |
| 167 | + strings.NewReader(form.Encode()), |
| 168 | + ) |
| 169 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 170 | + r.ServeHTTP(w, req) |
| 171 | + |
| 172 | + assert.Equal(t, http.StatusBadRequest, w.Code) |
| 173 | + |
| 174 | + var resp map[string]string |
| 175 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 176 | + assert.Equal(t, "unauthorized_client", resp["error"]) |
| 177 | +} |
| 178 | + |
| 179 | +func TestDeviceCodeRequest_JSONBody(t *testing.T) { |
| 180 | + r, s := setupDeviceTestEnv(t) |
| 181 | + client := createDeviceFlowClient(t, s, true, true) |
| 182 | + |
| 183 | + w := httptest.NewRecorder() |
| 184 | + body := `{"client_id":"` + client.ClientID + `","scope":"profile"}` |
| 185 | + req, _ := http.NewRequest(http.MethodPost, "/oauth/device/code", strings.NewReader(body)) |
| 186 | + req.Header.Set("Content-Type", "application/json") |
| 187 | + r.ServeHTTP(w, req) |
| 188 | + |
| 189 | + assert.Equal(t, http.StatusOK, w.Code) |
| 190 | + |
| 191 | + var resp map[string]any |
| 192 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 193 | + assert.NotEmpty(t, resp["device_code"]) |
| 194 | +} |
| 195 | + |
| 196 | +func TestDeviceCodeRequest_DefaultScope(t *testing.T) { |
| 197 | + r, s := setupDeviceTestEnv(t) |
| 198 | + client := createDeviceFlowClient(t, s, true, true) |
| 199 | + |
| 200 | + w := httptest.NewRecorder() |
| 201 | + form := url.Values{"client_id": {client.ClientID}} |
| 202 | + req, _ := http.NewRequest( |
| 203 | + http.MethodPost, |
| 204 | + "/oauth/device/code", |
| 205 | + strings.NewReader(form.Encode()), |
| 206 | + ) |
| 207 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 208 | + r.ServeHTTP(w, req) |
| 209 | + |
| 210 | + assert.Equal(t, http.StatusOK, w.Code) |
| 211 | + |
| 212 | + var resp map[string]any |
| 213 | + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) |
| 214 | + assert.NotEmpty(t, resp["device_code"]) |
| 215 | + assert.NotEmpty(t, resp["user_code"]) |
| 216 | +} |
0 commit comments