Skip to content

Commit 11b2236

Browse files
authored
Merge pull request #278 from 0x0elliot/main
New openID flow
2 parents 513836e + 0c73d9c commit 11b2236

File tree

4 files changed

+1248
-220
lines changed

4 files changed

+1248
-220
lines changed

db-connector.go

Lines changed: 182 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ import (
2525
"sync"
2626
"time"
2727

28+
runtimeDebug "runtime/debug"
29+
2830
"cloud.google.com/go/datastore"
2931
"github.com/Masterminds/semver"
3032
"github.com/bradfitz/slice"
3133
uuid "github.com/satori/go.uuid"
32-
runtimeDebug "runtime/debug"
3334

3435
//"github.com/frikky/kin-openapi/openapi3"
3536
"github.com/patrickmn/go-cache"
@@ -1739,7 +1740,7 @@ func Fixexecution(ctx context.Context, workflowExecution WorkflowExecution) (Wor
17391740
result = innerresult
17401741
break
17411742

1742-
//} else if innerresult.Status == "WAITING" || innerresult.Status == "SUCCESS" && (action.AppName == "AI Agent" || action.AppName == "Shuffle Agent") {
1743+
//} else if innerresult.Status == "WAITING" || innerresult.Status == "SUCCESS" && (action.AppName == "AI Agent" || action.AppName == "Shuffle Agent") {
17431744
} else if (innerresult.Status == "WAITING" || innerresult.Status == "SUCCESS") && (innerresult.Action.AppName == "AI Agent" || innerresult.Action.AppName == "Shuffle Agent") {
17441745
// Auto fixing decision data based on cache for better decisionmaking
17451746
// Map the result into AgentOutput to check decisions
@@ -4025,7 +4026,7 @@ func GetOrg(ctx context.Context, id string) (*Org, error) {
40254026
if id == "public" {
40264027
//return &Org{}, errors.New("'public' org is used for Singul action without being logged in. Not relevant.")
40274028
return &Org{
4028-
Id: "public",
4029+
Id: "public",
40294030
Name: "Public",
40304031
}, nil
40314032
}
@@ -5371,6 +5372,157 @@ func FindWorkflowAppByName(ctx context.Context, appName string) ([]WorkflowApp,
53715372
return apps, nil
53725373
}
53735374

5375+
// FindUserBySSOIdentity finds a user by their SSO identity using efficient database queries
5376+
// Also validates that the clientID matches the org's configured SSO
5377+
func FindUserBySSOIdentity(ctx context.Context, sub, clientID, orgID, email string) (User, error) {
5378+
var emptyUser User
5379+
5380+
// Check if Sub is empty - user hasn't connected SSO yet
5381+
if sub == "" {
5382+
return emptyUser, errors.New("connect user account with SSO first")
5383+
}
5384+
5385+
if clientID == "" || orgID == "" || email == "" {
5386+
return emptyUser, errors.New("clientID, orgID, and email are all required")
5387+
}
5388+
5389+
// Verify the clientID actually matches the org's SSO configuration
5390+
org, err := GetOrg(ctx, orgID)
5391+
if err != nil {
5392+
return emptyUser, fmt.Errorf("failed to get org %s: %w", orgID, err)
5393+
}
5394+
5395+
if org.SSOConfig.OpenIdClientId != clientID {
5396+
return emptyUser, fmt.Errorf("clientID %s does not match org's configured SSO client ID %s", clientID, org.SSOConfig.OpenIdClientId)
5397+
}
5398+
5399+
// Normalize email for comparison
5400+
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
5401+
5402+
nameKey := "Users"
5403+
var users []User
5404+
5405+
if project.DbType == "opensearch" {
5406+
// OpenSearch query to find users with matching SSO info
5407+
var buf bytes.Buffer
5408+
query := map[string]interface{}{
5409+
"size": 10,
5410+
"query": map[string]interface{}{
5411+
"bool": map[string]interface{}{
5412+
"must": []map[string]interface{}{
5413+
{
5414+
"term": map[string]interface{}{
5415+
"username.keyword": normalizedEmail,
5416+
},
5417+
},
5418+
{
5419+
"nested": map[string]interface{}{
5420+
"path": "sso_infos",
5421+
"query": map[string]interface{}{
5422+
"bool": map[string]interface{}{
5423+
"must": []map[string]interface{}{
5424+
{
5425+
"term": map[string]interface{}{
5426+
"sso_infos.sub.keyword": sub,
5427+
},
5428+
},
5429+
{
5430+
"term": map[string]interface{}{
5431+
"sso_infos.client_id.keyword": clientID,
5432+
},
5433+
},
5434+
{
5435+
"term": map[string]interface{}{
5436+
"sso_infos.org_id.keyword": orgID,
5437+
},
5438+
},
5439+
},
5440+
},
5441+
},
5442+
},
5443+
},
5444+
},
5445+
},
5446+
},
5447+
}
5448+
5449+
if err := json.NewEncoder(&buf).Encode(query); err != nil {
5450+
return emptyUser, fmt.Errorf("failed to encode opensearch query: %w", err)
5451+
}
5452+
5453+
resp, err := project.Es.Search(ctx, &opensearchapi.SearchReq{
5454+
Indices: []string{strings.ToLower(GetESIndexPrefix(nameKey))},
5455+
Body: &buf,
5456+
Params: opensearchapi.SearchParams{
5457+
TrackTotalHits: true,
5458+
},
5459+
})
5460+
if err != nil {
5461+
return emptyUser, fmt.Errorf("opensearch query failed: %w", err)
5462+
}
5463+
5464+
res := resp.Inspect().Response
5465+
defer res.Body.Close()
5466+
if res.StatusCode != 200 && res.StatusCode != 201 {
5467+
return emptyUser, fmt.Errorf("opensearch error response: %d", res.StatusCode)
5468+
}
5469+
5470+
var r map[string]interface{}
5471+
if err := json.NewDecoder(res.Body).Decode(&r); err != nil {
5472+
return emptyUser, fmt.Errorf("failed to parse opensearch response: %w", err)
5473+
}
5474+
5475+
hits, ok := r["hits"].(map[string]interface{})["hits"].([]interface{})
5476+
if !ok {
5477+
return emptyUser, errors.New("no matching user found")
5478+
}
5479+
5480+
for _, hit := range hits {
5481+
if source, ok := hit.(map[string]interface{})["_source"]; ok {
5482+
data, _ := json.Marshal(source)
5483+
var user User
5484+
if err := json.Unmarshal(data, &user); err == nil {
5485+
users = append(users, user)
5486+
}
5487+
}
5488+
}
5489+
} else {
5490+
// Datastore query - need to get by email first then validate SSO info
5491+
// (Datastore doesn't support nested queries efficiently)
5492+
q := datastore.NewQuery(nameKey).Filter("Username =", normalizedEmail).Limit(10)
5493+
_, err := project.Dbclient.GetAll(ctx, q, &users)
5494+
if err != nil {
5495+
return emptyUser, fmt.Errorf("datastore query failed: %w", err)
5496+
}
5497+
5498+
// Filter users to find exact SSO match
5499+
var matchingUsers []User
5500+
for _, user := range users {
5501+
for _, ssoInfo := range user.SSOInfos {
5502+
if ssoInfo.Sub == sub &&
5503+
ssoInfo.ClientID == clientID &&
5504+
ssoInfo.OrgID == orgID {
5505+
matchingUsers = append(matchingUsers, user)
5506+
break
5507+
}
5508+
}
5509+
}
5510+
users = matchingUsers
5511+
}
5512+
5513+
if len(users) == 0 {
5514+
return emptyUser, fmt.Errorf("no user found with Sub=%s, ClientID=%s, OrgID=%s, Email=%s", sub, clientID, orgID, normalizedEmail)
5515+
}
5516+
5517+
if len(users) > 1 {
5518+
log.Printf("[CRITICAL] Multiple users found with same SSO identity: Sub=%s, ClientID=%s, OrgID=%s, Email=%s",
5519+
sub, clientID, orgID, normalizedEmail)
5520+
return emptyUser, errors.New("multiple users found with same SSO identity - data integrity issue")
5521+
}
5522+
5523+
return users[0], nil
5524+
}
5525+
53745526
func FindGeneratedUser(ctx context.Context, username string) ([]User, error) {
53755527
var users []User
53765528

@@ -5480,7 +5632,7 @@ func FindUser(ctx context.Context, username string) ([]User, error) {
54805632
query := map[string]interface{}{
54815633
"size": 1000,
54825634
"query": map[string]interface{}{
5483-
"bool": map[string]interface{} {
5635+
"bool": map[string]interface{}{
54845636
"must": map[string]interface{}{
54855637
"match": map[string]interface{}{
54865638
"username": username,
@@ -5664,6 +5816,32 @@ func GetUser(ctx context.Context, username string) (*User, error) {
56645816
return curUser, nil
56655817
}
56665818

5819+
func (u *User) GetSSOInfo(orgID string) (SSOInfo, bool) {
5820+
for _, sso := range u.SSOInfos {
5821+
if sso.OrgID == orgID {
5822+
return sso, true
5823+
}
5824+
}
5825+
return SSOInfo{}, false
5826+
}
5827+
5828+
func (u *User) SetSSOInfo(orgID string, ssoInfo SSOInfo) {
5829+
ssoInfo.OrgID = orgID
5830+
for i, sso := range u.SSOInfos {
5831+
if sso.OrgID == orgID {
5832+
u.SSOInfos[i] = ssoInfo
5833+
return
5834+
}
5835+
}
5836+
u.SSOInfos = append(u.SSOInfos, ssoInfo)
5837+
}
5838+
5839+
func (u *User) InitSSOInfos() {
5840+
if u.SSOInfos == nil {
5841+
u.SSOInfos = []SSOInfo{}
5842+
}
5843+
}
5844+
56675845
func SetUser(ctx context.Context, user *User, updateOrg bool) error {
56685846
log.Printf("[INFO] Updating user %s (%s) that has the role %s with %d apps and %d orgs. Org updater: %t", user.Username, user.Id, user.Role, len(user.PrivateApps), len(user.Orgs), updateOrg)
56695847
parsedKey := user.Id

0 commit comments

Comments
 (0)