diff --git a/.gitignore b/.gitignore index 63b11496..59c69d71 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,8 @@ src/avatars/images .devcontainer/customize.sh # Runtime data -uploads/ \ No newline at end of file +uploads/ + +# OAuth Providers (sensitive, do not commit) +oauth-providers/ +plugins/ \ No newline at end of file diff --git a/docs/oauth-implementation-guide.md b/docs/oauth-implementation-guide.md new file mode 100644 index 00000000..d635201a --- /dev/null +++ b/docs/oauth-implementation-guide.md @@ -0,0 +1,286 @@ +# OAuth 功能实现指南 + +## 概述 + +本文档描述了 Cheese Backend 中 OAuth 登录功能的实现。该功能允许用户使用第三方 OAuth 提供商(如学校的 OAuth 系统)进行登录。 + +## 核心组件 + +### 1. OAuth Service (`src/auth/oauth/oauth.service.ts`) +- 动态加载和管理多个 OAuth 提供商 +- 支持插件机制和 npm 包加载 +- 提供统一的 OAuth 流程接口 + +### 2. OAuth 类型定义 (`src/auth/oauth/oauth.types.ts`) +- `OAuthProvider` 接口:定义提供商必须实现的方法 +- `BaseOAuthProvider` 抽象类:提供通用实现 +- `OAuthUserInfo` 接口:标准化用户信息格式 + +### 3. 数据库模型 (`src/auth/oauth/oauth.prisma`) +- `UserOAuthConnection` 模型:存储用户与第三方账号的关联关系 + +### 4. 控制器路由 (`src/users/users.controller.ts`) +- `GET /users/auth/oauth/providers` - 获取可用的 OAuth 提供商列表 +- `GET /users/auth/oauth/login/:providerId` - 重定向到 OAuth 提供商授权页面 +- `GET /users/auth/oauth/callback/:providerId` - 处理 OAuth 回调 + +## 环境配置 + +在 `.env` 文件中添加以下配置: + +```bash +# 启用的 OAuth 提供商(逗号分隔) +OAUTH_ENABLED_PROVIDERS=ruc,google + +# OAuth 插件搜索路径 +OAUTH_PLUGIN_PATHS=plugins/oauth + +# 是否允许从 npm 包加载提供商 +OAUTH_ALLOW_NPM_LOADING=false + +# 提供商凭据(以 'ruc' 为例) +OAUTH_RUC_CLIENT_ID=your-client-id +OAUTH_RUC_CLIENT_SECRET=your-client-secret +OAUTH_RUC_REDIRECT_URL=http://localhost:3000/users/auth/oauth/callback/ruc + +# 前端重定向路径 +FRONTEND_OAUTH_SUCCESS_PATH=/oauth-success +FRONTEND_OAUTH_ERROR_PATH=/oauth-error +``` + +## 实现新的 OAuth 提供商 + +### 方法一:插件文件 + +在 `plugins/oauth/` 目录下创建提供商实现文件: + +```javascript +// plugins/oauth/your-provider.js +const axios = require('axios'); + +class YourOAuthProvider { + constructor(config) { + this.config = { + ...config, + authorizationUrl: 'https://your-provider.com/oauth/authorize', + tokenUrl: 'https://your-provider.com/oauth/token', + scope: ['read:user', 'user:email'], + }; + } + + getConfig() { + return this.config; + } + + getAuthorizationUrl(state, accessType) { + const params = new URLSearchParams({ + client_id: this.config.clientId, + redirect_uri: this.config.redirectUrl, + scope: this.config.scope.join(' '), + response_type: 'code', + }); + + if (state) params.append('state', state); + if (accessType) params.append('access_type', accessType); + + return `${this.config.authorizationUrl}?${params.toString()}`; + } + + async handleCallback(code, state) { + const response = await axios.post(this.config.tokenUrl, { + client_id: this.config.clientId, + client_secret: this.config.clientSecret, + code, + grant_type: 'authorization_code', + redirect_uri: this.config.redirectUrl, + }); + + return response.data.access_token; + } + + async getUserInfo(accessToken) { + const response = await axios.get('https://your-provider.com/api/user', { + headers: { 'Authorization': `Bearer ${accessToken}` }, + }); + + const userData = response.data; + return { + id: userData.id.toString(), + email: userData.email, + name: userData.name, + username: userData.username, + preferredUsername: userData.preferred_username, + }; + } +} + +function createProvider(config) { + return new YourOAuthProvider(config); +} + +module.exports = { createProvider, default: createProvider }; +``` + +### 方法二:TypeScript 实现 + +```typescript +// plugins/oauth/your-provider.ts +import axios from 'axios'; +import { BaseOAuthProvider, OAuthProviderConfig, OAuthUserInfo } from '../../src/auth/oauth/oauth.types'; + +export class YourOAuthProvider extends BaseOAuthProvider { + constructor(config: OAuthProviderConfig) { + super({ + ...config, + authorizationUrl: 'https://your-provider.com/oauth/authorize', + tokenUrl: 'https://your-provider.com/oauth/token', + scope: ['read:user', 'user:email'], + }); + } + + async handleCallback(code: string, state?: string): Promise { + const response = await axios.post(this.config.tokenUrl, { + client_id: this.config.clientId, + client_secret: this.config.clientSecret, + code, + grant_type: 'authorization_code', + redirect_uri: this.config.redirectUrl, + }); + + return response.data.access_token; + } + + async getUserInfo(accessToken: string): Promise { + const response = await axios.get('https://your-provider.com/api/user', { + headers: { 'Authorization': `Bearer ${accessToken}` }, + }); + + const userData = response.data; + return { + id: userData.id.toString(), + email: userData.email, + name: userData.name, + username: userData.username, + preferredUsername: userData.preferred_username, + }; + } +} + +export function createProvider(config: OAuthProviderConfig) { + return new YourOAuthProvider(config); +} +``` + +## OAuth 流程 + +1. **用户点击 OAuth 登录** + - 前端调用 `GET /users/auth/oauth/providers` 获取可用提供商 + - 前端引导用户访问 `GET /users/auth/oauth/login/:providerId` + +2. **重定向到提供商** + - 后端生成授权 URL 并重定向用户到 OAuth 提供商 + +3. **用户授权并回调** + - 用户在提供商页面完成授权 + - 提供商重定向用户到 `GET /users/auth/oauth/callback/:providerId` + +4. **处理回调** + - 后端交换 authorization code 获取 access token + - 使用 access token 获取用户信息 + - 根据用户信息进行登录或注册 + - 生成 JWT token 并重定向到前端 + +5. **用户登录成功** + - 前端从 URL 参数获取 JWT token + - 从 Cookie 获取 refresh token + - 完成登录状态设置 + +## 用户同步逻辑 + +### 1. 检查已有绑定 +- 查询 `UserOAuthConnection` 表,查找是否已有该提供商和用户 ID 的绑定 +- 如果找到且关联用户未被删除,直接登录 + +### 2. 按邮箱匹配现有用户 +- 如果未找到绑定但 OAuth 提供了邮箱 +- 查找本地是否有相同邮箱的活跃用户 +- 如果找到,创建新的 OAuth 绑定并登录 + +### 3. 创建新用户 +- 如果既无绑定又无邮箱匹配,创建新用户 +- 生成唯一用户名(基于 OAuth 用户信息) +- 创建用户、用户档案和 OAuth 绑定 +- 生成随机密码(用户不会使用) + +## 安全考虑 + +1. **State 参数验证** + - OAuth 流程中的 state 参数用于防止 CSRF 攻击 + - 建议在生产环境中实现 state 验证 + +2. **路径遍历防护** + - 插件加载时验证路径安全性 + - 只允许加载预期目录下的文件 + +3. **配置验证** + - 验证提供商 ID 只包含安全字符 + - 检查必要配置项的存在 + +4. **错误处理** + - 统一的错误处理和日志记录 + - 不泄露敏感信息给前端 + +## 扩展功能 + +### 添加新提供商支持 +1. 实现提供商类(参考上面的示例) +2. 将实现文件放在 `plugins/oauth/` 目录 +3. 在环境变量中配置提供商凭据 +4. 将提供商 ID 添加到 `OAUTH_ENABLED_PROVIDERS` + +### 自定义用户信息映射 +在提供商实现的 `getUserInfo` 方法中,可以自定义如何将提供商的用户数据映射到标准的 `OAuthUserInfo` 格式。 + +### 长期令牌管理 +可以在 `UserOAuthConnection` 表中存储 OAuth refresh token,用于长期访问第三方 API(如果需要)。 + +## 前端集成 + +前端需要实现以下页面: + +1. **登录页面** + - 调用 `/users/auth/oauth/providers` 获取提供商列表 + - 为每个提供商提供登录按钮 + +2. **OAuth 成功页面** (`/oauth-success`) + - 从 URL 参数获取 `token` 和 `email` + - 设置应用登录状态 + +3. **OAuth 错误页面** (`/oauth-error`) + - 显示错误信息 + - 提供重新登录选项 + +## 故障排除 + +### 常见问题 + +1. **提供商未加载** + - 检查环境变量配置 + - 确认插件文件路径和格式正确 + - 查看应用启动日志 + +2. **回调失败** + - 确认回调 URL 配置正确 + - 检查提供商的 client credentials + - 查看错误日志了解具体失败原因 + +3. **用户信息获取失败** + - 检查用户信息 API 的 URL 和权限要求 + - 确认 access token 有效且有足够权限 + +### 调试技巧 + +1. 启用详细日志记录 +2. 使用 OAuth 提供商的开发者工具 +3. 检查网络请求和响应 +4. 验证 JWT token 的内容和有效性 \ No newline at end of file diff --git a/docs/oauth.md b/docs/oauth.md new file mode 100644 index 00000000..f0c7a7ec --- /dev/null +++ b/docs/oauth.md @@ -0,0 +1,198 @@ +明白了。我将研究如何将cheese-auth仓库中的OAuth认证功能(特别是你们学校的OAuth登录支持)作为可选的登录方式整合回cheese-backend仓库中,并确保其不会破坏现有的认证逻辑。 + +我会仔细分析cheese-auth中OAuth相关的模块、依赖和接口,评估其如何在cheese-backend中以模块化方式集成,包括必要的中间件、配置、用户信息同步等部分。完成后会整理成一份清晰的迁移方案供你参考。 + + +# Cheese-Backend 集成 OAuth 登录功能 + +## 背景与目标 + +Cheese-Auth 仓库实现了可动态扩展的 OAuth 登录模块(支持学校自有 OAuth 登录等)。现在希望将该功能“反向移植”到 Cheese-Backend 中,作为**可选**的登录方式,同时保持原有用户名/密码等认证流程不受影响。具体目标包括: + +1. **保持原有认证系统正常运行:** 新增 OAuth 登录支持时,不影响 Cheese-Backend 现有的密码登录、2FA 等机制。 +2. **独立集成 OAuth 模块:** 不直接依赖 cheese-auth 服务或库,而是在 Cheese-Backend 内部重新实现类似的 OAuth 支持。 +3. **模块化、可插拔设计:** 以模块方式集成 OAuth 登录,方便后续添加更多 OAuth 服务商。启用与否、支持哪些提供商都通过配置控制。 +4. **映射关键组件:** 基于 cheese-auth 的实现,识别 OAuth 功能的关键组件(如 OAuth 路由、Token 交换、回调处理、用户同步逻辑等),并说明如何在 Cheese-Backend 中对应整合。 +5. **提供实施指南:** 包括必要的代码结构说明、配置项、依赖库及注意事项,便于在 Cheese-Backend 中实现和配置该功能。 + +下面将按照以上思路详细介绍 OAuth 模块的设计与集成方案。 + +## OAuth 模块设计与提供程序接口 + +在 Cheese-Backend 中新增一个 **OAuth 模块**,包含 OAuth 服务(`OAuthService`)和提供程序接口定义,使其逻辑与 cheese-auth 保持一致。具体设计: + +* **提供程序接口 (`OAuthProvider`):** 定义每种 OAuth 登录方式需实现的方法,包括: + + * `getAuthorizationUrl(state?, accessType?)`:生成跳转到第三方 OAuth 提供商的认证 URL。 + * `handleCallback(code, state?)`:处理提供商回调,使用授权码换取访问令牌。 + * `getUserInfo(accessToken)`: 用访问令牌调用提供商的用户信息接口,获取用户资料。 + + 以及 `getConfig()` 返回提供商配置信息等。通过统一接口,Cheese-Backend 可与任意提供商交互,而具体差异由各提供商实现自行处理。 + +* **提供商配置 (`OAuthProviderConfig`):** 包含提供商的标识`id`、名称`name`、`clientId`、`clientSecret`、`authorizationUrl`(授权页面地址)、`tokenUrl`(换取 token 的接口)、`redirectUrl`(回调地址)、所需`scope`等。这些配置主要由环境变量提供。 + +* **OAuth 服务 (`OAuthService`):** NestJS Injectable 单例服务,负责**动态加载**和管理多个 OAuth 提供商。其职责: + + * **读取配置:** 在模块初始化时读取环境变量 `OAUTH_ENABLED_PROVIDERS`(启用的提供商列表);如果为空则不启用任何 OAuth 登录(保证模块可选,不配置就相当于禁用)。此外读取 `OAUTH_PLUGIN_PATHS`(插件搜索路径)和 `OAUTH_ALLOW_NPM_LOADING`(是否允许从 npm 包加载提供商)等配置。 + * **加载提供商实现:** 对每个启用的提供商ID: + + * 检查环境中是否提供了该提供商所需的 `CLIENT_ID`、`CLIENT_SECRET`、`REDIRECT_URL` 等凭据;若缺失则跳过并发出警告。 + * 在配置的插件目录下查找对应的提供商模块文件。如约定目录结构,在 `plugins/oauth/{providerId}/index.js` 或 `plugins/oauth/{providerId}.js` 等位置寻找实现。加载前会校验路径安全,防止路径遍历等风险。找到文件后通过 `import()` 动态加载模块,并调用其导出工厂函数实例化提供商对象。 + * 如本地未找到且允许 npm 加载,则尝试从已安装的 `@sageseekersociety/cheese-auth-{providerId}-oauth-provider` 包导入实现。为安全考虑,默认不启用 npm 动态加载(`OAUTH_ALLOW_NPM_LOADING=false`),除非明确配置启用。 + * 调用 `registerProvider()` 将实例化的提供商注册到内部映射表,以备后续根据ID查找使用。注册时以提供商ID生成唯一注入 token,确保各 provider 可独立注入(如果需要)。 + * **提供查询接口:** 提供 `getProvider(id)` 和 `getAllProviders()` 方法用于业务层获取提供商实例,`getProvidersConfig()` 则返回所有已注册提供商的配置信息概要(屏蔽掉敏感的 clientSecret)。这用于向前端提供可用的 OAuth 选项列表。 + +上述设计确保 OAuth 模块是自包含、可选加载的:只有在环境配置了提供商时才实际发挥作用,否则对系统无影响。同时,采用插件机制使得新增提供商无需修改核心代码,只需在指定路径放入实现或安装对应包并更新配置即可。 + +## 动态加载 OAuth 提供商实现 + +**cheese-auth** 的 OAuth 功能通过插件机制实现了提供商的可拔插支持。我们将在 Cheese-Backend 延续这一设计。关键实现点: + +* **插件目录:** 可在 Cheese-Backend 仓库中新建 `plugins/oauth/` 目录(或其他配置路径),用于存放各 OAuth 提供商的实现代码。实现可以是单文件或一个目录模块,遵循 cheese-auth 定义的导出规范(导出 `createProvider` 工厂函数或默认导出一个创建函数)。这样 OAuthService 能动态加载对应模块并调用工厂函数生成提供商实例。 +* **提供商实现规范:** 每个提供商插件应当返回一个实现了 `OAuthProvider` 接口的对象,一般可通过继承基类 `BaseOAuthProvider` 简化实现。Cheese-Auth 定义了 `BaseOAuthProvider` 抽象类提供通用的授权 URL 构建逻辑(附加 client\_id、redirect\_uri、scope 等参数),各提供商只需实现其特有的 `handleCallback`(用授权码获取 token)和 `getUserInfo`(用 token 获取用户信息)逻辑。例如,“学校自有 OAuth”(假设标识为`ruc`)的提供商实现会定义好学校认证服务器的 `authorizationUrl` 和 `tokenUrl`,并使用 `axios` 等HTTP库在 `handleCallback` 中向学校 OAuth 服务发送 token 请求,获取 access\_token 和(可选)refresh\_token,再实现 `getUserInfo` 调用学校的用户信息接口,返回包括学号、姓名、邮箱等字段的 `OAuthUserInfo` 对象。 +* **安全考虑:** OAuthService 在加载插件时会验证提供商ID仅包含安全字符(字母、数字、`-`、`_`),防止拼接路径时出现不安全的目录跳转。对于文件插件,使用`path.resolve`并确保目标路径仍在配置的基准目录下,避免非预期路径的代码被加载。对于 npm 插件,仅在明确允许时才尝试,且要求提前安装好对应包版本,以减少运行时从不受信任源下载代码的风险。这一系列检查确保我们在加载第三方提供商实现时尽可能降低安全隐患。 + +通过以上机制,Cheese-Backend 可像 Cheese-Auth 一样支持**动态扩展** OAuth 登录提供商。例如,要新增对 Google 登录的支持,只需开发符合规范的 `google` 提供商模块,放入插件目录并在环境变量中把 `google` 加入启用列表,无需修改服务器核心代码。这满足了可插拔的扩展需求。 + +## OAuth 登录相关路由 + +OAuth 模块加载后,需要在 API 层新增相应的接口供前端使用,主要包括: + +* **获取提供商列表:** `GET /users/auth/oauth/providers` – 返回当前后端配置并启用的 OAuth 提供商列表。Cheese-Auth 实现返回了状态码和消息,以及每个提供商的 `id`、显示名称`name`等基本信息,供前端展示登录选项。在 Cheese-Backend 中,我们可以类似地通过 `oauthService.getProvidersConfig()` 获取提供商配置列表,并封装成统一响应格式。 + +* **跳转 OAuth 登录:** `GET /users/auth/oauth/login/:providerId` – 用户选择某个 OAuth 选项后,前端引导浏览器请求此接口,后端据此构造对应提供商的**授权登录URL**并重定向。实现细节: + + * 后端通过 `oauthService.getProvider(providerId)` 查找对应的提供商实例。若未找到(提供商ID无效或未启用),返回 404 错误。 + * 找到提供商后,调用其 `getAuthorizationUrl(state, accessType)` 方法生成第三方认证页面的完整 URL。参数 `state` 用于防范 CSRF(可由前端产生并传回,用于回调校验),`access_type` 则供某些平台请求离线访问权限(refresh token)。 + * 使用 NestJS 的 `@Redirect()` 装饰器直接将响应定位到该 URL。这样浏览器会被重定向至提供商的登录页面,用户在第三方完成认证授权后,浏览器将跳转回我们配置的回调地址。 + +* **处理 OAuth 回调:** `GET /users/auth/oauth/callback/:providerId` – 第三方认证完成后将用户带回此接口,并附加授权码(`code`)和之前的`state`参数。后端需要处理如下流程: + + 1. 根据 `providerId` 找到对应的 OAuthProvider 实例,若不存在则返回 404。 + 2. 调用 `provider.handleCallback(code, state)`,向提供商的 Token 接口换取访问令牌(Access Token)。如提供商发回了错误(例如用户拒绝授权),应抛出异常进入错误流程处理。 + 3. 调用 `provider.getUserInfo(accessToken)` 获取用户基本信息(ID、姓名、邮箱等)。 + 4. 调用应用的用户服务逻辑,将此第三方用户信息与本地用户系统对接:执行登录或注册流程。我们将在下一节详述 **loginWithOAuth** 的实现。 + 5. `loginWithOAuth` 返回本地用户DTO及一个**应用内刷新令牌**(`refreshToken`)。随后,通过 `sessionService.refreshSession(refreshToken)` 颁发新的 Refresh Token 和对应的短期 JWT 访问令牌。这样,我们复用了系统现有的 Session/JWT 机制,OAuth 登录用户最终获得与密码登录用户相同格式的认证令牌。 + 6. 构造前端跳转:后端拿到 JWT后,需要引导浏览器回到前端指定页面,并把令牌传给前端。cheese-auth的做法是从配置中读取 `FRONTEND_BASE_URL` 和 `FRONTEND_OAUTH_SUCCESS_PATH` 作为成功登录后前端接收页面的地址。将 JWT Access Token 附加在URL的查询参数(如 `token=`)中传递。 同时,出于用户体验考虑,可以附加用户标识信息,例如 email(cheese-auth 将用户邮箱作为参数,以便前端显示“已使用xx邮箱登录”提示)。 + 7. 设置 Cookie:与常规登录一样,设置HTTP-Only的 `REFRESH_TOKEN` Cookie,路径限定为 `/users/auth`(由 `COOKIE_BASE_PATH` 配置决定)。这样前端后续可以使用 Refresh Token 刷新会话,而无需另行存储它。 + 8. 最后通过 `res.redirect(frontendRedirectUrl)` 将浏览器重定向到前端页面。前端据URL参数获取JWT,并结合Cookie中的 Refresh Token 完成登录状态维护。 + + 若上述流程中出现任何错误(如授权码无效、交换 token 失败等),则进入**错误处理**分支。后端会构造一个前端错误接收页面URL(`FRONTEND_OAUTH_ERROR_PATH` 配置),附加错误消息和提供商标识等信息,重定向浏览器到该错误页面,便于前端告知用户登录失败原因并做后续处理。 + +上述三个接口需在 NestJS 控制器中新增。Cheese-Auth 是将这些路由合并进了 UsersController(注解 `@Controller('/users')`)中,并使用了 `@NoAuth()` 装饰器标记为公共访问(无需现有认证令牌)。在 Cheese-Backend 中我们也可以采取类似做法:**在 UsersController 中增设 /auth/oauth/... 路由**,以便与现有 `/users/auth/login` 等路径保持一致风格。同时确保使用 `@NoAuth`(或同等机制)豁免 JWT 拦截,允许未登录用户访问这些接口。 + +## 用户同步与登录逻辑 + +用户服务需要新增一个关键方法 `loginWithOAuth(providerId, userInfo, ...)`,将第三方返回的用户信息与本地用户数据库同步,流程如下: + +1. **检查已有绑定:** 查询本地数据库的 **用户OAuth关联表**(我们稍后介绍其结构),查找是否已有记录对应此 OAuth 提供商和该提供商下的用户ID。如果找到且关联的本地用户未被删除,则表示该第三方账户以前登录过,直接复用对应的本地用户。 + + * 若找到记录但关联的 User 被软删除了,可视需求决定是否解锁/重新激活;当前实现简单地视同不存在,转入注册流程。 + * 如果记录存在且用户存在,但发现用户缺少用户档案(profile)(理论上不应发生),则补建默认档案以保持数据一致性。 + +2. **按邮箱匹配现有用户:** 如果没有现成关联,且第三方提供了 email 且本地开启了用邮箱唯一标识用户的策略,则尝试按 email 查找现有活跃用户。这覆盖了用户可能先用邮箱密码注册,后来又尝试用同邮箱的 OAuth 登录的情况。 + + * 若找到匹配用户,则**创建关联**:将该本地用户的ID与当前 OAuth账户ID建立链接并存库。Cheese-Auth 使用 Prisma 的 `upsert` 方法实现插入或更新关联:以 `(providerId, providerUserId)` 作为唯一键,写入 userId 和原始资料 `rawProfile`。如果此前该第三方账户曾绑定过别的本地用户(理论上不应发生,除非账号被合并),update 分支会更新绑定到当前用户ID。这一操作将日志记录输出,表示发生了帐号关联。 + * 之后复用找到的本地用户记录,跳至步骤4。 + +3. **创建新用户:** 如果既无绑定又未找到同邮箱用户,则视为新用户,执行本地用户的注册流程: + + * **生成唯一用户名:** 根据 OAuth 提供的用户信息确定用户名基础,如优先使用 `preferredUsername` 或昵称、姓名,均不可用则用 `user_{提供商用户ID}`。规范化该用户名:去除特殊字符、转为小写,并确保长度在合适范围内(如不足4字符则补前缀,超长则截断)。然后检查是否与现有用户重名,若是则在末尾追加递增数字后缀确保唯一。 + * **设定初始密码:** 生成一个安全随机密码赋给新用户,用于占位目的。实际可以不告诉用户该密码(用户走OAuth登录,不需要密码),但保留密码字段可允许日后通过“忘记密码”流程设置密码、或转为密码登录等。如果不希望OAuth用户有密码,可考虑设置一个不可用的随机哈希。同样地,Cheese-Auth对密码复杂度的要求在此场景下可放宽,因为用户不会手动使用这个密码。 + * **创建用户及关联:** 在数据库事务中同时创建 User、UserProfile(档案)以及 OAuth 关联记录。Prisma 事务确保这些操作要么全部成功要么全部回滚,保持数据一致。User 表中 email 字段可填写 OAuth 提供的邮箱(如有)或留空。UserProfile 可用第三方提供的姓名作为昵称等,avatar也可设为默认头像或以后扩展从 OAuth 拿头像URL。 + 随后插入 UserOAuthConnection 表,记录该 userId 与 providerId+providerUserId 的关联。也可以存储 `rawProfile` 完整的第三方返回资料供日后参考。(见下文数据库设计) + * **记录注册日志:** 标记此为 OAuth 用户注册(用于审计)并在日志中打印创建了新用户及其用户名。 + +4. **登录态创建:** 不论是已有用户还是新用户,至此都拿到了对应的本地 `User` 实体以及 `UserProfile`。接下来: + + * 记录一次用户登录日志(包含IP、UA等)以留存登录历史。 + * 创建应用内会话:调用 `SessionService.createSession(user.id)` 创建一个 Refresh Token。Cheese-Backend 原有 SessionService 很可能已经有创建和管理 refresh token 的逻辑,我们直接复用。Cheese-Auth 在 Prisma 中设计了 RefreshToken 存储或 JWT Payload 中携带 `validUntil` 用于判断有效期等,这里不需要改动原有逻辑。 + * 将User和Profile转为 UserDto 返回给上层,以便控制器封装响应。 + * 最终返回包含 UserDto 和 refreshToken 的元组,由控制器继续处理生成 JWT 等。 + +5. **维护用户OAuth关联表:** Cheese-Backend 需增加一个 **UserOAuthConnection** 表来存储用户与第三方账号的对应关系。可参考 Cheese-Auth 数据库模式: + + ```prisma + model UserOAuthConnection { + id Int @id @default(autoincrement()) + userId Int + providerId String // OAuth提供商ID,如 'ruc', 'google' + providerUserId String // 提供商侧用户唯一标识 + rawProfile Json? // 原始用户信息(JSON) + refreshToken String? // 可选,OAuth长效令牌 + tokenExpires DateTime? // 可选,OAuth令牌过期时间 + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + @@unique([providerId, providerUserId]) + @@map("user_oauth_connections") + } + ``` + + 该表以 (`providerId`, `providerUserId`) 组合唯一索引,保证同一个第三方账户只关联一个本地用户。在 Cheese-Backend Prisma schema 中加入此模型并通过迁移创建表,将使上述 `loginWithOAuth` 的数据库操作成立。字段 `refreshToken` 和 `tokenExpires` 目前在 Cheese-Auth 中未实际使用(因为 Cheese-Auth 并未实现请求第三方的新AccessToken的刷新逻辑,只依赖我们自己的 RefreshToken 维护会话)。但保留这些字段有利于将来扩展,如需要长期访问第三方 API 或检测 OAuth 凭据有效期,可以在 provider 实现中填充并利用这些字段。 + +综上,`usersService.loginWithOAuth` 实现了**OAuth 用户同步**:兼顾老用户绑定、邮箱匹配、新用户注册三种情况,确保第三方用户在本地有正确的身份,并生成应用内会话令牌以完成登录。Cheese-Auth 的完整实现可作为参考。在 Cheese-Backend 中应尽量复用现有 User/Session 机制(如 Password 登录用的 login 方法所做的记录和 token 发放逻辑),将 OAuth 登录结果接入相同的会话管理流程,实现统一的用户体验。 + +## 模块集成与兼容性 + +有了上述组件,实现时需注意模块组织和与现有系统的衔接: + +* **代码结构:** 建议在 Cheese-Backend `src/auth/` 下新建子目录(如 `oauth/`)存放 OAuth 模块相关代码。包括: + + * `oauth.types.ts` 或 `oauth-provider.interface.ts`:定义 OAuthProvider 接口及 OAuthUserInfo 等模型(可直接从 cheese-auth 的 `packages/oauth-provider-types` 复制接口定义)。 + * `oauth.service.ts`:实现 OAuthService 类,包含动态加载逻辑。可以整体借鉴 cheese-auth 的实现,并适当调整日志和错误处理方式以符合 Backend 风格。 + * `oauth.module.ts`:定义 NestJS 模块,利用 `@Module` 装饰器声明 providers 和 exports。可以提供一个静态方法 `register(providers?: OAuthProvider[])` 返回 DynamicModule,使外部可按需传入自定义 Provider 实例注册。同时在 `module.exports` 中导出 OAuthService,以便 UsersController 等处注入使用。 + * (可选)具体提供商实现:如果学校 OAuth 实现不是敏感代码,可直接放在如 `oauth/providers/ruc.provider.ts` 编码实现并通过 OAuthModule 静态注册;或者遵循插件机制将实现放在 `plugins/oauth/ruc.js` 并配置加载。当然由于我们不会将 cheese-auth 仓库作为依赖,在 Cheese-Backend 引入学校 OAuth 的代码是必要的,可从 cheese-auth 插件代码或相关文档获取实现细节,然后在 Backend 以本地模块方式集成。 + +* **模块导入:** 在 Cheese-Backend 的主应用模块或 AuthModule 中,按需引入 OAuthModule。例如,在 `AuthModule` 的 `imports` 列表中增加 `OAuthModule.register()`(如果没有额外 providers 参数则传空数组或默认)。这样 OAuthService 将在应用启动时初始化并加载提供商。同时确保 UsersModule(或 UsersController 所在模块)能够注入 OAuthService。例如如果 UsersModule imports 了 AuthModule,那么 AuthModule 导出 OAuthService 后 UsersController 构造函数即可通过 DI 拿到它。Cheese-Auth 即是在 UsersController 中注入了 OAuthService,因此我们保持相同的依赖关系即可。 + +* **路由集成:** 将上文列出的 `/users/auth/oauth/...` 路由添加到 UsersController。由于 Cheese-Backend 已有 UsersController,我们可以直接在该类中新增对应的方法(带 @Get 装饰器)。参考 cheese-auth 的 UsersController 增加相关段落等。注意控制器方法签名和注入的参数:callback 接口需要注入 `@Req()`或 NestJS 提供的 `@Ip()` 获取IP,以及 `@Headers('User-Agent')` 获取UA。同时通过 `@Res()` 手动构造响应,以便设置Cookie和重定向。 + + * **注意**:在使用 `@Res()` 时,NestJS会短路框架的默认响应处理,故方法需要手工 `return res.redirect(...)` 等。同样返回 JSON 时也需要手工 `res.json()` 或`res.cookie().json()`连缀。Cheese-Auth 示例中已经演示了如何设置 Cookie 后用 `res.redirect`。我们应确保路径和安全属性正确:Cookie 路径应结合配置的 `cookieBasePath` 和业务路由前缀,一般为`/users/auth`;`sameSite` 可以设为 `lax` 以允许第三方跳转携带(如果前后端不同域的话),`secure` 则依据是否HTTPS环境设置。 + +* **保持兼容性:** 现有 Cheese-Backend 登录/注册相关流程无需修改,依然通过 `/users/auth/login` 等完成用户名密码验证等。我们新增的接口在不使用时不会干扰原有逻辑。例如,如果部署时未配置任何 OAuth 提供商,`/users/auth/oauth/providers` 返回空列表,`/users/auth/oauth/login/:id` 会直接返回 404。因此,对于不需要 OAuth 的环境,可以完全不受影响。仅当运维在配置中启用了某个提供商,相关路由才真正发挥作用。 + + * 我们也不会改变 AuthService、UsersService 中原有的密码校验、2FA 验证等逻辑,只是在 UsersService 中增加 `loginWithOAuth` 方法,并在 UsersController 增加调用它的入口。因此,原有邮箱验证、密码重置、TOTP 两步验证等功能都与 OAuth 登录互不冲突,各自按需执行。 + * **前端配合:** 需要在前端加入对 providers 列表的获取和 OAuth 登录流程的支持(例如在登录界面提供“使用校园账号登录”等按钮,点击后调用后端 `/oauth/login/:provider` 接口重定向)。Cheese-Auth 返回 token 和 email 给前端时,是假定前端有相应页面读取URL参数并调用应用的登录成功流程(例如存储 JWT、显示用户邮箱等)。因此,在完成后端集成后,也要同步调整前端应用的 OAuth 回调处理逻辑,使用后端统一的参数名`token`和`email`。 + +## 配置项与依赖 + +成功的集成还需要正确设置环境配置和依赖库: + +* **环境变量配置:** 在 Cheese-Backend 的配置文件(如 `.env`)中增加 OAuth 支持相关的配置项: + + * `OAUTH_ENABLED_PROVIDERS`:启用的 OAuth 提供商ID列表,逗号分隔。例如:`OAUTH_ENABLED_PROVIDERS=ruc` 开启学校 OAuth 登录,或 `ruc,google` 同时开启多个。 + * `OAUTH_PLUGIN_PATHS`:插件搜索路径列表,逗号分隔。默认可用 `plugins/oauth`,如果编译后也可指定 `dist/oauth-providers` 等目录。Cheese-Auth 默认值就是 `plugins/oauth`。需确保这些路径在部署环境中存在并包含对应实现文件。 + * `OAUTH_ALLOW_NPM_LOADING`:是否允许从已安装的 npm 包加载提供商实现。默认建议为 false(出于安全),除非你计划通过依赖引入官方实现包。如果要使用我们打包的 `cheese-auth-ruc-oauth-provider` 等,可在安装包后将此值设为 true 以启用。 + * **各提供商凭据:** 对于每个提供商 ``,需要配置: + + * `OAUTH__CLIENT_ID` – OAuth客户端ID。 + * `OAUTH__CLIENT_SECRET` – OAuth客户端密钥。 + * `OAUTH__REDIRECT_URL` – 回调URL(一般指向上述 `/users/auth/oauth/callback/` 的完整外网地址)。 + 例如,学校提供商ID为 “ruc”,则需设置 `OAUTH_RUC_CLIENT_ID`、`OAUTH_RUC_CLIENT_SECRET`、`OAUTH_RUC_REDIRECT_URL` 三项。这些凭据通常由第三方OAuth服务提供,需要保密。 + * `FRONTEND_OAUTH_SUCCESS_PATH`:前端接收成功登录结果的页面路径,默认可设为 `/oauth-success`。Cheese-Auth默认此路径用于重定向。 + * `FRONTEND_OAUTH_ERROR_PATH`:前端处理登录失败的页面路径,默认 `/oauth-error`。 + * (上述两个路径会被拼接到已有的 `FRONTEND_BASE_URL` 后生成完整跳转地址。Cheese-Backend 原本已有 `FRONTEND_BASE_URL` 和 `COOKIE_BASE_PATH` 等配置。确保这些值正确,否则 OAuth 回调无法定位前端页面。) +* **依赖库:** 为实现 OAuth 流程,后端需要能向第三方服务器发HTTP请求并解析响应。Cheese-Auth 使用了 **Axios** 库来执行 OAuth Token请求和获取用户信息(在 `cheese-auth-ruc-oauth-provider` 包中引入了 axios)。因此,建议在 Cheese-Backend 项目中添加 Axios依赖(如果尚未有的话),或使用 `node-fetch`/Nest自带的 HttpService 等完成类似功能。由于 OAuth 通信多为 REST API,Axios 的使用相对直接。 + + * **注意**:Axios 默认返回 Promise,在 provider 的 `handleCallback` 和 `getUserInfo` 实现中会被 `await`,要做好异常捕获并转换为 OAuthError 以在控制器中捕获。(Cheese-Auth 定义了一套 OAuthError 类型用于区分错误类型,可选择性参考)。 +* **日志与调试:** 开发集成时,可临时提高 OAuthService 的日志等级为 debug,以便看到提供商模块加载的详细过程。部署到生产环境前,酌情降低日志级别避免泄露敏感信息(如 access token)。另外,务必在 HTTPS 环境下进行 OAuth 回调通讯,避免令牌在网络传输中被窃取。 + +## 注意事项与扩展展望 + +在实现和使用该 OAuth 集成方案时,还需留意以下事项: + +* **安全与状态校验:** OAuth “state”参数用于防止跨站请求伪造攻击。后端在重定向用户去第三方之前,可以生成并保存一个随机状态值,并在回调时比对确保请求合法。目前 cheese-auth 的实现将前端传来的 state 原样传给提供商并拿回来,但未校验。为了安全,建议改进:后端生成 state(或接受前端提供的state但同时在服务器session中存一份),在callback处理时验证 `state` 一致后再继续兑换令牌。 +* **用户体验:** 对于第一次使用学校 OAuth 登录的新用户,由于我们创建了随机密码,**建议**引导用户绑定常用登录方式(如设置密码或绑定邮箱)以防日后学校OAuth不可用时用户无法登录。可以在前端 OAuth 成功页提示用户尽快完善账户信息。 +* **私有部署与代码管理:** cheese-auth 建议将 `plugins/oauth` 目录加入 `.gitignore`避免私有实现泄露。如果 Cheese-Backend 也是开源仓库且不方便公开学校OAuth实现,可采取类似策略:将学校OAuth实现作为私有文件部署(通过文档指导运维放置),而不直接提交代码。当然,这需要权衡:把实现代码纳入仓库便于版本管理,但需确保不包含敏感信息(比如具体的学校OAuth端点可能是公开协议,无妨)。如果实现复杂,也可以考虑将其封装为独立 npm 包供内部安装使用。 +* **多 OAuth 提供商支持:** 方案已经支持多提供商并存。例如同时启用学校OAuth和GitHub/Google等。要注意在前端区分不同登录来源,以及确保不同提供商的 `redirectUrl` 配置正确(通常形如 `https://yourapp.com/users/auth/oauth/callback/google` 和 `.../callback/ruc` 等,各自注册到对应提供商的OAuth客户端配置中)。 +* **扩展 Refresh Token 支持:** 当前实现中,我们主要利用自己应用的 Refresh Token 来维持用户会话,而**未**使用第三方 OAuth 提供的 Refresh Token(如果有)。对于某些需要长期与第三方交互的场景,可以考虑在 UserOAuthConnection 表中保存 OAuth 的 refreshToken 和 tokenExpires(Cheese-Auth 数据模型已经设计了这些字段),并在 Access Token 过期时用 refresh token 去获取新 token。这超出了登录本身的需求,但为以后调用第三方API提供了可能性。 +* **保持原系统完整性:** 确保引入 OAuth 登录不会破坏原有认证流程的安全性。例如,原有 email 验证、TOTP二步验证在密码登录流程中,默认不会对 OAuth 登录用户触发,因为 OAuth 登录已经是外部身份验证,通常不需要再让用户输入密码或验证码。但如果有某些特定权限操作仍需二次验证,可考虑对 OAuth 用户也提供绑定2FA的选项。 +* **测试:** 在将功能投入生产前,应在测试环境针对各种情况进行充分测试: + + * 未配置提供商时,相关接口应正确返回错误或空列表。 + * 配置一个有效提供商时,完整流程(从点击登录按钮到前端收到 token)是否畅通。 + * 错误流程:包括用户在第三方拒绝授权、code 无效、用户已存在/未存在等分支逻辑是否正确处理和重定向。 + * 数据库检查:确保 UserOAuthConnection 记录正确写入,避免重复记录或错误关联。 + +通过以上步骤,Cheese-Backend 将成功集成来自 Cheese-Auth 的 OAuth 登录功能,实现**用户名/密码登录**与**学校自有 OAuth 登录**并存的认证机制。模块化的设计保证了易于维护和扩展:后续若要支持新的 OAuth 服务商,只需按规范添加实现和配置,无需改动核心代码。整个集成过程遵循 NestJS 的惯用方式,最大程度复用了既有系统组件,降低引入新功能的风险。 diff --git a/package.json b/package.json index f6760066..7d7d35ab 100644 --- a/package.json +++ b/package.json @@ -45,6 +45,7 @@ "@types/md5": "^2.3.5", "ajv": "^8.17.1", "async-mutex": "^0.5.0", + "axios": "^1.10.0", "bcryptjs": "^3.0.2", "class-transformer": "^0.5.1", "class-validator": "^0.14.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index f3c5ebbc..4bdb058b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -59,6 +59,9 @@ importers: async-mutex: specifier: ^0.5.0 version: 0.5.0 + axios: + specifier: ^1.10.0 + version: 1.10.0 bcryptjs: specifier: ^3.0.2 version: 3.0.2 @@ -2116,6 +2119,9 @@ packages: asynckit@0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} + axios@1.10.0: + resolution: {integrity: sha512-/1xYAC4MP/HEG+3duIhFr4ZQXR4sQXOIe+o6sdqzeykGLx6Upp/1p8MHqhINOvGeP7xyNHe7tsiJByc4SSVUxw==} + babel-jest@29.7.0: resolution: {integrity: sha512-BrvGY3xZSwEcCzKvKsCi2GgHqDqsYkOP4/by5xCgIwGXQxIEh+8ew3gmrE1y7XRR6LHZIj6yLYnUi/mm2KXKBg==} engines: {node: ^14.15.0 || ^16.10.0 || >=18.0.0} @@ -3088,6 +3094,15 @@ packages: resolution: {integrity: sha512-Be3narBNt2s6bsaqP6Jzq91heDgOEaDCJAXcE3qcma/EJBSy5FB4cvO31XBInuAuKBx8Kptf8dkhjK0IOru39Q==} engines: {node: '>=18'} + follow-redirects@1.15.9: + resolution: {integrity: sha512-gew4GsXizNgdoRyqmyfMHyAmXsZDk6mHkSxZFCzW9gwlbtOW44CDtYavM+y+72qD/Vq2l550kMF52DT8fOLJqQ==} + engines: {node: '>=4.0'} + peerDependencies: + debug: '*' + peerDependenciesMeta: + debug: + optional: true + foreground-child@3.2.1: resolution: {integrity: sha512-PXUUyLqrR2XCWICfv6ukppP96sdFwWbNEnfEMt7jNsISjMsvaLNinAHNDYyvkyU+SZG2BTSbT5NjG+vZslfGTA==} engines: {node: '>=14'} @@ -4703,6 +4718,9 @@ packages: resolution: {integrity: sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==} engines: {node: '>= 0.10'} + proxy-from-env@1.1.0: + resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==} + pug-attrs@3.0.0: resolution: {integrity: sha512-azINV9dUtzPMFQktvTXciNAfAuVh/L/JCl0vtPCwvOA21uZrC08K/UnmrL+SXGEVc1FwzjW62+xw5S/uaLj6cA==} @@ -8154,6 +8172,14 @@ snapshots: asynckit@0.4.0: {} + axios@1.10.0: + dependencies: + follow-redirects: 1.15.9 + form-data: 4.0.0 + proxy-from-env: 1.1.0 + transitivePeerDependencies: + - debug + babel-jest@29.7.0(@babel/core@7.26.10): dependencies: '@babel/core': 7.26.10 @@ -9288,6 +9314,8 @@ snapshots: async: 0.2.10 which: 1.3.1 + follow-redirects@1.15.9: {} + foreground-child@3.2.1: dependencies: cross-spawn: 7.0.6 @@ -11358,6 +11386,8 @@ snapshots: forwarded: 0.2.0 ipaddr.js: 1.9.1 + proxy-from-env@1.1.0: {} + pug-attrs@3.0.0: dependencies: constantinople: 4.0.1 diff --git a/prisma/migrations/20250625052726_add_oauth_tables/migration.sql b/prisma/migrations/20250625052726_add_oauth_tables/migration.sql new file mode 100644 index 00000000..1131e079 --- /dev/null +++ b/prisma/migrations/20250625052726_add_oauth_tables/migration.sql @@ -0,0 +1,23 @@ +-- CreateTable +CREATE TABLE "user_o_auth_connection" ( + "id" SERIAL NOT NULL, + "user_id" INTEGER NOT NULL, + "provider_id" TEXT NOT NULL, + "provider_user_id" TEXT NOT NULL, + "raw_profile" JSONB, + "refresh_token" TEXT, + "token_expires" TIMESTAMP(3), + "created_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMPTZ(6) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "user_o_auth_connection_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE INDEX "user_o_auth_connection_user_id_idx" ON "user_o_auth_connection"("user_id"); + +-- CreateIndex +CREATE UNIQUE INDEX "user_o_auth_connection_provider_id_provider_user_id_key" ON "user_o_auth_connection"("provider_id", "provider_user_id"); + +-- AddForeignKey +ALTER TABLE "user_o_auth_connection" ADD CONSTRAINT "user_o_auth_connection_user_id_fkey" FOREIGN KEY ("user_id") REFERENCES "user"("id") ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/prisma/migrations/migration_lock.toml b/prisma/migrations/migration_lock.toml index fbffa92c..044d57cd 100644 --- a/prisma/migrations/migration_lock.toml +++ b/prisma/migrations/migration_lock.toml @@ -1,3 +1,3 @@ # Please do not edit this file manually -# It should be added in your version-control system (i.e. Git) -provider = "postgresql" \ No newline at end of file +# It should be added in your version-control system (e.g., Git) +provider = "postgresql" diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 45978c60..0dc4b608 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -192,6 +192,28 @@ model AttitudeLog { @@map("attitude_log") } +// +// oauth.prisma +// + +model UserOAuthConnection { + id Int @id @default(autoincrement()) + userId Int @map("user_id") + providerId String @map("provider_id") // OAuth提供商ID,如 'ruc', 'google' + providerUserId String @map("provider_user_id") // 提供商侧用户唯一标识 + /// [rawProfileType] + rawProfile Json? @map("raw_profile") // 原始用户信息(JSON) + refreshToken String? @map("refresh_token") // 可选,OAuth长效令牌 + tokenExpires DateTime? @map("token_expires") // 可选,OAuth令牌过期时间 + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @default(now()) @updatedAt @map("updated_at") @db.Timestamptz(6) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([providerId, providerUserId]) + @@index([userId]) + @@map("user_o_auth_connection") +} + // // session.prisma // @@ -652,6 +674,7 @@ model User { totpEnabled Boolean @default(false) @map("totp_enabled") totpAlwaysRequired Boolean @default(false) @map("totp_always_required") backupCodes UserBackupCode[] + oauthConnections UserOAuthConnection[] @@map("user") } diff --git a/sample.env b/sample.env index 4aceab9a..a2dc8a3f 100644 --- a/sample.env +++ b/sample.env @@ -85,4 +85,24 @@ TOTP_ENCRYPTION_KEY=sm2vSXU3SudEuBd2r6ewGiap1LbqGbjf # Disable email verification. # Should be used only for development or testing. -# DISABLE_EMAIL_VERIFICATION=true \ No newline at end of file +# DISABLE_EMAIL_VERIFICATION=true + +# OAuth configuration +# Comma-separated list of enabled OAuth providers +# OAUTH_ENABLED_PROVIDERS=ruc + +# OAuth plugin paths (comma-separated) +# OAUTH_PLUGIN_PATHS=plugins/oauth + +# Allow loading OAuth providers from npm packages +# OAUTH_ALLOW_NPM_LOADING=false + +# OAuth provider credentials (example for 'ruc' provider) +# Replace 'RUC' with your provider ID in uppercase +# OAUTH_RUC_CLIENT_ID=your-client-id +# OAUTH_RUC_CLIENT_SECRET=your-client-secret +# OAUTH_RUC_REDIRECT_URL=http://localhost:3000/users/auth/oauth/callback/ruc + +# Frontend OAuth redirect paths +# FRONTEND_OAUTH_SUCCESS_PATH=/oauth-success +# FRONTEND_OAUTH_ERROR_PATH=/oauth-error \ No newline at end of file diff --git a/src/auth/auth.module.ts b/src/auth/auth.module.ts index 88884dc5..b82610a4 100644 --- a/src/auth/auth.module.ts +++ b/src/auth/auth.module.ts @@ -13,6 +13,7 @@ import { JwtModule } from '@nestjs/jwt'; import { PrismaModule } from '../common/prisma/prisma.module'; import { AuthService } from './auth.service'; import { SessionService } from './session.service'; +import { OAuthModule } from './oauth/oauth.module'; @Module({ imports: [ @@ -30,9 +31,10 @@ import { SessionService } from './session.service'; inject: [ConfigService], }), PrismaModule, + OAuthModule.register(), ], controllers: [], providers: [AuthService, SessionService], - exports: [AuthService, SessionService], + exports: [AuthService, SessionService, OAuthModule], }) export class AuthModule {} diff --git a/src/auth/oauth/oauth.module.spec.ts b/src/auth/oauth/oauth.module.spec.ts new file mode 100644 index 00000000..6191b6cc --- /dev/null +++ b/src/auth/oauth/oauth.module.spec.ts @@ -0,0 +1,194 @@ +/* + * Description: Unit tests for OAuth Module + * + * Author(s): + * HuanCheng65 + */ + +import { ConfigModule, ConfigService } from '@nestjs/config'; +import { Test, TestingModule } from '@nestjs/testing'; +import { OAuthModule } from './oauth.module'; +import { OAuthService } from './oauth.service'; + +describe('OAuthModule', () => { + let module: TestingModule; + let oauthService: OAuthService; + + beforeEach(async () => { + module = await Test.createTestingModule({ + providers: [ + OAuthService, + { + provide: ConfigService, + useValue: { + get: jest.fn().mockReturnValue(''), + }, + }, + ], + }).compile(); + + oauthService = module.get(OAuthService); + }); + + afterEach(async () => { + if (module) { + await module.close(); + } + }); + + describe('module initialization', () => { + it('should be defined', () => { + expect(module).toBeDefined(); + }); + + it('should provide OAuthService', () => { + expect(oauthService).toBeDefined(); + expect(oauthService).toBeInstanceOf(OAuthService); + }); + + it('should call initialize on module init', async () => { + const initializeSpy = jest + .spyOn(oauthService, 'initialize') + .mockResolvedValue(); + + // Test onModuleInit directly + const moduleInstance = new OAuthModule(oauthService); + await moduleInstance.onModuleInit(); + + expect(initializeSpy).toHaveBeenCalled(); + + initializeSpy.mockRestore(); + }); + }); + + describe('dynamic module registration', () => { + it('should register module with OAuth providers', async () => { + // Create mock OAuth providers + const mockProvider1 = { + getConfig: jest.fn().mockReturnValue({ + id: 'mock1', + name: 'Mock Provider 1', + clientId: 'client1', + clientSecret: 'secret1', + redirectUrl: 'http://localhost:3000/callback/mock1', + authorizationUrl: 'https://mock1.com/oauth/authorize', + tokenUrl: 'https://mock1.com/oauth/token', + scope: ['read:user'], + }), + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn(), + getUserInfo: jest.fn(), + }; + + const mockProvider2 = { + getConfig: jest.fn().mockReturnValue({ + id: 'mock2', + name: 'Mock Provider 2', + clientId: 'client2', + clientSecret: 'secret2', + redirectUrl: 'http://localhost:3000/callback/mock2', + authorizationUrl: 'https://mock2.com/oauth/authorize', + tokenUrl: 'https://mock2.com/oauth/token', + scope: ['read:profile'], + }), + getAuthorizationUrl: jest.fn(), + handleCallback: jest.fn(), + getUserInfo: jest.fn(), + }; + + const customProviders = [mockProvider1, mockProvider2]; + const dynamicModule = OAuthModule.register(customProviders); + + expect(dynamicModule.module).toBe(OAuthModule); + expect(dynamicModule.providers).toEqual( + expect.arrayContaining([ + OAuthService, + { + provide: 'OAUTH_PROVIDER_0', + useValue: mockProvider1, + }, + { + provide: 'OAUTH_PROVIDER_1', + useValue: mockProvider2, + }, + ]), + ); + expect(dynamicModule.exports).toEqual([OAuthService]); + expect(dynamicModule.imports).toEqual([ConfigModule]); + }); + + it('should register module without custom providers', () => { + const dynamicModule = OAuthModule.register(); + + expect(dynamicModule.module).toBe(OAuthModule); + expect(dynamicModule.providers).toEqual([OAuthService]); + expect(dynamicModule.exports).toEqual([OAuthService]); + expect(dynamicModule.imports).toEqual([ConfigModule]); + }); + + it('should register module with empty providers array', () => { + const dynamicModule = OAuthModule.register([]); + + expect(dynamicModule.module).toBe(OAuthModule); + expect(dynamicModule.providers).toEqual([OAuthService]); + expect(dynamicModule.exports).toEqual([OAuthService]); + expect(dynamicModule.imports).toEqual([ConfigModule]); + }); + }); + + describe('service integration', () => { + it('should initialize service when module starts', async () => { + const initializeSpy = jest.spyOn(oauthService, 'initialize'); + + // Mock the initialize method to avoid actual initialization + initializeSpy.mockResolvedValue(); + + // Test that calling onModuleInit triggers initialize + const moduleInstance = new OAuthModule(oauthService); + await moduleInstance.onModuleInit(); + + expect(initializeSpy).toHaveBeenCalled(); + + initializeSpy.mockRestore(); + }); + + it('should handle initialization errors gracefully', async () => { + const initializeSpy = jest.spyOn(oauthService, 'initialize'); + + // Mock initialize to throw an error + initializeSpy.mockRejectedValue(new Error('Initialization failed')); + + const moduleInstance = new OAuthModule(oauthService); + + // This should propagate the error + await expect(moduleInstance.onModuleInit()).rejects.toThrow( + 'Initialization failed', + ); + + initializeSpy.mockRestore(); + }); + }); + + describe('module configuration', () => { + it('should have OAuthModule class defined', () => { + // Test that the module class is properly defined + expect(OAuthModule).toBeDefined(); + expect(typeof OAuthModule).toBe('function'); + expect(OAuthModule.name).toBe('OAuthModule'); + }); + + it('should register dynamic module correctly', () => { + const dynamicModule = OAuthModule.register(); + + expect(dynamicModule.module).toBe(OAuthModule); + expect(dynamicModule.providers).toContain(OAuthService); + expect(dynamicModule.exports).toContain(OAuthService); + expect(dynamicModule.imports).toContain(ConfigModule); + }); + + it('should be able to get OAuthService from module', () => { + expect(oauthService).toBeDefined(); + expect(oauthService).toBeInstanceOf(OAuthService); + }); + }); +}); diff --git a/src/auth/oauth/oauth.module.ts b/src/auth/oauth/oauth.module.ts new file mode 100644 index 00000000..7b99b751 --- /dev/null +++ b/src/auth/oauth/oauth.module.ts @@ -0,0 +1,39 @@ +/* + * Description: OAuth Module - 提供 OAuth 功能的 NestJS 模块 + * + * Author(s): + * HuanCheng65 + */ + +import { DynamicModule, Module, OnModuleInit } from '@nestjs/common'; +import { ConfigModule } from '@nestjs/config'; +import { OAuthService } from './oauth.service'; +import { OAuthProvider } from './oauth.types'; + +@Module({ + imports: [ConfigModule], + providers: [OAuthService], + exports: [OAuthService], +}) +export class OAuthModule implements OnModuleInit { + constructor(private readonly oauthService: OAuthService) {} + + async onModuleInit() { + await this.oauthService.initialize(); + } + + static register(providers: OAuthProvider[] = []): DynamicModule { + return { + module: OAuthModule, + imports: [ConfigModule], + providers: [ + OAuthService, + ...providers.map((provider, index) => ({ + provide: `OAUTH_PROVIDER_${index}`, + useValue: provider, + })), + ], + exports: [OAuthService], + }; + } +} diff --git a/src/auth/oauth/oauth.prisma b/src/auth/oauth/oauth.prisma new file mode 100644 index 00000000..c9aa2420 --- /dev/null +++ b/src/auth/oauth/oauth.prisma @@ -0,0 +1,19 @@ +import { User } from "../../users/users" + +model UserOAuthConnection { + id Int @id @default(autoincrement()) + userId Int @map("user_id") + providerId String @map("provider_id") // OAuth提供商ID,如 'ruc', 'google' + providerUserId String @map("provider_user_id") // 提供商侧用户唯一标识 + /// [rawProfileType] + rawProfile Json? @map("raw_profile") // 原始用户信息(JSON) + refreshToken String? @map("refresh_token") // 可选,OAuth长效令牌 + tokenExpires DateTime? @map("token_expires") // 可选,OAuth令牌过期时间 + createdAt DateTime @default(now()) @map("created_at") @db.Timestamptz(6) + updatedAt DateTime @default(now()) @updatedAt @map("updated_at") @db.Timestamptz(6) + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@unique([providerId, providerUserId]) + @@index([userId]) + @@map("user_oauth_connection") +} \ No newline at end of file diff --git a/src/auth/oauth/oauth.service.spec.ts b/src/auth/oauth/oauth.service.spec.ts new file mode 100644 index 00000000..6be6daa4 --- /dev/null +++ b/src/auth/oauth/oauth.service.spec.ts @@ -0,0 +1,1591 @@ +/* + * Description: Unit tests for OAuth Service + * + * Author(s): + * HuanCheng65 + */ + +import { ConfigService } from '@nestjs/config'; +import { Test, TestingModule } from '@nestjs/testing'; +import { OAuthService } from './oauth.service'; +import { + OAuthProvider, + OAuthProviderConfig, + OAuthUserInfo, +} from './oauth.types'; + +// Mock fs module +jest.mock('fs', () => ({ + existsSync: jest.fn(), +})); + +import * as fs from 'fs'; + +// Mock provider for testing +class MockOAuthProvider implements OAuthProvider { + constructor(private config: OAuthProviderConfig) {} + + getConfig(): OAuthProviderConfig { + return this.config; + } + + getAuthorizationUrl(state?: string, accessType?: string): string { + const params = new URLSearchParams({ + client_id: this.config.clientId, + redirect_uri: this.config.redirectUrl, + scope: this.config.scope.join(' '), + response_type: 'code', + }); + + if (state) params.append('state', state); + if (accessType) params.append('access_type', accessType); + + return `${this.config.authorizationUrl}?${params.toString()}`; + } + + async handleCallback(code: string, state?: string): Promise { + if (code === 'valid_code') { + return 'mock_access_token'; + } + throw new Error('Invalid authorization code'); + } + + async getUserInfo(accessToken: string): Promise { + if (accessToken === 'mock_access_token') { + return { + id: '12345', + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }; + } + throw new Error('Invalid access token'); + } +} + +describe('OAuthService', () => { + let service: OAuthService; + let configService: ConfigService; + let mockCreateProvider: jest.Mock; + + beforeEach(async () => { + // Reset mocks + jest.clearAllMocks(); + jest.resetModules(); + + // Setup mock provider factory + mockCreateProvider = jest.fn().mockReturnValue( + new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }), + ); + + const module: TestingModule = await Test.createTestingModule({ + providers: [ + OAuthService, + { + provide: ConfigService, + useValue: { + get: jest.fn((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_PLUGIN_PATHS': + return './test-plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return false; + case 'OAUTH_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }), + }, + }, + ], + }).compile(); + + service = module.get(OAuthService); + configService = module.get(ConfigService); + }); + + afterEach(() => { + jest.clearAllMocks(); + jest.resetAllMocks(); + // Reset service state to ensure test isolation + (service as any).initialized = false; + (service as any).providers.clear(); + }); + + describe('initialize', () => { + it('should initialize without providers when none are enabled', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return ''; + return undefined; + }); + + await service.initialize(); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + }); + + it('should skip provider loading when config is missing', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return 'test'; + if (key === 'OAUTH_TEST_CLIENT_ID') return undefined; // Missing config + if (key === 'OAUTH_PLUGIN_PATHS') return './test-plugins'; + return undefined; + }); + + // Spy on the service's logger warn method + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Missing configuration for OAuth provider.+test/, + ), + ); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should load provider from plugin file', async () => { + // Mock fs.existsSync to return true for plugin file + (fs.existsSync as jest.Mock).mockReturnValue(true); + + // Mock the loadProvider method to simulate successful loading + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + // Mock the loadProvider method to directly register the provider + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + + await service.initialize(); + + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toContain('test'); + }); + + it('should handle provider loading errors gracefully', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(true); + + // Spy on the service's logger error method + const loggerErrorSpy = jest + .spyOn(service['logger'], 'error') + .mockImplementation(() => {}); + + try { + // Mock loadProvider to throw an error + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + throw new Error('Module loading failed'); + }); + + await service.initialize(); + + expect(loggerErrorSpy).toHaveBeenCalledWith( + expect.stringContaining('Failed to load OAuth provider'), + expect.any(String), // error.stack + ); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + } finally { + loggerErrorSpy.mockRestore(); + } + }); + + it('should validate provider ID format', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return 'invalid-provider!'; + return undefined; + }); + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation(); + + // Call getEnabledProviders which filters invalid IDs + await service.initialize(); + + // The service should warn about invalid provider ID during getEnabledProviders filtering + // Since the service filters out invalid IDs during initialization, no providers should be loaded + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + + consoleSpy.mockRestore(); + }); + }); + + describe('getProvider', () => { + beforeEach(async () => { + // Setup a mock provider + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + await service.initialize(); + }); + + it('should return provider when it exists', async () => { + const provider = await service.getProvider('test'); + expect(provider).toBeInstanceOf(MockOAuthProvider); + }); + + it('should return undefined when provider does not exist', async () => { + const provider = await service.getProvider('nonexistent'); + expect(provider).toBeUndefined(); + }); + }); + + describe('generateAuthorizationUrl', () => { + beforeEach(async () => { + // Setup a mock provider + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + await service.initialize(); + }); + + it('should generate authorization URL', async () => { + const url = await service.generateAuthorizationUrl('test', 'state123'); + expect(url).toContain('https://test.com/oauth/authorize'); + expect(url).toContain('client_id=test-client-id'); + expect(url).toContain('state=state123'); + }); + + it('should throw error for nonexistent provider', async () => { + await expect( + service.generateAuthorizationUrl('nonexistent'), + ).rejects.toThrow("OAuth provider 'nonexistent' not found"); + }); + }); + + describe('handleCallback', () => { + beforeEach(async () => { + // Setup a mock provider + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + await service.initialize(); + }); + + it('should handle callback successfully', async () => { + const accessToken = await service.handleCallback('test', 'valid_code'); + expect(accessToken).toBe('mock_access_token'); + }); + + it('should throw error for invalid code', async () => { + await expect( + service.handleCallback('test', 'invalid_code'), + ).rejects.toThrow('Invalid authorization code'); + }); + + it('should throw error for nonexistent provider', async () => { + await expect( + service.handleCallback('nonexistent', 'code'), + ).rejects.toThrow("OAuth provider 'nonexistent' not found"); + }); + }); + + describe('getAllProviders', () => { + it('should return empty array when no providers are loaded', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return ''; + return undefined; + }); + + await service.initialize(); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + }); + + it('should return provider IDs when providers are loaded', async () => { + // Setup mock providers + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + + await service.initialize(); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual(['test']); + }); + }); + + describe('getUserInfo', () => { + beforeEach(async () => { + // Setup a mock provider + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + await service.initialize(); + }); + + it('should return user info for valid provider and token', async () => { + const userInfo = await service.getUserInfo('test', 'mock_access_token'); + + expect(userInfo).toEqual({ + id: '12345', + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }); + }); + + it('should throw error for nonexistent provider', async () => { + await expect(service.getUserInfo('nonexistent', 'token')).rejects.toThrow( + "OAuth provider 'nonexistent' not found", + ); + }); + + it('should throw error for invalid token', async () => { + await expect( + service.getUserInfo('test', 'invalid_token'), + ).rejects.toThrow('Failed to get user info from provider'); + }); + }); + + describe('getProvidersConfig', () => { + it('should return empty array when no providers are loaded', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return ''; + return undefined; + }); + + await service.initialize(); + const configs = await service.getProvidersConfig(); + expect(configs).toEqual([]); + }); + + it('should return provider configurations when providers are loaded', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user', 'read:email'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args) => { + const providerId = args[0] as string; + (service as any).registerProvider(providerId, mockProvider); + }); + + await service.initialize(); + const configs = await service.getProvidersConfig(); + + expect(configs).toEqual([ + { + id: 'test', + name: 'Test Provider', + scope: ['read:user', 'read:email'], + }, + ]); + }); + + it('should return multiple provider configurations', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'github,google'; + case 'OAUTH_PLUGIN_PATHS': + return './test-plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return false; + case 'OAUTH_GITHUB_CLIENT_ID': + return 'github-client-id'; + case 'OAUTH_GITHUB_CLIENT_SECRET': + return 'github-client-secret'; + case 'OAUTH_GITHUB_REDIRECT_URL': + return 'http://localhost:3000/callback/github'; + case 'OAUTH_GOOGLE_CLIENT_ID': + return 'google-client-id'; + case 'OAUTH_GOOGLE_CLIENT_SECRET': + return 'google-client-secret'; + case 'OAUTH_GOOGLE_REDIRECT_URL': + return 'http://localhost:3000/callback/google'; + default: + return undefined; + } + }); + + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const mockGithubProvider = new MockOAuthProvider({ + id: 'github', + name: 'GitHub', + clientId: 'github-client-id', + clientSecret: 'github-client-secret', + redirectUrl: 'http://localhost:3000/callback/github', + authorizationUrl: 'https://github.com/login/oauth/authorize', + tokenUrl: 'https://github.com/login/oauth/access_token', + scope: ['user:email'], + }); + + const mockGoogleProvider = new MockOAuthProvider({ + id: 'google', + name: 'Google', + clientId: 'google-client-id', + clientSecret: 'google-client-secret', + redirectUrl: 'http://localhost:3000/callback/google', + authorizationUrl: 'https://accounts.google.com/oauth2/authorize', + tokenUrl: 'https://oauth2.googleapis.com/token', + scope: ['openid', 'email', 'profile'], + }); + + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args: unknown[]) => { + const providerId = args[0] as string; + if (providerId === 'github') { + (service as any).registerProvider(providerId, mockGithubProvider); + } else if (providerId === 'google') { + (service as any).registerProvider(providerId, mockGoogleProvider); + } + }); + + await service.initialize(); + const configs = await service.getProvidersConfig(); + + expect(configs).toHaveLength(2); + expect(configs.find((c) => c.id === 'github')).toEqual({ + id: 'github', + name: 'GitHub', + scope: ['user:email'], + }); + expect(configs.find((c) => c.id === 'google')).toEqual({ + id: 'google', + name: 'Google', + scope: ['openid', 'email', 'profile'], + }); + }); + }); + + describe('NPM package loading', () => { + beforeEach(() => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'npm-provider'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return true; + case 'OAUTH_NPM_PROVIDER_CLIENT_ID': + return 'npm-client-id'; + case 'OAUTH_NPM_PROVIDER_CLIENT_SECRET': + return 'npm-client-secret'; + case 'OAUTH_NPM_PROVIDER_REDIRECT_URL': + return 'http://localhost:3000/callback/npm'; + default: + return undefined; + } + }); + }); + + it('should load provider from npm when plugin not found locally', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock dynamic import for npm package + const mockNpmProvider = new MockOAuthProvider({ + id: 'npm-provider', + name: 'NPM Provider', + clientId: 'npm-client-id', + clientSecret: 'npm-client-secret', + redirectUrl: 'http://localhost:3000/callback/npm', + authorizationUrl: 'https://npm.com/oauth/authorize', + tokenUrl: 'https://npm.com/oauth/token', + scope: ['read:user'], + }); + + // Mock the entire loadProvider method to register the provider directly + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async (...args: unknown[]) => { + const providerId = args[0] as string; + if (providerId === 'npm-provider') { + (service as any).registerProvider(providerId, mockNpmProvider); + } + }); + + await service.initialize(); + const providers = await service.getAllProviders(); + + expect(providers.map((p) => p.getConfig().id)).toContain('npm-provider'); + }); + + it('should handle npm package loading failure gracefully', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock tryLoadFromNpm to return null (loading failed) + jest.spyOn(service as any, 'tryLoadFromNpm').mockResolvedValue(null); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Missing configuration for OAuth provider.+npm-provider/, + ), + ); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should not try npm loading when disabled', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'npm-provider'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return false; // Disabled + case 'OAUTH_NPM_PROVIDER_CLIENT_ID': + return 'npm-client-id'; + case 'OAUTH_NPM_PROVIDER_CLIENT_SECRET': + return 'npm-client-secret'; + case 'OAUTH_NPM_PROVIDER_REDIRECT_URL': + return 'http://localhost:3000/callback/npm'; + default: + return undefined; + } + }); + + (fs.existsSync as jest.Mock).mockReturnValue(false); + + const tryLoadFromNpmSpy = jest + .spyOn(service as any, 'tryLoadFromNpm') + .mockResolvedValue(null); + + await service.initialize(); + + expect(tryLoadFromNpmSpy).not.toHaveBeenCalled(); + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + }); + }); + + describe('error handling and edge cases', () => { + it('should handle multiple initialization calls gracefully', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return ''; + return undefined; + }); + + // Call initialize multiple times + await service.initialize(); + await service.initialize(); + await service.initialize(); + + // Should only initialize once + expect(await service.getAllProviders()).toEqual([]); + }); + + it('should use default plugin path when OAUTH_PLUGIN_PATHS is undefined', async () => { + // Test the getPluginPaths method directly + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_PLUGIN_PATHS') return undefined; + return undefined; + }); + + const pluginPaths = (service as any).getPluginPaths(); + expect(pluginPaths).toEqual(['plugins/oauth']); + }); + + it('should handle provider with missing redirect URL', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return undefined; // Missing redirect URL + default: + return undefined; + } + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Missing configuration for OAuth provider.+test/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + }); + + describe('security validations', () => { + it('should prevent path traversal in plugin loading', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return '../../../malicious'; + if (key === 'OAUTH_PLUGIN_PATHS') return './plugins'; + return undefined; + }); + + await service.initialize(); + + // Should filter out invalid provider ID and not load any providers + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + }); + + it('should validate provider ID contains only safe characters', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return 'test', + 'test"onload="alert(1)"', + "test'onclick=alert(1)", + 'test--drop-table', + 'test;rm -rf /', + ]; + + for (const maliciousId of maliciousIds) { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return maliciousId; + return undefined; + }); + + await service.initialize(); + + expect( + (await service.getAllProviders()).map((p) => p.getConfig().id), + ).toEqual([]); + } + }); + + it('should handle path traversal attempts in plugin paths', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_PLUGIN_PATHS': + return '../../../etc,/etc/passwd,./plugins'; + case 'OAUTH_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + (fs.existsSync as jest.Mock).mockReturnValue(true); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + // Mock loadProvider to simulate path validation + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async () => { + // Simulate path traversal detection + service['logger'].warn('Potential path traversal detected'); + }); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Potential path traversal detected'), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should detect actual path traversal in tryLoadFromPath', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + // Mock fs.existsSync to simulate file exists but outside allowed path + (fs.existsSync as jest.Mock).mockReturnValue(true); + + // Mock path.resolve to simulate path traversal + const originalResolve = require('path').resolve; + jest + .spyOn(require('path'), 'resolve') + .mockImplementation((...args: unknown[]) => { + const pathStr = args[0] as string; + if (pathStr && pathStr.includes('test')) { + return '/etc/passwd'; // Simulate path outside of allowed directory + } + return originalResolve(...(args as string[])); + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching(/Potential path traversal detected/), + ); + } finally { + loggerWarnSpy.mockRestore(); + jest.restoreAllMocks(); + } + }); + }); + + describe('tryLoadFromPath edge cases', () => { + beforeEach(() => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + }); + + it('should handle invalid provider module (not a function)', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(true); + + // Mock dynamic import to return an object that's not a function + const mockImport = jest.fn().mockResolvedValue({ + createProvider: 'not a function', + default: 'also not a function', + }); + + // Mock the import function using jest.spyOn + const originalImport = jest.requireActual('fs'); + jest.doMock('path', () => ({ + ...jest.requireActual('path'), + resolve: jest.fn().mockReturnValue('/test/path'), + })); + + // Mock the tryLoadFromPath method directly + jest + .spyOn(service as any, 'tryLoadFromPath') + .mockImplementation(async () => { + // Simulate invalid module loading + const loggerWarnSpy = jest.spyOn(service['logger'], 'warn'); + loggerWarnSpy.mockImplementation(() => {}); + service['logger'].warn( + "Invalid provider module for 'test': expected function", + ); + return null; + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Invalid provider module for.+test.+expected function/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should successfully load provider from file when found', async () => { + const mockProvider = new MockOAuthProvider({ + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + (fs.existsSync as jest.Mock).mockReturnValue(true); + + // Mock successful import + const originalImport = jest.fn(); + const mockModule = { + createProvider: jest.fn().mockReturnValue(mockProvider), + }; + + // Mock the actual import call + jest + .spyOn(service as any, 'tryLoadFromPath') + .mockImplementation(async (providerId, pluginPath, config) => { + // Simulate successful file loading + const loggerLogSpy = jest.spyOn(service['logger'], 'log'); + const result = mockProvider; + return result; + }); + + const loggerLogSpy = jest + .spyOn(service['logger'], 'log') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerLogSpy).toHaveBeenCalledWith( + expect.stringMatching(/Registered OAuth provider: test/), + ); + expect(loggerLogSpy).toHaveBeenCalledWith( + expect.stringMatching(/OAuth service initialized with.+1.+providers/), + ); + } finally { + loggerLogSpy.mockRestore(); + } + }); + + it('should handle module import errors gracefully', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(true); + + // Mock tryLoadFromPath to simulate import error + jest + .spyOn(service as any, 'tryLoadFromPath') + .mockImplementation(async () => { + // Simulate debug logging for import failure + service['logger'].debug( + "Failed to load provider 'test' from /test/path: Import failed", + ); + return null; + }); + + const loggerDebugSpy = jest + .spyOn(service['logger'], 'debug') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerDebugSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Failed to load provider.+test.+from.+Import failed/, + ), + ); + } finally { + loggerDebugSpy.mockRestore(); + } + }); + }); + + describe('tryLoadFromNpm edge cases', () => { + beforeEach(() => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'npmtest'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return true; + // Note: The config uses provider ID converted to uppercase + case 'OAUTH_NPMTEST_CLIENT_ID': + return 'npm-client-id'; + case 'OAUTH_NPMTEST_CLIENT_SECRET': + return 'npm-client-secret'; + case 'OAUTH_NPMTEST_REDIRECT_URL': + return 'http://localhost:3000/callback/npm'; + default: + return undefined; + } + }); + }); + + it('should handle invalid npm module (not a function)', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock tryLoadFromPath to return null (not found locally) + jest.spyOn(service as any, 'tryLoadFromPath').mockResolvedValue(null); + + // Mock tryLoadFromNpm to simulate invalid module + jest + .spyOn(service as any, 'tryLoadFromNpm') + .mockImplementation(async () => { + service['logger'].warn( + "Invalid npm provider package for 'npmtest': expected function", + ); + return null; + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Invalid npm provider package for.+npmtest.+expected function/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should successfully load provider from npm package', async () => { + const mockProvider = new MockOAuthProvider({ + id: 'npmtest', + name: 'NPM Test Provider', + clientId: 'npm-client-id', + clientSecret: 'npm-client-secret', + redirectUrl: 'http://localhost:3000/callback/npm', + authorizationUrl: 'https://npm.com/oauth/authorize', + tokenUrl: 'https://npm.com/oauth/token', + scope: ['read:user'], + }); + + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock tryLoadFromPath to return null (not found locally) + jest.spyOn(service as any, 'tryLoadFromPath').mockResolvedValue(null); + + // Mock successful npm loading + jest + .spyOn(service as any, 'tryLoadFromNpm') + .mockResolvedValue(mockProvider); + + const loggerLogSpy = jest + .spyOn(service['logger'], 'log') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerLogSpy).toHaveBeenCalledWith( + expect.stringMatching(/Registered OAuth provider: npmtest/), + ); + expect(loggerLogSpy).toHaveBeenCalledWith( + expect.stringMatching(/OAuth service initialized with.+1.+providers/), + ); + } finally { + loggerLogSpy.mockRestore(); + } + }); + + it('should handle npm import errors gracefully', async () => { + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock tryLoadFromPath to return null (not found locally) + jest.spyOn(service as any, 'tryLoadFromPath').mockResolvedValue(null); + + // Mock tryLoadFromNpm to simulate import error + jest + .spyOn(service as any, 'tryLoadFromNpm') + .mockImplementation(async () => { + service['logger'].debug( + "Failed to load provider 'npmtest' from npm package '@sageseekersociety/cheese-auth-npmtest-oauth-provider': Package not found", + ); + return null; + }); + + const loggerDebugSpy = jest + .spyOn(service['logger'], 'debug') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerDebugSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Failed to load provider.+npmtest.+from npm package.+Package not found/, + ), + ); + } finally { + loggerDebugSpy.mockRestore(); + } + }); + }); + + describe('provider configuration edge cases', () => { + it('should handle missing client ID', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_TEST_CLIENT_ID': + return undefined; // Missing + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Missing configuration for OAuth provider.+test/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should handle missing client secret', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_TEST_CLIENT_SECRET': + return undefined; // Missing + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Missing configuration for OAuth provider.+test/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should handle empty configuration values', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'test'; + case 'OAUTH_TEST_CLIENT_ID': + return ''; // Empty string + case 'OAUTH_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Missing configuration for OAuth provider.+test/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + }); + + describe('provider error handling in methods', () => { + let mockProviderWithErrors: OAuthProvider; + + beforeEach(async () => { + // Create a provider that throws errors in methods + mockProviderWithErrors = { + getConfig: () => ({ + id: 'error-test', + name: 'Error Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }), + getAuthorizationUrl: jest.fn().mockImplementation(() => { + throw new Error('Authorization URL generation failed'); + }), + handleCallback: jest.fn().mockImplementation(() => { + throw new Error('Callback handling failed'); + }), + getUserInfo: jest.fn().mockImplementation(() => { + throw new Error('User info retrieval failed'); + }), + }; + + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'error-test'; + case 'OAUTH_ERROR_TEST_CLIENT_ID': + return 'test-client-id'; + case 'OAUTH_ERROR_TEST_CLIENT_SECRET': + return 'test-client-secret'; + case 'OAUTH_ERROR_TEST_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + // Register the error provider directly + (service as any).registerProvider('error-test', mockProviderWithErrors); + (service as any).initialized = true; + }); + + it('should handle errors in generateAuthorizationUrl', async () => { + await expect( + service.generateAuthorizationUrl('error-test', 'state123'), + ).rejects.toThrow( + "Failed to generate authorization URL for provider 'error-test': Authorization URL generation failed", + ); + }); + + it('should handle errors in handleCallback', async () => { + await expect( + service.handleCallback('error-test', 'code123'), + ).rejects.toThrow( + "Failed to handle callback for provider 'error-test': Callback handling failed", + ); + }); + + it('should handle errors in getUserInfo', async () => { + await expect( + service.getUserInfo('error-test', 'token123'), + ).rejects.toThrow( + "Failed to get user info from provider 'error-test': User info retrieval failed", + ); + }); + + it('should handle non-Error objects thrown by providers', async () => { + const mockProviderWithStringError: OAuthProvider = { + getConfig: () => ({ + id: 'string-error-test', + name: 'String Error Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }), + getAuthorizationUrl: jest.fn().mockImplementation(() => { + throw 'String error'; // Non-Error object + }), + handleCallback: jest.fn().mockImplementation(() => { + throw 'String error'; + }), + getUserInfo: jest.fn().mockImplementation(() => { + throw 'String error'; + }), + }; + + (service as any).registerProvider( + 'string-error-test', + mockProviderWithStringError, + ); + + await expect( + service.generateAuthorizationUrl('string-error-test'), + ).rejects.toThrow( + "Failed to generate authorization URL for provider 'string-error-test': String error", + ); + + await expect( + service.handleCallback('string-error-test', 'code'), + ).rejects.toThrow( + "Failed to handle callback for provider 'string-error-test': String error", + ); + + await expect( + service.getUserInfo('string-error-test', 'token'), + ).rejects.toThrow( + "Failed to get user info from provider 'string-error-test': String error", + ); + }); + }); + + describe('configuration parsing edge cases', () => { + it('should handle whitespace in enabled providers list', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') + return ' test1 , test2 , , test3 '; + return undefined; + }); + + const enabledProviders = (service as any).getEnabledProviders(); + expect(enabledProviders).toEqual(['test1', 'test2', 'test3']); + }); + + it('should handle empty providers in enabled list', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ENABLED_PROVIDERS') return 'test1,,test2,,,test3'; + return undefined; + }); + + const enabledProviders = (service as any).getEnabledProviders(); + expect(enabledProviders).toEqual(['test1', 'test2', 'test3']); + }); + + it('should handle whitespace in plugin paths', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_PLUGIN_PATHS') return ' ./plugins1 , ./plugins2 '; + return undefined; + }); + + const pluginPaths = (service as any).getPluginPaths(); + expect(pluginPaths).toEqual(['./plugins1', './plugins2']); + }); + + it('should handle missing implementation warning', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'missing-provider'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return false; + case 'OAUTH_MISSING_PROVIDER_CLIENT_ID': + return 'client-id'; + case 'OAUTH_MISSING_PROVIDER_CLIENT_SECRET': + return 'client-secret'; + case 'OAUTH_MISSING_PROVIDER_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock tryLoadFromPath to return null (not found) + jest.spyOn(service as any, 'tryLoadFromPath').mockResolvedValue(null); + + // Mock the loadProvider method to directly call the warning + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async () => { + service['logger'].warn( + "Could not find implementation for OAuth provider 'missing-provider'", + ); + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Could not find implementation for OAuth provider.+missing-provider/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should handle provider loading with null result', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + switch (key) { + case 'OAUTH_ENABLED_PROVIDERS': + return 'null-provider'; + case 'OAUTH_PLUGIN_PATHS': + return './plugins'; + case 'OAUTH_ALLOW_NPM_LOADING': + return true; + case 'OAUTH_NULL_PROVIDER_CLIENT_ID': + return 'client-id'; + case 'OAUTH_NULL_PROVIDER_CLIENT_SECRET': + return 'client-secret'; + case 'OAUTH_NULL_PROVIDER_REDIRECT_URL': + return 'http://localhost:3000/callback'; + default: + return undefined; + } + }); + + (fs.existsSync as jest.Mock).mockReturnValue(false); + + // Mock tryLoadFromPath and tryLoadFromNpm to return null + jest.spyOn(service as any, 'tryLoadFromPath').mockResolvedValue(null); + jest.spyOn(service as any, 'tryLoadFromNpm').mockResolvedValue(null); + + // Mock the loadProvider method to directly call the warning + jest + .spyOn(service as any, 'loadProvider') + .mockImplementation(async () => { + service['logger'].warn( + "Could not find implementation for OAuth provider 'null-provider'", + ); + }); + + const loggerWarnSpy = jest + .spyOn(service['logger'], 'warn') + .mockImplementation(() => {}); + + try { + await service.initialize(); + + expect(loggerWarnSpy).toHaveBeenCalledWith( + expect.stringMatching( + /Could not find implementation for OAuth provider.+null-provider/, + ), + ); + } finally { + loggerWarnSpy.mockRestore(); + } + }); + + it('should test getAllowNpmLoading method', async () => { + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ALLOW_NPM_LOADING') return true; + return undefined; + }); + + const allowNpmLoading = (service as any).getAllowNpmLoading(); + expect(allowNpmLoading).toBe(true); + + jest.spyOn(configService, 'get').mockImplementation((key: string) => { + if (key === 'OAUTH_ALLOW_NPM_LOADING') return false; + return undefined; + }); + + const disallowNpmLoading = (service as any).getAllowNpmLoading(); + expect(disallowNpmLoading).toBe(false); + }); + + it('should test registerProvider method directly', async () => { + const mockProvider = new MockOAuthProvider({ + id: 'direct-test', + name: 'Direct Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user'], + }); + + const loggerLogSpy = jest + .spyOn(service['logger'], 'log') + .mockImplementation(() => {}); + + try { + (service as any).registerProvider('direct-test', mockProvider); + + expect(loggerLogSpy).toHaveBeenCalledWith( + 'Registered OAuth provider: direct-test', + ); + + // Verify the provider was actually registered + const provider = await service.getProvider('direct-test'); + expect(provider).toBeDefined(); + expect(provider?.getConfig().id).toBe('direct-test'); + } finally { + loggerLogSpy.mockRestore(); + } + }); + }); +}); diff --git a/src/auth/oauth/oauth.service.ts b/src/auth/oauth/oauth.service.ts new file mode 100644 index 00000000..a99f758c --- /dev/null +++ b/src/auth/oauth/oauth.service.ts @@ -0,0 +1,332 @@ +/* + * Description: OAuth Service - 动态加载和管理多个 OAuth 提供商 + * + * Author(s): + * HuanCheng65 + */ + +import { Injectable, Logger } from '@nestjs/common'; +import { ConfigService } from '@nestjs/config'; +import fs from 'node:fs'; +import path from 'node:path'; +import { OAuthError, OAuthProvider, OAuthProviderConfig } from './oauth.types'; + +@Injectable() +export class OAuthService { + private readonly logger = new Logger(OAuthService.name); + private readonly providers = new Map(); + private initialized = false; + + constructor(private readonly configService: ConfigService) {} + + async initialize(): Promise { + if (this.initialized) { + return; + } + + const enabledProviders = this.getEnabledProviders(); + if (enabledProviders.length === 0) { + this.logger.warn('No OAuth providers enabled'); + this.initialized = true; + return; + } + + const pluginPaths = this.getPluginPaths(); + const allowNpmLoading = this.getAllowNpmLoading(); + + for (const providerId of enabledProviders) { + try { + await this.loadProvider(providerId, pluginPaths, allowNpmLoading); + } catch (error) { + this.logger.error( + `Failed to load OAuth provider '${providerId}': ${error instanceof Error ? error.message : String(error)}`, + error instanceof Error ? error.stack : undefined, + ); + } + } + + this.initialized = true; + this.logger.log( + `OAuth service initialized with ${this.providers.size} providers: ${Array.from( + this.providers.keys(), + ).join(', ')}`, + ); + } + + private getEnabledProviders(): string[] { + const enabled = this.configService.get('OAUTH_ENABLED_PROVIDERS'); + if (!enabled) { + return []; + } + return enabled + .split(',') + .map((id) => id.trim()) + .filter((id) => id.length > 0) + .filter((id) => this.isValidProviderId(id)); + } + + private getPluginPaths(): string[] { + const paths = this.configService.get('OAUTH_PLUGIN_PATHS'); + if (!paths) { + return ['plugins/oauth']; + } + return paths.split(',').map((p) => p.trim()); + } + + private getAllowNpmLoading(): boolean { + return this.configService.get('OAUTH_ALLOW_NPM_LOADING') === true; + } + + private isValidProviderId(id: string): boolean { + return /^[a-zA-Z0-9_-]+$/.test(id); + } + + private async loadProvider( + providerId: string, + pluginPaths: string[], + allowNpmLoading: boolean, + ): Promise { + // 检查是否有必要的配置 + const config = this.getProviderConfig(providerId); + if (!config) { + this.logger.warn( + `Missing configuration for OAuth provider '${providerId}', skipping`, + ); + return; + } + + // 尝试从插件路径加载 + let provider: OAuthProvider | null = null; + for (const pluginPath of pluginPaths) { + provider = await this.tryLoadFromPath(providerId, pluginPath, config); + if (provider) { + break; + } + } + + // 如果本地未找到且允许 npm 加载,尝试从 npm 包加载 + if (!provider && allowNpmLoading) { + provider = await this.tryLoadFromNpm(providerId, config); + } + + if (provider) { + this.registerProvider(providerId, provider); + } else { + this.logger.warn( + `Could not find implementation for OAuth provider '${providerId}'`, + ); + } + } + + private getProviderConfig(providerId: string): OAuthProviderConfig | null { + const upperCaseId = providerId.toUpperCase(); + const clientId = this.configService.get( + `OAUTH_${upperCaseId}_CLIENT_ID`, + ); + const clientSecret = this.configService.get( + `OAUTH_${upperCaseId}_CLIENT_SECRET`, + ); + const redirectUrl = this.configService.get( + `OAUTH_${upperCaseId}_REDIRECT_URL`, + ); + + if (!clientId || !clientSecret || !redirectUrl) { + return null; + } + + return { + id: providerId, + name: providerId, // 可以通过配置覆盖显示名称 + clientId, + clientSecret, + authorizationUrl: '', // 由具体实现提供 + tokenUrl: '', // 由具体实现提供 + redirectUrl, + scope: [], // 由具体实现提供 + }; + } + + private async tryLoadFromPath( + providerId: string, + pluginPath: string, + config: OAuthProviderConfig, + ): Promise { + const possiblePaths = [ + path.join(pluginPath, `${providerId}/index.js`), + path.join(pluginPath, `${providerId}.js`), + path.join(pluginPath, `${providerId}/index.ts`), + path.join(pluginPath, `${providerId}.ts`), + ]; + + for (const modulePath of possiblePaths) { + try { + const resolvedPath = path.resolve(modulePath); + + // 安全检查:确保路径在预期的基准目录下 + const basePath = path.resolve(pluginPath); + if (!resolvedPath.startsWith(basePath)) { + this.logger.warn( + `Potential path traversal detected for provider '${providerId}': ${resolvedPath}`, + ); + continue; + } + + if (fs.existsSync(resolvedPath)) { + const module = await import(resolvedPath); + const createProvider = + module.createProvider || module.default || module; + + if (typeof createProvider === 'function') { + return createProvider(config); + } else { + this.logger.warn( + `Invalid provider module for '${providerId}': expected function`, + ); + } + } + } catch (error) { + this.logger.debug( + `Failed to load provider '${providerId}' from ${modulePath}: ${error instanceof Error ? error.message : String(error)}`, + ); + } + } + + return null; + } + + private async tryLoadFromNpm( + providerId: string, + config: OAuthProviderConfig, + ): Promise { + const packageName = `@sageseekersociety/cheese-auth-${providerId}-oauth-provider`; + + try { + const module = await import(packageName); + const createProvider = module.createProvider || module.default || module; + + if (typeof createProvider === 'function') { + return createProvider(config); + } else { + this.logger.warn( + `Invalid npm provider package for '${providerId}': expected function`, + ); + } + } catch (error) { + this.logger.debug( + `Failed to load provider '${providerId}' from npm package '${packageName}': ${error instanceof Error ? error.message : String(error)}`, + ); + } + + return null; + } + + private registerProvider(providerId: string, provider: OAuthProvider): void { + this.providers.set(providerId, provider); + this.logger.log(`Registered OAuth provider: ${providerId}`); + } + + async getProvider(providerId: string): Promise { + if (!this.initialized) { + await this.initialize(); + } + return this.providers.get(providerId); + } + + async getAllProviders(): Promise { + if (!this.initialized) { + await this.initialize(); + } + return Array.from(this.providers.values()); + } + + async getProvidersConfig(): Promise< + Array<{ id: string; name: string; scope: string[] }> + > { + if (!this.initialized) { + await this.initialize(); + } + + return Array.from(this.providers.entries()).map(([id, provider]) => { + const config = provider.getConfig(); + return { + id: config.id, + name: config.name, + scope: config.scope, + }; + }); + } + + async generateAuthorizationUrl( + providerId: string, + state?: string, + accessType?: string, + ): Promise { + const provider = await this.getProvider(providerId); + if (!provider) { + throw new OAuthError( + `OAuth provider '${providerId}' not found`, + providerId, + 'validation', + ); + } + + try { + return provider.getAuthorizationUrl(state, accessType); + } catch (error) { + throw new OAuthError( + `Failed to generate authorization URL for provider '${providerId}': ${error instanceof Error ? error.message : String(error)}`, + providerId, + 'authorization', + error, + ); + } + } + + async handleCallback( + providerId: string, + code: string, + state?: string, + ): Promise { + const provider = await this.getProvider(providerId); + if (!provider) { + throw new OAuthError( + `OAuth provider '${providerId}' not found`, + providerId, + 'validation', + ); + } + + try { + return await provider.handleCallback(code, state); + } catch (error) { + throw new OAuthError( + `Failed to handle callback for provider '${providerId}': ${error instanceof Error ? error.message : String(error)}`, + providerId, + 'token_exchange', + error, + ); + } + } + + async getUserInfo(providerId: string, accessToken: string) { + const provider = await this.getProvider(providerId); + if (!provider) { + throw new OAuthError( + `OAuth provider '${providerId}' not found`, + providerId, + 'validation', + ); + } + + try { + return await provider.getUserInfo(accessToken); + } catch (error) { + throw new OAuthError( + `Failed to get user info from provider '${providerId}': ${error instanceof Error ? error.message : String(error)}`, + providerId, + 'user_info', + error, + ); + } + } +} diff --git a/src/auth/oauth/oauth.types.spec.ts b/src/auth/oauth/oauth.types.spec.ts new file mode 100644 index 00000000..0d76481c --- /dev/null +++ b/src/auth/oauth/oauth.types.spec.ts @@ -0,0 +1,245 @@ +/* + * Description: Unit tests for OAuth Types + * + * Author(s): + * HuanCheng65 + */ + +import { + BaseOAuthProvider, + OAuthError, + OAuthProviderConfig, + OAuthUserInfo, +} from './oauth.types'; + +// Mock implementation for testing BaseOAuthProvider +class TestOAuthProvider extends BaseOAuthProvider { + constructor(config: OAuthProviderConfig) { + super(config); + } + + async handleCallback(code: string, state?: string): Promise { + if (code === 'valid_code') { + return 'mock_access_token'; + } + throw new Error('Invalid authorization code'); + } + + async getUserInfo(accessToken: string): Promise { + if (accessToken === 'mock_access_token') { + return { + id: '12345', + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }; + } + throw new Error('Invalid access token'); + } +} + +describe('OAuth Types', () => { + describe('BaseOAuthProvider', () => { + let provider: TestOAuthProvider; + let config: OAuthProviderConfig; + + beforeEach(() => { + config = { + id: 'test', + name: 'Test Provider', + clientId: 'test-client-id', + clientSecret: 'test-client-secret', + redirectUrl: 'http://localhost:3000/callback', + authorizationUrl: 'https://test.com/oauth/authorize', + tokenUrl: 'https://test.com/oauth/token', + scope: ['read:user', 'read:email'], + }; + provider = new TestOAuthProvider(config); + }); + + describe('constructor', () => { + it('should initialize with config', () => { + expect(provider.getConfig()).toBe(config); + }); + }); + + describe('getConfig', () => { + it('should return the configuration', () => { + const result = provider.getConfig(); + expect(result).toEqual(config); + expect(result.id).toBe('test'); + expect(result.name).toBe('Test Provider'); + expect(result.clientId).toBe('test-client-id'); + }); + }); + + describe('getAuthorizationUrl', () => { + it('should generate authorization URL without optional parameters', () => { + const url = provider.getAuthorizationUrl(); + + expect(url).toContain('https://test.com/oauth/authorize'); + expect(url).toContain('client_id=test-client-id'); + expect(url).toContain( + 'redirect_uri=http%3A%2F%2Flocalhost%3A3000%2Fcallback', + ); + expect(url).toContain('scope=read%3Auser+read%3Aemail'); + expect(url).toContain('response_type=code'); + expect(url).not.toContain('state='); + expect(url).not.toContain('access_type='); + }); + + it('should generate authorization URL with state parameter', () => { + const url = provider.getAuthorizationUrl('state123'); + + expect(url).toContain('https://test.com/oauth/authorize'); + expect(url).toContain('client_id=test-client-id'); + expect(url).toContain('state=state123'); + }); + + it('should generate authorization URL with access_type parameter', () => { + const url = provider.getAuthorizationUrl(undefined, 'offline'); + + expect(url).toContain('https://test.com/oauth/authorize'); + expect(url).toContain('access_type=offline'); + }); + + it('should generate authorization URL with both state and access_type parameters', () => { + const url = provider.getAuthorizationUrl('state456', 'online'); + + expect(url).toContain('state=state456'); + expect(url).toContain('access_type=online'); + }); + + it('should handle empty scope array', () => { + const configWithEmptyScope = { ...config, scope: [] }; + const providerWithEmptyScope = new TestOAuthProvider( + configWithEmptyScope, + ); + + const url = providerWithEmptyScope.getAuthorizationUrl(); + expect(url).toContain('scope='); + }); + + it('should handle single scope', () => { + const configWithSingleScope = { ...config, scope: ['read:user'] }; + const providerWithSingleScope = new TestOAuthProvider( + configWithSingleScope, + ); + + const url = providerWithSingleScope.getAuthorizationUrl(); + expect(url).toContain('scope=read%3Auser'); + }); + + it('should properly encode special characters in URLs', () => { + const configWithSpecialChars = { + ...config, + redirectUrl: 'http://localhost:3000/callback?test=value&other=data', + }; + const providerWithSpecialChars = new TestOAuthProvider( + configWithSpecialChars, + ); + + const url = providerWithSpecialChars.getAuthorizationUrl(); + expect(url).toContain( + 'redirect_uri=http%3A%2F%2Flocalhost%3A3000%2Fcallback%3Ftest%3Dvalue%26other%3Ddata', + ); + }); + }); + + describe('handleCallback', () => { + it('should handle valid authorization code', async () => { + const token = await provider.handleCallback('valid_code'); + expect(token).toBe('mock_access_token'); + }); + + it('should handle invalid authorization code', async () => { + await expect(provider.handleCallback('invalid_code')).rejects.toThrow( + 'Invalid authorization code', + ); + }); + + it('should handle callback with state parameter', async () => { + const token = await provider.handleCallback('valid_code', 'state123'); + expect(token).toBe('mock_access_token'); + }); + }); + + describe('getUserInfo', () => { + it('should return user info with valid token', async () => { + const userInfo = await provider.getUserInfo('mock_access_token'); + + expect(userInfo).toEqual({ + id: '12345', + email: 'test@example.com', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }); + }); + + it('should throw error with invalid token', async () => { + await expect(provider.getUserInfo('invalid_token')).rejects.toThrow( + 'Invalid access token', + ); + }); + }); + }); + + describe('OAuthError', () => { + it('should create error with basic information', () => { + const error = new OAuthError( + 'Test error message', + 'github', + 'authorization', + ); + + expect(error.message).toBe('Test error message'); + expect(error.provider).toBe('github'); + expect(error.type).toBe('authorization'); + expect(error.name).toBe('OAuthError'); + expect(error.originalError).toBeUndefined(); + }); + + it('should create error with original error', () => { + const originalError = new Error('Original error'); + const error = new OAuthError( + 'OAuth error occurred', + 'google', + 'token_exchange', + originalError, + ); + + expect(error.message).toBe('OAuth error occurred'); + expect(error.provider).toBe('google'); + expect(error.type).toBe('token_exchange'); + expect(error.originalError).toBe(originalError); + }); + + it('should support all error types', () => { + const types = [ + 'authorization', + 'token_exchange', + 'user_info', + 'validation', + ] as const; + + types.forEach((type) => { + const error = new OAuthError('Test message', 'provider', type); + expect(error.type).toBe(type); + }); + }); + + it('should be instanceof Error', () => { + const error = new OAuthError('Test message', 'provider', 'validation'); + expect(error instanceof Error).toBe(true); + expect(error instanceof OAuthError).toBe(true); + }); + + it('should have correct stack trace', () => { + const error = new OAuthError('Test message', 'provider', 'validation'); + expect(error.stack).toBeDefined(); + expect(error.stack).toContain('OAuthError'); + }); + }); +}); diff --git a/src/auth/oauth/oauth.types.ts b/src/auth/oauth/oauth.types.ts new file mode 100644 index 00000000..7063c439 --- /dev/null +++ b/src/auth/oauth/oauth.types.ts @@ -0,0 +1,82 @@ +/* + * Description: OAuth 类型定义和接口 + * + * Author(s): + * HuanCheng65 + */ + +export interface OAuthUserInfo { + id: string; + email?: string; + name?: string; + username?: string; + preferredUsername?: string; +} + +export interface OAuthProviderConfig { + id: string; + name: string; + clientId: string; + clientSecret: string; + authorizationUrl: string; + tokenUrl: string; + redirectUrl: string; + scope: string[]; +} + +export interface OAuthProvider { + getConfig(): OAuthProviderConfig; + getAuthorizationUrl(state?: string, accessType?: string): string; + handleCallback(code: string, state?: string): Promise; + getUserInfo(accessToken: string): Promise; +} + +export abstract class BaseOAuthProvider implements OAuthProvider { + protected config: OAuthProviderConfig; + + constructor(config: OAuthProviderConfig) { + this.config = config; + } + + getConfig(): OAuthProviderConfig { + return this.config; + } + + getAuthorizationUrl(state?: string, accessType?: string): string { + const params = new URLSearchParams({ + client_id: this.config.clientId, + redirect_uri: this.config.redirectUrl, + scope: this.config.scope.join(' '), + response_type: 'code', + }); + + if (state) { + params.append('state', state); + } + + if (accessType) { + params.append('access_type', accessType); + } + + return `${this.config.authorizationUrl}?${params.toString()}`; + } + + abstract handleCallback(code: string, state?: string): Promise; + abstract getUserInfo(accessToken: string): Promise; +} + +export class OAuthError extends Error { + constructor( + message: string, + public provider: string, + public type: + | 'authorization' + | 'token_exchange' + | 'user_info' + | 'validation', + public originalError?: any, + ) { + super(message); + this.name = 'OAuthError'; + } +} diff --git a/src/users/DTO/oauth.dto.ts b/src/users/DTO/oauth.dto.ts new file mode 100644 index 00000000..a3968437 --- /dev/null +++ b/src/users/DTO/oauth.dto.ts @@ -0,0 +1,42 @@ +/* + * Description: OAuth related DTOs + * + * Author(s): + * HuanCheng65 + */ + +import { IsOptional, IsString } from 'class-validator'; +import { BaseResponseDto } from '../../common/DTO/base-response.dto'; +import { UserDto } from './user.dto'; + +export class GetOAuthProvidersResponseDto extends BaseResponseDto { + data: { + providers: Array<{ + id: string; + name: string; + scope: string[]; + }>; + }; +} + +export class OAuthCallbackQueryDto { + @IsString() + code: string; + + @IsOptional() + @IsString() + state?: string; + + @IsOptional() + @IsString() + error?: string; + + @IsOptional() + @IsString() + error_description?: string; +} + +// OAuth用户DTO,继承自UserDto并添加email字段 +export class OAuthUserDto extends UserDto { + email?: string | null; +} diff --git a/src/users/users.controller.ts b/src/users/users.controller.ts index e223199e..40a1b134 100644 --- a/src/users/users.controller.ts +++ b/src/users/users.controller.ts @@ -17,6 +17,7 @@ import { HttpCode, Inject, Ip, + Logger, Param, ParseIntPipe, Patch, @@ -39,6 +40,8 @@ import { ResourceId, ResourceOwnerIdGetter, } from '../auth/guard.decorator'; +import { OAuthService } from '../auth/oauth/oauth.service'; +import { OAuthError } from '../auth/oauth/oauth.types'; import { SessionService } from '../auth/session.service'; import { UserId } from '../auth/user-id.decorator'; import { BaseResponseDto } from '../common/DTO/base-response.dto'; @@ -60,6 +63,10 @@ import { GetFollowedQuestionsResponseDto } from './DTO/get-followed-questions.dt import { GetFollowersResponseDto } from './DTO/get-followers.dto'; import { GetUserResponseDto } from './DTO/get-user.dto'; import { LoginRequestDto, LoginResponseDto } from './DTO/login.dto'; +import { + GetOAuthProvidersResponseDto, + OAuthCallbackQueryDto, +} from './DTO/oauth.dto'; import { DeletePasskeyResponseDto, GetPasskeysResponseDto, @@ -127,6 +134,8 @@ declare module 'express-session' { @Controller('/users') export class UsersController { + private readonly logger = new Logger(UsersController.name); + constructor( private readonly usersService: UsersService, private readonly authService: AuthService, @@ -138,6 +147,7 @@ export class UsersController { @Inject(forwardRef(() => QuestionsService)) private readonly questionsService: QuestionsService, private readonly configService: ConfigService, + private readonly oauthService: OAuthService, ) {} @ResourceOwnerIdGetter('user') @@ -1203,4 +1213,137 @@ export class UsersController { message: 'Password changed successfully', }; } + + // OAuth 相关路由 + + @Get('/auth/oauth/providers') + @NoAuth() + async getOAuthProviders(): Promise { + const providers = await this.oauthService.getProvidersConfig(); + return { + code: 200, + message: 'Get OAuth providers successfully.', + data: { + providers, + }, + }; + } + + @Get('/auth/oauth/login/:providerId') + @NoAuth() + async oauthLogin( + @Param('providerId') providerId: string, + @Query('state') state?: string, + @Query('access_type') accessType?: string, + @Res() res?: Response, + ): Promise { + try { + const authUrl = await this.oauthService.generateAuthorizationUrl( + providerId, + state, + accessType, + ); + res!.redirect(authUrl); + return; + } catch (error) { + if (error instanceof OAuthError) { + const frontendBaseUrl = this.configService.get('FRONTEND_BASE_URL'); + const errorPath = + this.configService.get('FRONTEND_OAUTH_ERROR_PATH') || '/oauth-error'; + const errorUrl = `${frontendBaseUrl}${errorPath}?error=${encodeURIComponent(error.message)}&provider=${providerId}`; + res!.redirect(errorUrl); + return; + } + throw error; + } + } + + @Get('/auth/oauth/callback/:providerId') + @NoAuth() + async oauthCallback( + @Param('providerId') providerId: string, + @Query() query: OAuthCallbackQueryDto, + @Ip() ip: string, + @Headers('User-Agent') userAgent: string | undefined, + @Res() res: Response, + ): Promise { + try { + // 检查是否有错误 + if (query.error) { + const frontendBaseUrl = this.configService.get('FRONTEND_BASE_URL'); + const errorPath = + this.configService.get('FRONTEND_OAUTH_ERROR_PATH') || '/oauth-error'; + const errorUrl = `${frontendBaseUrl}${errorPath}?error=${encodeURIComponent(query.error)}&provider=${providerId}&description=${encodeURIComponent(query.error_description || '')}`; + res.redirect(errorUrl); + return; + } + + if (!query.code) { + throw new Error('Authorization code not provided'); + } + + // 1. 交换访问令牌 + const accessToken = await this.oauthService.handleCallback( + providerId, + query.code, + query.state, + ); + + // 2. 获取用户信息 + const userInfo = await this.oauthService.getUserInfo( + providerId, + accessToken, + ); + + // 3. 处理用户登录/注册 + const [userDto, refreshToken] = await this.usersService.loginWithOAuth( + providerId, + userInfo, + ip, + userAgent, + ); + + // 4. 生成 JWT Access Token + const [newRefreshToken, jwtAccessToken] = + await this.sessionService.refreshSession(refreshToken); + const newRefreshTokenExpire = new Date( + this.authService.decode(newRefreshToken).validUntil, + ); + + // 5. 构造前端跳转 URL + const frontendBaseUrl = this.configService.get('FRONTEND_BASE_URL'); + const successPath = + this.configService.get('FRONTEND_OAUTH_SUCCESS_PATH') || + '/oauth-success'; + const frontendUrl = `${frontendBaseUrl}${successPath}?token=${encodeURIComponent(jwtAccessToken)}&email=${encodeURIComponent(userDto.email || userDto.username)}&provider=${providerId}`; + + // 6. 设置 Refresh Token Cookie + const cookieBasePath = this.configService.get('cookieBasePath') || ''; + res.cookie('REFRESH_TOKEN', newRefreshToken, { + httpOnly: true, + sameSite: 'lax', // 允许第三方跳转携带 + path: path.posix.join(cookieBasePath, 'users/auth'), + expires: newRefreshTokenExpire, + secure: process.env.NODE_ENV === 'production', // 生产环境使用 HTTPS + }); + + // 7. 重定向到前端 + res.redirect(frontendUrl); + return; + } catch (error) { + this.logger.error( + `OAuth callback failed for provider ${providerId}: ${error instanceof Error ? error.message : String(error)}`, + error instanceof Error ? error.stack : undefined, + ); + + const frontendBaseUrl = this.configService.get('FRONTEND_BASE_URL'); + const errorPath = + this.configService.get('FRONTEND_OAUTH_ERROR_PATH') || '/oauth-error'; + const errorMessage = + error instanceof OAuthError ? error.message : 'Internal server error'; + const errorUrl = `${frontendBaseUrl}${errorPath}?error=${encodeURIComponent(errorMessage)}&provider=${providerId}`; + res.redirect(errorUrl); + return; + } + } } diff --git a/src/users/users.prisma b/src/users/users.prisma index b6082a4a..2918ce7e 100644 --- a/src/users/users.prisma +++ b/src/users/users.prisma @@ -9,6 +9,7 @@ import { AttitudeLog } from "../attitude/attitude" import { QuestionInvitationRelation } from "../questions/questions.invitation" import { Material } from "../materials/materials" import { MaterialBundle } from "../materialbundles/materialbundles" +import { UserOAuthConnection } from "../auth/oauth/oauth" model User { id Int @id(map: "PK_cace4a159ff9f2512dd42373760") @default(autoincrement()) @@ -54,6 +55,7 @@ model User { totpEnabled Boolean @default(false) @map("totp_enabled") totpAlwaysRequired Boolean @default(false) @map("totp_always_required") backupCodes UserBackupCode[] + oauthConnections UserOAuthConnection[] @@map("user") } diff --git a/src/users/users.service.spec.ts b/src/users/users.service.spec.ts new file mode 100644 index 00000000..eb445187 --- /dev/null +++ b/src/users/users.service.spec.ts @@ -0,0 +1,732 @@ +/* + * Description: Unit tests for Users Service OAuth functionality + * + * Author(s): + * HuanCheng65 + */ + +import { RedisService } from '@liaoliaots/nestjs-redis'; +import { ConfigService } from '@nestjs/config'; +import { Test, TestingModule } from '@nestjs/testing'; +import { AnswerService } from '../answer/answer.service'; +import { AuthService } from '../auth/auth.service'; +import { OAuthUserInfo } from '../auth/oauth/oauth.types'; +import { SessionService } from '../auth/session.service'; +import { AvatarsService } from '../avatars/avatars.service'; +import { PrismaService } from '../common/prisma/prisma.service'; +import { EmailRuleService } from '../email/email-rule.service'; +import { EmailService } from '../email/email.service'; +import { QuestionsService } from '../questions/questions.service'; +import { RolePermissionService } from './role-permission.service'; +import { SrpService } from './srp.service'; +import { TOTPService } from './totp.service'; +import { UserChallengeRepository } from './user-challenge.repository'; +import { UsersPermissionService } from './users-permission.service'; +import { UsersRegisterRequestService } from './users-register-request.service'; +import { UsersService } from './users.service'; + +describe('UsersService - OAuth', () => { + let service: UsersService; + let prismaService: PrismaService; + + const mockPrismaService = { + userOAuthConnection: { + findUnique: jest.fn(), + create: jest.fn(), + update: jest.fn(), + upsert: jest.fn(), + }, + user: { + findUnique: jest.fn(), + create: jest.fn(), + }, + userProfile: { + create: jest.fn(), + }, + userLoginLog: { + create: jest.fn(), + }, + userRegisterLog: { + create: jest.fn(), + }, + $transaction: jest.fn(), + }; + + const mockAuthService = { + generateRandomPassword: jest.fn().mockReturnValue('random-password'), + }; + + const mockSessionService = { + createSession: jest.fn().mockResolvedValue('session-token'), + }; + + const mockConfigService = { + get: jest.fn().mockImplementation((key: string) => { + if (key === 'defaultIntro') { + return 'This user has not set an introduction yet.'; + } + return undefined; + }), + }; + + const mockRedisService = { + getOrThrow: jest.fn().mockReturnValue({ + publish: jest.fn(), + }), + }; + + const mockEmailRuleService = { + verifyEmailRule: jest.fn(), + }; + + const mockAnswerService = { + // Add any methods that might be called in tests + }; + + const mockUsersPermissionService = { + getAuthorizationForUser: jest.fn().mockResolvedValue({ + permissions: [], + roles: [], + }), + }; + + const mockAvatarsService = { + getDefaultAvatarId: jest.fn().mockResolvedValue(1), + }; + + beforeEach(async () => { + const module: TestingModule = await Test.createTestingModule({ + providers: [ + UsersService, + { provide: PrismaService, useValue: mockPrismaService }, + { provide: AuthService, useValue: mockAuthService }, + { provide: SessionService, useValue: mockSessionService }, + { provide: ConfigService, useValue: mockConfigService }, + { provide: RedisService, useValue: mockRedisService }, + { provide: EmailService, useValue: {} }, + { provide: EmailRuleService, useValue: mockEmailRuleService }, + { provide: AvatarsService, useValue: mockAvatarsService }, + { provide: UsersRegisterRequestService, useValue: {} }, + { + provide: UsersPermissionService, + useValue: mockUsersPermissionService, + }, + { provide: RolePermissionService, useValue: {} }, + { provide: UserChallengeRepository, useValue: {} }, + { provide: TOTPService, useValue: {} }, + { provide: SrpService, useValue: {} }, + { provide: AnswerService, useValue: mockAnswerService }, + { provide: QuestionsService, useValue: {} }, + ], + }).compile(); + + service = module.get(UsersService); + prismaService = module.get(PrismaService); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('loginWithOAuth', () => { + const mockUserInfo: OAuthUserInfo = { + id: '12345', + email: 'test@ruc.edu.cn', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }; + + const mockExistingUser = { + id: 1, + username: 'existing-user', + email: 'test@ruc.edu.cn', + deletedAt: null, + userProfile: { + id: 1, + nickname: 'Existing User', + }, + }; + + const mockOAuthConnection = { + id: 1, + userId: 1, + providerId: 'test', + providerUserId: '12345', + user: mockExistingUser, + }; + + beforeEach(() => { + // Mock getUserDtoById method + jest.spyOn(service, 'getUserDtoById').mockResolvedValue({ + id: 1, + username: 'test-user', + nickname: 'Test User', + email: 'test@ruc.edu.cn', + } as any); + + // Mock createSession method (it's private, so we need to mock the sessionService.createSession instead) + jest + .spyOn(mockSessionService, 'createSession') + .mockResolvedValue('session-token'); + + // Mock createDefaultProfileForUser method (private method, mock via prisma) + jest + .spyOn(service as any, 'createDefaultProfileForUser') + .mockResolvedValue(undefined); + }); + + it('should login existing user with OAuth connection', async () => { + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue( + mockOAuthConnection, + ); + mockPrismaService.user.findUnique.mockResolvedValue(mockExistingUser); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.update.mockResolvedValue({}); + + const result = await service.loginWithOAuth( + 'test', + mockUserInfo, + '127.0.0.1', + 'test-agent', + ); + + expect(result).toHaveLength(2); + expect(result[1]).toBe('session-token'); + expect( + mockPrismaService.userOAuthConnection.findUnique, + ).toHaveBeenCalledWith({ + where: { + providerId_providerUserId: { + providerId: 'test', + providerUserId: '12345', + }, + }, + include: { + user: { + include: { + userProfile: true, + }, + }, + }, + }); + expect(mockPrismaService.userLoginLog.create).toHaveBeenCalledWith({ + data: { + userId: 1, + ip: '127.0.0.1', + userAgent: 'test-agent', + }, + }); + }); + + it('should create profile for user without profile', async () => { + const connectionWithoutProfile = { + ...mockOAuthConnection, + user: { + ...mockExistingUser, + userProfile: null, + }, + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue( + connectionWithoutProfile, + ); + mockPrismaService.user.findUnique.mockResolvedValue(mockExistingUser); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.update.mockResolvedValue({}); + + await service.loginWithOAuth( + 'test', + mockUserInfo, + '127.0.0.1', + 'test-agent', + ); + + expect(service['createDefaultProfileForUser']).toHaveBeenCalledWith(1); + }); + + it('should bind OAuth to existing user by email', async () => { + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique + .mockResolvedValueOnce(mockExistingUser) + .mockResolvedValueOnce(mockExistingUser); + mockPrismaService.userOAuthConnection.upsert.mockResolvedValue( + mockOAuthConnection, + ); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + + const result = await service.loginWithOAuth( + 'test', + mockUserInfo, + '127.0.0.1', + 'test-agent', + ); + + expect(result).toHaveLength(2); + expect(mockPrismaService.user.findUnique).toHaveBeenCalledWith({ + where: { email: 'test@ruc.edu.cn' }, + include: { userProfile: true }, + }); + expect(mockPrismaService.userOAuthConnection.upsert).toHaveBeenCalledWith( + expect.objectContaining({ + where: { + providerId_providerUserId: { + providerId: 'test', + providerUserId: '12345', + }, + }, + create: { + providerId: 'test', + providerUserId: '12345', + userId: 1, + rawProfile: mockUserInfo, + }, + update: expect.objectContaining({ + rawProfile: mockUserInfo, + updatedAt: expect.any(Date), + }), + }), + ); + }); + + it('should create new user when no existing connection or email match', async () => { + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + const newUser = { + id: 2, + username: 'testuser', + email: 'test@ruc.edu.cn', + deletedAt: null, + }; + + const newUserProfile = { + id: 2, + userId: 2, + nickname: 'Test User', + }; + + const newOAuthConnection = { + id: 2, + userId: 2, + providerId: 'test', + providerUserId: '12345', + }; + + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + // Mock the transaction operations + mockPrismaService.user.create.mockResolvedValue(newUser); + mockPrismaService.userProfile.create.mockResolvedValue(newUserProfile); + mockPrismaService.userOAuthConnection.create.mockResolvedValue( + newOAuthConnection, + ); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + // Mock user.findUnique for getOAuthUserDtoById + mockPrismaService.user.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce(newUser); + + // Mock generateUniqueUsername method + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('testuser'); + + const result = await service.loginWithOAuth( + 'test', + mockUserInfo, + '127.0.0.1', + 'test-agent', + ); + + expect(result).toHaveLength(2); + expect(mockPrismaService.$transaction).toHaveBeenCalled(); + expect(mockPrismaService.user.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ + username: 'testuser', + email: 'test@ruc.edu.cn', + srpUpgraded: false, + }), + }); + expect(mockPrismaService.userProfile.create).toHaveBeenCalledWith({ + data: { + userId: 2, + nickname: 'Test User', + intro: 'This user has not set an introduction yet.', + avatarId: 1, + }, + }); + expect(mockPrismaService.userOAuthConnection.create).toHaveBeenCalledWith( + { + data: { + userId: 2, + providerId: 'test', + providerUserId: '12345', + rawProfile: mockUserInfo, + }, + }, + ); + }); + + it('should handle OAuth user without email', async () => { + const userInfoWithoutEmail: OAuthUserInfo = { + id: '12345', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + + const newUser = { + id: 2, + username: 'testuser', + email: 'oauth-test-12345@placeholder.internal', + deletedAt: null, + }; + + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + mockPrismaService.user.create.mockResolvedValue(newUser); + mockPrismaService.userProfile.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.create.mockResolvedValue({}); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + // Mock user.findUnique for getOAuthUserDtoById + mockPrismaService.user.findUnique.mockResolvedValue(newUser); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('testuser'); + + await service.loginWithOAuth( + 'test', + userInfoWithoutEmail, + '127.0.0.1', + 'test-agent', + ); + + expect(mockPrismaService.user.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ + username: 'testuser', + email: 'oauth-test-12345@placeholder.internal', + srpUpgraded: false, + }), + }); + }); + + it('should handle deleted user gracefully', async () => { + const deletedUser = { + ...mockExistingUser, + deletedAt: new Date(), + }; + + const connectionWithDeletedUser = { + ...mockOAuthConnection, + user: deletedUser, + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue( + connectionWithDeletedUser, + ); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + // Should proceed to create new user + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + const newUser = { + id: 3, + username: 'testuser2', + email: 'test@ruc.edu.cn', + }; + mockPrismaService.user.create.mockResolvedValue(newUser); + mockPrismaService.userProfile.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.create.mockResolvedValue({}); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + // Mock user.findUnique for getOAuthUserDtoById + mockPrismaService.user.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce(newUser); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('testuser2'); + + const result = await service.loginWithOAuth( + 'test', + mockUserInfo, + '127.0.0.1', + 'test-agent', + ); + + expect(result).toHaveLength(2); + expect(mockPrismaService.user.create).toHaveBeenCalled(); + }); + + it('should generate fallback nickname when OAuth name is missing', async () => { + const userInfoWithoutName: OAuthUserInfo = { + id: '12345', + email: 'test@ruc.edu.cn', + username: 'testuser', + preferredUsername: 'testuser', + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + const newUser = { id: 2, username: 'testuser', email: 'test@ruc.edu.cn' }; + mockPrismaService.user.create.mockResolvedValue(newUser); + mockPrismaService.userProfile.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.create.mockResolvedValue({}); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + // Mock user.findUnique for getOAuthUserDtoById + mockPrismaService.user.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce(newUser); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('testuser'); + + await service.loginWithOAuth( + 'test', + userInfoWithoutName, + '127.0.0.1', + 'test-agent', + ); + + expect(mockPrismaService.userProfile.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ + userId: 2, + nickname: 'testuser', // Should fallback to username + intro: 'This user has not set an introduction yet.', + avatarId: 1, + }), + }); + }); + + it('should handle database transaction errors during user creation', async () => { + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + // Mock transaction to throw error + mockPrismaService.$transaction.mockRejectedValue( + new Error('Database transaction failed'), + ); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('testuser'); + + await expect( + service.loginWithOAuth('test', mockUserInfo, '127.0.0.1', 'test-agent'), + ).rejects.toThrow('Database transaction failed'); + + expect(mockPrismaService.$transaction).toHaveBeenCalled(); + }); + + it('should handle OAuth connection update errors', async () => { + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue( + mockOAuthConnection, + ); + mockPrismaService.user.findUnique.mockResolvedValue(mockExistingUser); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + + // Mock update to throw error + mockPrismaService.userOAuthConnection.update.mockRejectedValue( + new Error('Update failed'), + ); + + await expect( + service.loginWithOAuth('test', mockUserInfo, '127.0.0.1', 'test-agent'), + ).rejects.toThrow('Update failed'); + }); + + it('should handle missing OAuth user ID', async () => { + const userInfoWithoutId: Omit & { id?: string } = { + email: 'test@ruc.edu.cn', + name: 'Test User', + username: 'testuser', + preferredUsername: 'testuser', + }; + + // Call with undefined id + await expect( + service.loginWithOAuth( + 'test', + userInfoWithoutId as OAuthUserInfo, + '127.0.0.1', + 'test-agent', + ), + ).rejects.toThrow(); + }); + + it('should handle very long OAuth user data', async () => { + const longString = 'a'.repeat(1000); + const userInfoWithLongData: OAuthUserInfo = { + id: '12345', + email: 'test@ruc.edu.cn', + name: longString, + username: longString, + preferredUsername: longString, + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + const newUser = { + id: 2, + username: 'shortened-username', // Should be shortened + email: 'test@ruc.edu.cn', + }; + mockPrismaService.user.create.mockResolvedValue(newUser); + mockPrismaService.userProfile.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.create.mockResolvedValue({}); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + mockPrismaService.user.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce(newUser); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('shortened-username'); + + const result = await service.loginWithOAuth( + 'test', + userInfoWithLongData, + '127.0.0.1', + 'test-agent', + ); + + expect(result).toHaveLength(2); + expect(mockPrismaService.userProfile.create).toHaveBeenCalledWith({ + data: expect.objectContaining({ + nickname: longString.substring(0, 255), // Should be truncated if needed + }), + }); + }); + + it('should handle OAuth provider with special characters in user data', async () => { + const userInfoWithSpecialChars: OAuthUserInfo = { + id: '12345', + email: 'test+special@ruc.edu.cn', + name: 'Test User <>&"\'', + username: 'test-user-123', + preferredUsername: 'test_user_123', + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + const newUser = { + id: 2, + username: 'test-user-123', + email: 'test+special@ruc.edu.cn', + }; + mockPrismaService.user.create.mockResolvedValue(newUser); + mockPrismaService.userProfile.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.create.mockResolvedValue({}); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + mockPrismaService.user.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce(newUser); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('test-user-123'); + + const result = await service.loginWithOAuth( + 'test', + userInfoWithSpecialChars, + '127.0.0.1', + 'test-agent', + ); + + expect(result).toHaveLength(2); + expect(mockPrismaService.userOAuthConnection.create).toHaveBeenCalledWith( + { + data: { + userId: 2, + providerId: 'test', + providerUserId: '12345', // Should preserve original OAuth ID + rawProfile: userInfoWithSpecialChars, + }, + }, + ); + }); + + it('should use preferredUsername when username is not available', async () => { + const userInfoWithPreferredUsername: OAuthUserInfo = { + id: '12345', + email: 'test@ruc.edu.cn', + name: 'Test User', + preferredUsername: 'preferred-user', + }; + + mockPrismaService.userOAuthConnection.findUnique.mockResolvedValue(null); + mockPrismaService.user.findUnique.mockResolvedValue(null); + + mockPrismaService.$transaction.mockImplementation(async (callback) => { + return await callback(mockPrismaService); + }); + + mockPrismaService.user.create.mockResolvedValue({ + id: 2, + username: 'preferred-user', + email: 'test@ruc.edu.cn', + }); + mockPrismaService.userProfile.create.mockResolvedValue({}); + mockPrismaService.userOAuthConnection.create.mockResolvedValue({}); + mockPrismaService.userLoginLog.create.mockResolvedValue({}); + mockPrismaService.userRegisterLog.create.mockResolvedValue({}); + + mockPrismaService.user.findUnique + .mockResolvedValueOnce(null) + .mockResolvedValueOnce({ id: 2, username: 'preferred-user' }); + + jest + .spyOn(service, 'generateUniqueUsername' as any) + .mockResolvedValue('preferred-user'); + + await service.loginWithOAuth( + 'test', + userInfoWithPreferredUsername, + '127.0.0.1', + 'test-agent', + ); + + expect(service['generateUniqueUsername']).toHaveBeenCalledWith( + 'preferred-user', + ); + }); + }); +}); diff --git a/src/users/users.service.ts b/src/users/users.service.ts index a1e0dfad..68466f39 100644 --- a/src/users/users.service.ts +++ b/src/users/users.service.ts @@ -41,6 +41,7 @@ import { } from '../auth/auth.error'; import { AuthService } from '../auth/auth.service'; import { Authorization } from '../auth/definitions'; +import { OAuthUserInfo } from '../auth/oauth/oauth.types'; import { SessionService } from '../auth/session.service'; import { AvatarNotFoundError } from '../avatars/avatars.error'; import { AvatarsService } from '../avatars/avatars.service'; @@ -50,6 +51,7 @@ import { PrismaService } from '../common/prisma/prisma.service'; import { EmailRuleService } from '../email/email-rule.service'; import { EmailService } from '../email/email.service'; import { QuestionsService } from '../questions/questions.service'; +import { OAuthUserDto } from './DTO/oauth.dto'; import { UserDto } from './DTO/user.dto'; import { SrpService } from './srp.service'; import { TOTPService } from './totp.service'; @@ -736,6 +738,30 @@ export class UsersService { }; } + /** + * Get OAuth user DTO with email field for OAuth operations + */ + async getOAuthUserDtoById( + userId: number, + viewerId: number | undefined, // optional + ip: string, + userAgent: string | undefined, // optional + ): Promise { + const userDto = await this.getUserDtoById(userId, viewerId, ip, userAgent); + const user = await this.findUserRecordOrThrow(userId); + + // 检查是否是占位符email,如果是则返回null + const email = + user.email && user.email.endsWith('@placeholder.internal') + ? null + : user.email; + + return { + ...userDto, + email: email, + }; + } + // Returns: // [userDto, refreshToken] async login( @@ -1612,4 +1638,362 @@ export class UsersService { }, }); } + + /** + * OAuth 用户登录/注册处理 + * 将第三方返回的用户信息与本地用户数据库同步 + */ + async loginWithOAuth( + providerId: string, + userInfo: OAuthUserInfo, + ip: string, + userAgent: string | undefined, + ): Promise<[OAuthUserDto, string]> { + this.logger.log( + `Processing OAuth login for provider: ${providerId}, user: ${userInfo.id}`, + ); + + // 1. 检查已有绑定 + const existingConnection = + await this.prismaService.userOAuthConnection.findUnique({ + where: { + providerId_providerUserId: { + providerId, + providerUserId: userInfo.id, + }, + }, + include: { + user: { + include: { + userProfile: true, + }, + }, + }, + }); + + if ( + existingConnection && + existingConnection.user && + !existingConnection.user.deletedAt + ) { + return await this.handleExistingOAuthConnection( + existingConnection, + userInfo, + ip, + userAgent, + ); + } + + // 2. 按邮箱匹配现有用户 + if (userInfo.email) { + const existingUser = await this.prismaService.user.findUnique({ + where: { email: userInfo.email }, + include: { userProfile: true }, + }); + + if (existingUser && !existingUser.deletedAt) { + return await this.bindOAuthToExistingUserByEmail( + existingUser, + providerId, + userInfo, + ip, + userAgent, + ); + } + } + + // 3. 创建新用户 + return await this.createNewOAuthUser(providerId, userInfo, ip, userAgent); + } + + /** + * 处理已存在的 OAuth 连接 + */ + private async handleExistingOAuthConnection( + existingConnection: any, + userInfo: OAuthUserInfo, + ip: string, + userAgent: string | undefined, + ): Promise<[OAuthUserDto, string]> { + // 用户已存在,检查是否有 profile + if (!existingConnection.user.userProfile) { + this.logger.warn( + `User ${existingConnection.user.id} has OAuth connection but no profile, creating default profile`, + ); + await this.createDefaultProfileForUser(existingConnection.user.id); + } + + // 记录登录日志 + await this.prismaService.userLoginLog.create({ + data: { + userId: existingConnection.user.id, + ip, + userAgent, + }, + }); + + // 更新连接的原始资料 + await this.prismaService.userOAuthConnection.update({ + where: { id: existingConnection.id }, + data: { + rawProfile: userInfo as any, + updatedAt: new Date(), + }, + }); + + return [ + await this.getOAuthUserDtoById( + existingConnection.user.id, + existingConnection.user.id, + ip, + userAgent, + ), + await this.createSession(existingConnection.user.id), + ]; + } + + /** + * 将 OAuth 连接绑定到现有用户(通过邮箱匹配) + */ + private async bindOAuthToExistingUserByEmail( + existingUser: any, + providerId: string, + userInfo: OAuthUserInfo, + ip: string, + userAgent: string | undefined, + ): Promise<[OAuthUserDto, string]> { + this.logger.log( + `Found existing user by email for OAuth login: ${existingUser.username}`, + ); + + // 创建关联 + await this.prismaService.userOAuthConnection.upsert({ + where: { + providerId_providerUserId: { + providerId, + providerUserId: userInfo.id, + }, + }, + update: { + userId: existingUser.id, + rawProfile: userInfo as any, + updatedAt: new Date(), + }, + create: { + userId: existingUser.id, + providerId, + providerUserId: userInfo.id, + rawProfile: userInfo as any, + }, + }); + + // 记录登录日志 + await this.prismaService.userLoginLog.create({ + data: { + userId: existingUser.id, + ip, + userAgent, + }, + }); + + return [ + await this.getOAuthUserDtoById( + existingUser.id, + existingUser.id, + ip, + userAgent, + ), + await this.createSession(existingUser.id), + ]; + } + + /** + * 为 OAuth 用户创建新账户 + */ + private async createNewOAuthUser( + providerId: string, + userInfo: OAuthUserInfo, + ip: string, + userAgent: string | undefined, + ): Promise<[OAuthUserDto, string]> { + this.logger.log( + `Creating new user for OAuth login from provider: ${providerId}`, + ); + + // 生成唯一用户名 + const baseUsername = this.generateOAuthUsername(userInfo); + const uniqueUsername = await this.generateUniqueUsername(baseUsername); + + // 获取默认头像 + const avatarId = await this.avatarsService.getDefaultAvatarId(); + + // 生成随机密码(用户不会使用,仅为占位) + const randomPassword = this.generateRandomPassword(); + const hashedPassword = bcrypt.hashSync(randomPassword, 10); + + // 在事务中创建用户、profile 和 OAuth 连接 + const result = await this.prismaService.$transaction(async (tx) => { + // 为没有email的用户生成唯一占位符email + let userEmail = userInfo.email; + if (!userEmail) { + // 生成格式:oauth-{providerId}-{providerUserId}@placeholder.internal + userEmail = `oauth-${providerId}-${userInfo.id}@placeholder.internal`; + } + + // 创建用户 + const newUser = await tx.user.create({ + data: { + username: uniqueUsername, + email: userEmail, + hashedPassword, + srpUpgraded: false, // OAuth 用户默认未升级到 SRP + }, + }); + + // 创建用户 profile + const nickname = ( + userInfo.name || + userInfo.preferredUsername || + uniqueUsername + ).substring(0, 255); // 限制nickname长度为255个字符 + + await tx.userProfile.create({ + data: { + userId: newUser.id, + nickname, + intro: this.defaultIntro, + avatarId, + }, + }); + + // 创建 OAuth 连接 + await tx.userOAuthConnection.create({ + data: { + userId: newUser.id, + providerId, + providerUserId: userInfo.id, + rawProfile: userInfo as any, + }, + }); + + // 记录注册日志 + await tx.userRegisterLog.create({ + data: { + type: 'Success', + email: userInfo.email || '', + ip, + userAgent, + }, + }); + + // 记录登录日志 + await tx.userLoginLog.create({ + data: { + userId: newUser.id, + ip, + userAgent, + }, + }); + + return newUser; + }); + + this.logger.log( + `Created new user ${result.username} (ID: ${result.id}) for OAuth provider: ${providerId}`, + ); + + return [ + await this.getOAuthUserDtoById(result.id, result.id, ip, userAgent), + await this.createSession(result.id), + ]; + } + + /** + * 根据 OAuth 用户信息生成用户名基础 + */ + private generateOAuthUsername(userInfo: OAuthUserInfo): string { + if ( + userInfo.preferredUsername && + this.isValidUsername(userInfo.preferredUsername) + ) { + return userInfo.preferredUsername; + } + + if (userInfo.username && this.isValidUsername(userInfo.username)) { + return userInfo.username; + } + + if (userInfo.name) { + // 清理名称:去除特殊字符,转为小写 + const cleaned = userInfo.name + .replace(/[^a-zA-Z0-9_-]/g, '_') + .toLowerCase() + .replace(/_+/g, '_') + .replace(/^_|_$/g, ''); + + if (cleaned.length >= 4 && this.isValidUsername(cleaned)) { + return cleaned; + } + } + + // 如果都不可用,使用默认格式 + return `user_${userInfo.id}`.replace(/[^a-zA-Z0-9_-]/g, '_').toLowerCase(); + } + + /** + * 生成唯一用户名 + */ + private async generateUniqueUsername(baseUsername: string): Promise { + // 确保用户名长度符合要求 + let username = baseUsername; + if (username.length < 4) { + username = `user_${username}`; + } + if (username.length > 32) { + username = username.substring(0, 32); + } + + // 检查是否已存在 + let counter = 0; + let uniqueUsername = username; + + while (await this.isUsernameRegistered(uniqueUsername)) { + counter++; + const suffix = `_${counter}`; + const maxBaseLength = 32 - suffix.length; + uniqueUsername = username.substring(0, maxBaseLength) + suffix; + } + + return uniqueUsername; + } + + /** + * 生成随机密码 + */ + private generateRandomPassword(): string { + const chars = + 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*'; + let password = ''; + for (let i = 0; i < 16; i++) { + password += chars.charAt(Math.floor(Math.random() * chars.length)); + } + return password; + } + + /** + * 为用户创建默认 profile + */ + private async createDefaultProfileForUser(userId: number): Promise { + const avatarId = await this.avatarsService.getDefaultAvatarId(); + const user = await this.findUserRecordOrThrow(userId); + + await this.prismaService.userProfile.create({ + data: { + userId, + nickname: user.username, + intro: this.defaultIntro, + avatarId, + }, + }); + } } diff --git a/test/user.e2e-spec.ts b/test/user.e2e-spec.ts index e3d8fb93..011e2dd0 100644 --- a/test/user.e2e-spec.ts +++ b/test/user.e2e-spec.ts @@ -13,6 +13,9 @@ import session from 'express-session'; import { authenticator } from 'otplib'; import request from 'supertest'; import { AppModule } from '../src/app.module'; +import { AuthService } from '../src/auth/auth.service'; +import { OAuthService } from '../src/auth/oauth/oauth.service'; +import { OAuthError } from '../src/auth/oauth/oauth.types'; import { EmailService } from '../src/email/email.service'; const srpClient = new SrpClient(); @@ -160,6 +163,7 @@ async function verifySudoWithSRP( async function createLegacyUser(httpServer: HttpServer): Promise<{ username: string; password: string; + email: string; accessToken: string; refreshToken: string; userId: number; @@ -176,7 +180,7 @@ async function createLegacyUser(httpServer: HttpServer): Promise<{ const verificationCode = ( MockedEmailService.mock.instances[0].sendRegisterCode as jest.Mock - ).mock.calls[0][1]; + ).mock.calls.slice(-1)[0][1]; // Get the last call's verification code // 注册用户 const registerRes = await request(httpServer) @@ -198,6 +202,7 @@ async function createLegacyUser(httpServer: HttpServer): Promise<{ return { username, password, + email, accessToken: registerRes.body.data.accessToken, refreshToken, userId: registerRes.body.data.user.id, @@ -217,6 +222,12 @@ describe('User Module', () => { let TestToken: string; beforeAll(async () => { + // Set JWT secret and other env vars BEFORE module creation + process.env.JWT_SECRET = 'test-jwt-secret-for-oauth-tests'; + process.env.FRONTEND_BASE_URL = 'http://localhost:3000'; + process.env.FRONTEND_OAUTH_SUCCESS_PATH = '/oauth-success'; + process.env.FRONTEND_OAUTH_ERROR_PATH = '/oauth-error'; + const moduleFixture: TestingModule = await Test.createTestingModule({ imports: [AppModule], }).compile(); @@ -229,6 +240,7 @@ describe('User Module', () => { saveUninitialized: false, }), ); + await app.init(); }, 20000); @@ -1197,6 +1209,341 @@ describe('User Module', () => { }); }); + describe('OAuth Authentication', () => { + let oauthUser: { + userId: number; + username: string; + email: string; + accessToken: string; + refreshToken: string; + }; + + beforeAll(async () => { + // Mock OAuth service to return test data + const oauthService = app.get(OAuthService); + if (oauthService) { + jest + .spyOn(oauthService, 'getProvidersConfig') + .mockResolvedValue([ + { id: 'test', name: 'Test Provider', scope: ['read:user'] }, + ]); + jest + .spyOn(oauthService, 'generateAuthorizationUrl') + .mockResolvedValue( + 'https://test.com/oauth/authorize?client_id=test&redirect_uri=callback&response_type=code', + ); + jest + .spyOn(oauthService, 'handleCallback') + .mockResolvedValue('mock_access_token'); + jest.spyOn(oauthService, 'getUserInfo').mockResolvedValue({ + id: 'oauth-user-123', + email: `oauth-${Math.floor(Math.random() * 10000000000)}@ruc.edu.cn`, + name: 'OAuth Test User', + username: 'oauthuser', + preferredUsername: 'oauthuser', + }); + } + }); + + it('GET /users/auth/oauth/providers should return available providers', async () => { + const res = await request(app.getHttpServer()) + .get('/users/auth/oauth/providers') + .expect(200); + + expect(res.body.data.providers).toBeDefined(); + expect(Array.isArray(res.body.data.providers)).toBe(true); + }); + + it('GET /users/auth/oauth/login/:providerId should redirect to provider', async () => { + const res = await request(app.getHttpServer()) + .get('/users/auth/oauth/login/test') + .expect(302); + + expect(res.headers.location).toContain( + 'https://test.com/oauth/authorize', + ); + }); + + it('should return 404 for invalid provider', async () => { + // Override OAuth mocks to simulate invalid provider + const oauthService = app.get(OAuthService); + if (oauthService) { + jest + .spyOn(oauthService, 'generateAuthorizationUrl') + .mockRejectedValue( + new OAuthError( + 'OAuth provider not found', + 'invalid-provider', + 'validation', + ), + ); + } + + const res = await request(app.getHttpServer()) + .get('/users/auth/oauth/login/invalid-provider') + .expect(302); + + expect(res.headers.location).toContain('/oauth-error'); + expect(res.headers.location).toContain('error='); + expect(res.headers.location).toContain('OAuth%20provider%20not%20found'); + }); + + it('GET /users/auth/oauth/callback/:providerId should handle OAuth callback', async () => { + // Override OAuth mocks for this specific test + const oauthService = app.get(OAuthService); + if (oauthService) { + jest + .spyOn(oauthService, 'handleCallback') + .mockResolvedValue('mock_access_token'); + jest.spyOn(oauthService, 'getUserInfo').mockResolvedValue({ + id: 'oauth-user-123', + email: `oauth-${Math.floor(Math.random() * 10000000000)}@ruc.edu.cn`, + name: 'OAuth Test User', + username: 'oauthuser', + preferredUsername: 'oauthuser', + }); + } + + const agent = request.agent(app.getHttpServer()); + + const res = await agent + .get('/users/auth/oauth/callback/test?code=test-code&state=test-state') + .expect(302); + + // Should redirect to frontend success page + expect(res.headers.location).toContain('/oauth-success'); + expect(res.headers.location).toContain('token='); + expect(res.headers.location).toContain('email='); + + // Should set refresh token cookie + expect(res.headers['set-cookie']).toBeDefined(); + const cookies = Array.isArray(res.headers['set-cookie']) + ? res.headers['set-cookie'] + : [res.headers['set-cookie']]; + expect( + cookies.some((cookie: string) => cookie.includes('REFRESH_TOKEN=')), + ).toBe(true); + + // Extract token from redirect URL for further tests + const callbackUrlParams = new URLSearchParams( + res.headers.location.split('?')[1], + ); + const accessToken = callbackUrlParams.get('token'); + const email = callbackUrlParams.get('email'); + + expect(accessToken).toBeDefined(); + expect(email).toBeDefined(); + + // First decode token to get user ID + const authService = app.get(AuthService); + const payload = authService.decode(accessToken!); + const userId = payload.authorization.userId; + + // Verify the token works by getting user info + const userRes = await agent + .get(`/users/${userId}`) + .set('Authorization', `Bearer ${accessToken}`) + .expect(200); + + // Verify email from URL parameter (UserDto doesn't include email field) + expect(decodeURIComponent(email!)).toMatch(/^oauth-\d+@ruc\.edu\.cn$/); + + oauthUser = { + userId: userRes.body.data.user.id, + username: userRes.body.data.user.username, + email: decodeURIComponent(email!), // Use email from URL parameter + accessToken: accessToken!, + refreshToken: + cookies + .find((c: string) => c.includes('REFRESH_TOKEN=')) + ?.split('=')[1] + ?.split(';')[0] || '', + }; + }); + + it('should handle OAuth callback with invalid code', async () => { + const oauthService = app.get(OAuthService); + if (oauthService) { + // Override OAuth mocks for error scenario + jest + .spyOn(oauthService, 'handleCallback') + .mockRejectedValue(new Error('Invalid authorization code')); + } + + const res = await request(app.getHttpServer()) + .get('/users/auth/oauth/callback/test?code=invalid-code') + .expect(302); + + expect(res.headers.location).toContain('/oauth-error'); + expect(res.headers.location).toContain('error='); + }); + + it('should handle OAuth callback for invalid provider', async () => { + const res = await request(app.getHttpServer()) + .get('/users/auth/oauth/callback/invalid-provider?code=test-code') + .expect(302); + + expect(res.headers.location).toContain('/oauth-error'); + expect(res.headers.location).toContain('error='); + }); + + it('should bind OAuth account to existing user by email', async () => { + // Create a regular user first + const regularUser = await createLegacyUser(app.getHttpServer()); + + // Add debug logging to verify email matching + console.log(`Test: regularUser email = "${regularUser.email}"`); + console.log(`Test: regularUser userId = ${regularUser.userId}`); + + // Override OAuth mocks to return the regularUser's email + const oauthService = app.get(OAuthService); + if (oauthService) { + jest + .spyOn(oauthService, 'handleCallback') + .mockResolvedValue('mock_access_token_existing'); + jest.spyOn(oauthService, 'getUserInfo').mockImplementation(async () => { + const userInfo = { + id: `oauth-binding-test-${Date.now()}`, // Use unique ID to avoid conflicts + email: regularUser.email, // Use the exact email from regularUser + name: 'OAuth Existing User', + username: 'oauthexisting', + preferredUsername: 'oauthexisting', + }; + console.log( + `Mock: OAuth getUserInfo returning email = "${userInfo.email}"`, + ); + return userInfo; + }); + } + + // Wait a bit to ensure user creation is fully committed + await new Promise((resolve) => setTimeout(resolve, 100)); + + const agent = request.agent(app.getHttpServer()); + const res = await agent + .get('/users/auth/oauth/callback/test?code=test-code-binding-test') + .expect(302); + + expect(res.headers.location).toContain('/oauth-success'); + + // Verify the OAuth login returns the same user + const bindingUrlParams = new URLSearchParams( + res.headers.location.split('?')[1], + ); + const accessToken = bindingUrlParams.get('token'); + + // Get user ID from token + const authService = app.get(AuthService); + const payload = authService.decode(accessToken!); + const userId = payload.authorization.userId; + + const userRes = await agent + .get(`/users/${userId}`) + .set('Authorization', `Bearer ${accessToken}`) + .expect(200); + + // The main assertion is that OAuth binding should return the same user ID + expect(userRes.body.data.user.id).toBe(regularUser.userId); + // Email should also match (from URL parameter, not from user object since UserDto doesn't include email) + expect(decodeURIComponent(bindingUrlParams.get('email')!)).toBe( + regularUser.email, + ); + }); + + it('should handle OAuth login for existing OAuth connection', async () => { + // Override OAuth mocks to return same user info as before + const oauthService = app.get(OAuthService); + if (oauthService) { + jest + .spyOn(oauthService, 'handleCallback') + .mockResolvedValue('mock_access_token_existing_connection'); + jest.spyOn(oauthService, 'getUserInfo').mockResolvedValue({ + id: 'oauth-user-123', // Same ID as first test + email: oauthUser?.email || 'oauth-test@test.com', + name: 'OAuth Test User Updated', + username: 'oauthuser', + preferredUsername: 'oauthuser', + }); + } + + const agent = request.agent(app.getHttpServer()); + const res = await agent + .get('/users/auth/oauth/callback/test?code=test-code-existing') + .expect(302); + + expect(res.headers.location).toContain('/oauth-success'); + + const existingOAuthUrlParams = new URLSearchParams( + res.headers.location.split('?')[1], + ); + const accessToken = existingOAuthUrlParams.get('token')!; + + // Get user ID from token + const authService = app.get(AuthService); + const payload = authService.decode(accessToken); + const userId = payload.authorization.userId; + + const userRes = await agent + .get(`/users/${userId}`) + .set('Authorization', `Bearer ${accessToken}`) + .expect(200); + + if (oauthUser) { + expect(userRes.body.data.user.id).toBe(oauthUser.userId); + // Verify email from URL parameter (UserDto doesn't include email field) + expect(decodeURIComponent(existingOAuthUrlParams.get('email')!)).toBe( + oauthUser.email, + ); + } + }); + + it('should create new user for OAuth without email', async () => { + const oauthService = app.get(OAuthService); + if (oauthService) { + // Override OAuth mocks for no-email scenario + jest + .spyOn(oauthService, 'handleCallback') + .mockResolvedValue('mock_access_token_no_email'); + jest.spyOn(oauthService, 'getUserInfo').mockResolvedValue({ + id: 'oauth-user-no-email', + name: 'OAuth No Email User', + username: 'oauthnoemail', + preferredUsername: 'oauthnoemail', + }); + } + + const agent = request.agent(app.getHttpServer()); + const res = await agent + .get('/users/auth/oauth/callback/test?code=test-code-no-email') + .expect(302); + + expect(res.headers.location).toContain('/oauth-success'); + + const noEmailUrlParams = new URLSearchParams( + res.headers.location.split('?')[1], + ); + const accessToken = noEmailUrlParams.get('token')!; + + // Get user ID from token + const authService = app.get(AuthService); + const payload = authService.decode(accessToken); + const userId = payload.authorization.userId; + + const userRes = await agent + .get(`/users/${userId}`) + .set('Authorization', `Bearer ${accessToken}`) + .expect(200); + + // For users without email, check that username is passed in URL parameter + const callbackUrlParams = new URLSearchParams( + res.headers.location.split('?')[1], + ); + const emailParam = callbackUrlParams.get('email'); + expect(decodeURIComponent(emailParam!)).toContain('oauthnoemail'); + expect(userRes.body.data.user.username).toContain('oauthnoemail'); + }); + }); + afterAll(async () => { await app.close(); }); diff --git a/tsconfig.build.json b/tsconfig.build.json index 64f86c6b..f8fed909 100644 --- a/tsconfig.build.json +++ b/tsconfig.build.json @@ -1,4 +1,4 @@ { "extends": "./tsconfig.json", - "exclude": ["node_modules", "test", "dist", "**/*spec.ts"] + "exclude": ["node_modules", "test", "dist", "**/*spec.ts", "oauth-providers"] }