Skip to content

Commit c693a61

Browse files
RockChinQautofix-ci[bot]Yessenia-d
authored andcommitted
feat: oauth provider (#24206)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: yessenia <[email protected]>
1 parent 045792e commit c693a61

File tree

32 files changed

+757
-22
lines changed

32 files changed

+757
-22
lines changed

api/controllers/console/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
)
7171

7272
# Import auth controllers
73-
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth
73+
from .auth import activate, data_source_bearer_auth, data_source_oauth, forgot_password, login, oauth, oauth_server
7474

7575
# Import billing controllers
7676
from .billing import billing, compliance
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from functools import wraps
2+
from typing import cast
3+
4+
import flask_login
5+
from flask import request
6+
from flask_restx import Resource, reqparse
7+
from werkzeug.exceptions import BadRequest, NotFound
8+
9+
from controllers.console.wraps import account_initialization_required, setup_required
10+
from core.model_runtime.utils.encoders import jsonable_encoder
11+
from libs.login import login_required
12+
from models.account import Account
13+
from models.model import OAuthProviderApp
14+
from services.oauth_server import OAUTH_ACCESS_TOKEN_EXPIRES_IN, OAuthGrantType, OAuthServerService
15+
16+
from .. import api
17+
18+
19+
def oauth_server_client_id_required(view):
20+
@wraps(view)
21+
def decorated(*args, **kwargs):
22+
parser = reqparse.RequestParser()
23+
parser.add_argument("client_id", type=str, required=True, location="json")
24+
parsed_args = parser.parse_args()
25+
client_id = parsed_args.get("client_id")
26+
if not client_id:
27+
raise BadRequest("client_id is required")
28+
29+
oauth_provider_app = OAuthServerService.get_oauth_provider_app(client_id)
30+
if not oauth_provider_app:
31+
raise NotFound("client_id is invalid")
32+
33+
kwargs["oauth_provider_app"] = oauth_provider_app
34+
35+
return view(*args, **kwargs)
36+
37+
return decorated
38+
39+
40+
def oauth_server_access_token_required(view):
41+
@wraps(view)
42+
def decorated(*args, **kwargs):
43+
oauth_provider_app = kwargs.get("oauth_provider_app")
44+
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
45+
raise BadRequest("Invalid oauth_provider_app")
46+
47+
if not request.headers.get("Authorization"):
48+
raise BadRequest("Authorization is required")
49+
50+
authorization_header = request.headers.get("Authorization")
51+
if not authorization_header:
52+
raise BadRequest("Authorization header is required")
53+
54+
parts = authorization_header.split(" ")
55+
if len(parts) != 2:
56+
raise BadRequest("Invalid Authorization header format")
57+
58+
token_type = parts[0]
59+
if token_type != "Bearer":
60+
raise BadRequest("token_type is invalid")
61+
62+
access_token = parts[1]
63+
if not access_token:
64+
raise BadRequest("access_token is required")
65+
66+
account = OAuthServerService.validate_oauth_access_token(oauth_provider_app.client_id, access_token)
67+
if not account:
68+
raise BadRequest("access_token or client_id is invalid")
69+
70+
kwargs["account"] = account
71+
72+
return view(*args, **kwargs)
73+
74+
return decorated
75+
76+
77+
class OAuthServerAppApi(Resource):
78+
@setup_required
79+
@oauth_server_client_id_required
80+
def post(self, oauth_provider_app: OAuthProviderApp):
81+
parser = reqparse.RequestParser()
82+
parser.add_argument("redirect_uri", type=str, required=True, location="json")
83+
parsed_args = parser.parse_args()
84+
redirect_uri = parsed_args.get("redirect_uri")
85+
86+
# check if redirect_uri is valid
87+
if redirect_uri not in oauth_provider_app.redirect_uris:
88+
raise BadRequest("redirect_uri is invalid")
89+
90+
return jsonable_encoder(
91+
{
92+
"app_icon": oauth_provider_app.app_icon,
93+
"app_label": oauth_provider_app.app_label,
94+
"scope": oauth_provider_app.scope,
95+
}
96+
)
97+
98+
99+
class OAuthServerUserAuthorizeApi(Resource):
100+
@setup_required
101+
@login_required
102+
@account_initialization_required
103+
@oauth_server_client_id_required
104+
def post(self, oauth_provider_app: OAuthProviderApp):
105+
account = cast(Account, flask_login.current_user)
106+
user_account_id = account.id
107+
108+
code = OAuthServerService.sign_oauth_authorization_code(oauth_provider_app.client_id, user_account_id)
109+
return jsonable_encoder(
110+
{
111+
"code": code,
112+
}
113+
)
114+
115+
116+
class OAuthServerUserTokenApi(Resource):
117+
@setup_required
118+
@oauth_server_client_id_required
119+
def post(self, oauth_provider_app: OAuthProviderApp):
120+
parser = reqparse.RequestParser()
121+
parser.add_argument("grant_type", type=str, required=True, location="json")
122+
parser.add_argument("code", type=str, required=False, location="json")
123+
parser.add_argument("client_secret", type=str, required=False, location="json")
124+
parser.add_argument("redirect_uri", type=str, required=False, location="json")
125+
parser.add_argument("refresh_token", type=str, required=False, location="json")
126+
parsed_args = parser.parse_args()
127+
128+
grant_type = OAuthGrantType(parsed_args["grant_type"])
129+
130+
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
131+
if not parsed_args["code"]:
132+
raise BadRequest("code is required")
133+
134+
if parsed_args["client_secret"] != oauth_provider_app.client_secret:
135+
raise BadRequest("client_secret is invalid")
136+
137+
if parsed_args["redirect_uri"] not in oauth_provider_app.redirect_uris:
138+
raise BadRequest("redirect_uri is invalid")
139+
140+
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
141+
grant_type, code=parsed_args["code"], client_id=oauth_provider_app.client_id
142+
)
143+
return jsonable_encoder(
144+
{
145+
"access_token": access_token,
146+
"token_type": "Bearer",
147+
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
148+
"refresh_token": refresh_token,
149+
}
150+
)
151+
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
152+
if not parsed_args["refresh_token"]:
153+
raise BadRequest("refresh_token is required")
154+
155+
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
156+
grant_type, refresh_token=parsed_args["refresh_token"], client_id=oauth_provider_app.client_id
157+
)
158+
return jsonable_encoder(
159+
{
160+
"access_token": access_token,
161+
"token_type": "Bearer",
162+
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
163+
"refresh_token": refresh_token,
164+
}
165+
)
166+
else:
167+
raise BadRequest("invalid grant_type")
168+
169+
170+
class OAuthServerUserAccountApi(Resource):
171+
@setup_required
172+
@oauth_server_client_id_required
173+
@oauth_server_access_token_required
174+
def post(self, oauth_provider_app: OAuthProviderApp, account: Account):
175+
return jsonable_encoder(
176+
{
177+
"name": account.name,
178+
"email": account.email,
179+
"avatar": account.avatar,
180+
"interface_language": account.interface_language,
181+
"timezone": account.timezone,
182+
}
183+
)
184+
185+
186+
api.add_resource(OAuthServerAppApi, "/oauth/provider")
187+
api.add_resource(OAuthServerUserAuthorizeApi, "/oauth/provider/authorize")
188+
api.add_resource(OAuthServerUserTokenApi, "/oauth/provider/token")
189+
api.add_resource(OAuthServerUserAccountApi, "/oauth/provider/account")
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""empty message
2+
3+
Revision ID: 8d289573e1da
4+
Revises: fa8b0fa6f407
5+
Create Date: 2025-08-20 17:47:17.015695
6+
7+
"""
8+
from alembic import op
9+
import models as models
10+
import sqlalchemy as sa
11+
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '8d289573e1da'
15+
down_revision = '0e154742a5fa'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade():
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table('oauth_provider_apps',
23+
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
24+
sa.Column('app_icon', sa.String(length=255), nullable=False),
25+
sa.Column('app_label', sa.JSON(), server_default='{}', nullable=False),
26+
sa.Column('client_id', sa.String(length=255), nullable=False),
27+
sa.Column('client_secret', sa.String(length=255), nullable=False),
28+
sa.Column('redirect_uris', sa.JSON(), server_default='[]', nullable=False),
29+
sa.Column('scope', sa.String(length=255), server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"), nullable=False),
30+
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
31+
sa.PrimaryKeyConstraint('id', name='oauth_provider_app_pkey')
32+
)
33+
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
34+
batch_op.create_index('oauth_provider_app_client_id_idx', ['client_id'], unique=False)
35+
36+
# ### end Alembic commands ###
37+
38+
39+
def downgrade():
40+
# ### commands auto generated by Alembic - please adjust! ###
41+
with op.batch_alter_table('oauth_provider_apps', schema=None) as batch_op:
42+
batch_op.drop_index('oauth_provider_app_client_id_idx')
43+
44+
op.drop_table('oauth_provider_apps')
45+
# ### end Alembic commands ###

api/models/model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,32 @@ def tenant(self):
580580
return tenant
581581

582582

583+
class OAuthProviderApp(Base):
584+
"""
585+
Globally shared OAuth provider app information.
586+
Only for Dify Cloud.
587+
"""
588+
589+
__tablename__ = "oauth_provider_apps"
590+
__table_args__ = (
591+
sa.PrimaryKeyConstraint("id", name="oauth_provider_app_pkey"),
592+
sa.Index("oauth_provider_app_client_id_idx", "client_id"),
593+
)
594+
595+
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
596+
app_icon = mapped_column(String(255), nullable=False)
597+
app_label = mapped_column(sa.JSON, nullable=False, server_default="{}")
598+
client_id = mapped_column(String(255), nullable=False)
599+
client_secret = mapped_column(String(255), nullable=False)
600+
redirect_uris = mapped_column(sa.JSON, nullable=False, server_default="[]")
601+
scope = mapped_column(
602+
String(255),
603+
nullable=False,
604+
server_default=sa.text("'read:name read:email read:avatar read:interface_language read:timezone'"),
605+
)
606+
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
607+
608+
583609
class Conversation(Base):
584610
__tablename__ = "conversations"
585611
__table_args__ = (

api/services/oauth_server.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import enum
2+
import uuid
3+
4+
from sqlalchemy import select
5+
from sqlalchemy.orm import Session
6+
from werkzeug.exceptions import BadRequest
7+
8+
from extensions.ext_database import db
9+
from extensions.ext_redis import redis_client
10+
from models.account import Account
11+
from models.model import OAuthProviderApp
12+
from services.account_service import AccountService
13+
14+
15+
class OAuthGrantType(enum.StrEnum):
16+
AUTHORIZATION_CODE = "authorization_code"
17+
REFRESH_TOKEN = "refresh_token"
18+
19+
20+
OAUTH_AUTHORIZATION_CODE_REDIS_KEY = "oauth_provider:{client_id}:authorization_code:{code}"
21+
OAUTH_ACCESS_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:access_token:{token}"
22+
OAUTH_ACCESS_TOKEN_EXPIRES_IN = 60 * 60 * 12 # 12 hours
23+
OAUTH_REFRESH_TOKEN_REDIS_KEY = "oauth_provider:{client_id}:refresh_token:{token}"
24+
OAUTH_REFRESH_TOKEN_EXPIRES_IN = 60 * 60 * 24 * 30 # 30 days
25+
26+
27+
class OAuthServerService:
28+
@staticmethod
29+
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
30+
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
31+
32+
with Session(db.engine) as session:
33+
return session.execute(query).scalar_one_or_none()
34+
35+
@staticmethod
36+
def sign_oauth_authorization_code(client_id: str, user_account_id: str) -> str:
37+
code = str(uuid.uuid4())
38+
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
39+
redis_client.set(redis_key, user_account_id, ex=60 * 10) # 10 minutes
40+
return code
41+
42+
@staticmethod
43+
def sign_oauth_access_token(
44+
grant_type: OAuthGrantType,
45+
code: str = "",
46+
client_id: str = "",
47+
refresh_token: str = "",
48+
) -> tuple[str, str]:
49+
match grant_type:
50+
case OAuthGrantType.AUTHORIZATION_CODE:
51+
redis_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id=client_id, code=code)
52+
user_account_id = redis_client.get(redis_key)
53+
if not user_account_id:
54+
raise BadRequest("invalid code")
55+
56+
# delete code
57+
redis_client.delete(redis_key)
58+
59+
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
60+
refresh_token = OAuthServerService._sign_oauth_refresh_token(client_id, user_account_id)
61+
return access_token, refresh_token
62+
case OAuthGrantType.REFRESH_TOKEN:
63+
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=refresh_token)
64+
user_account_id = redis_client.get(redis_key)
65+
if not user_account_id:
66+
raise BadRequest("invalid refresh token")
67+
68+
access_token = OAuthServerService._sign_oauth_access_token(client_id, user_account_id)
69+
return access_token, refresh_token
70+
71+
@staticmethod
72+
def _sign_oauth_access_token(client_id: str, user_account_id: str) -> str:
73+
token = str(uuid.uuid4())
74+
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
75+
redis_client.set(redis_key, user_account_id, ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN)
76+
return token
77+
78+
@staticmethod
79+
def _sign_oauth_refresh_token(client_id: str, user_account_id: str) -> str:
80+
token = str(uuid.uuid4())
81+
redis_key = OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
82+
redis_client.set(redis_key, user_account_id, ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN)
83+
return token
84+
85+
@staticmethod
86+
def validate_oauth_access_token(client_id: str, token: str) -> Account | None:
87+
redis_key = OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id=client_id, token=token)
88+
user_account_id = redis_client.get(redis_key)
89+
if not user_account_id:
90+
return None
91+
92+
user_id_str = user_account_id.decode("utf-8")
93+
94+
return AccountService.load_user(user_id_str)

web/app/account/account-page/AvatarWithEdit.tsx renamed to web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx

File renamed without changes.

web/app/account/account-page/email-change-modal.tsx renamed to web/app/account/(commonLayout)/account-page/email-change-modal.tsx

File renamed without changes.
File renamed without changes.

web/app/account/delete-account/components/check-email.tsx renamed to web/app/account/(commonLayout)/delete-account/components/check-email.tsx

File renamed without changes.

0 commit comments

Comments
 (0)