diff --git a/auth/tenant_mgt_test.go b/auth/tenant_mgt_test.go index 77c26851..e7d56cc9 100644 --- a/auth/tenant_mgt_test.go +++ b/auth/tenant_mgt_test.go @@ -90,6 +90,34 @@ func TestTenantGetUser(t *testing.T) { } } +func TestTenantQueryUsers(t *testing.T) { + resp := `{ + "usersInfo": [], + "recordsCount": "0" + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + tenantClient, err := s.Client.TenantManager.AuthForTenant("test-tenant") + if err != nil { + t.Fatalf("Failed to create tenant client: %v", err) + } + + query := &QueryUsersRequest{ + ReturnUserInfo: true, + } + + _, err = tenantClient.QueryUsers(context.Background(), query) + if err != nil { + t.Fatalf("QueryUsers() with tenant client = %v", err) + } + + wantPath := "/projects/mock-project-id/tenants/test-tenant/accounts:query" + if s.Req[0].RequestURI != wantPath { + t.Errorf("QueryUsers() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) + } +} + func TestTenantGetUserByEmail(t *testing.T) { s := echoServer(testGetUserResponse, t) defer s.Close() diff --git a/auth/user_mgt.go b/auth/user_mgt.go index 63a5c381..6e8101d9 100644 --- a/auth/user_mgt.go +++ b/auth/user_mgt.go @@ -838,6 +838,61 @@ type getAccountInfoResponse struct { Users []*userQueryResponse `json:"users"` } +// QueryUserInfoResponse is the response structure for the accounts:query endpoint. +type QueryUserInfoResponse struct { + Users []*UserRecord + Count string +} + +type queryUsersResponse struct { + Users []*userQueryResponse `json:"usersInfo,omitempty"` + Count string `json:"recordsCount,omitempty"` +} + +// SQLExpression is a query condition used to filter results. +type SQLExpression struct { + Email string `json:"email,omitempty"` + UserID string `json:"userId,omitempty"` + PhoneNumber string `json:"phoneNumber,omitempty"` +} + +// QueryUsersRequest is the request structure for the accounts:query endpoint. +type QueryUsersRequest struct { + ReturnUserInfo bool `json:"returnUserInfo"` + Limit string `json:"limit,omitempty"` + Offset string `json:"offset,omitempty"` + SortBy string `json:"sortBy,omitempty"` + Order string `json:"order,omitempty"` + TenantID string `json:"tenantId,omitempty"` + Expression []*SQLExpression `json:"expression,omitempty"` +} + +// SortByField is a field to use for sorting user accounts. +type SortByField string + +const ( + // UserID sorts results by userId. + UserID SortByField = "USER_ID" + // Name sorts results by name. + Name SortByField = "NAME" + // CreatedAt sorts results by createdAt. + CreatedAt SortByField = "CREATED_AT" + // LastLoginAt sorts results by lastLoginAt. + LastLoginAt SortByField = "LAST_LOGIN_AT" + // UserEmail sorts results by userEmail. + UserEmail SortByField = "USER_EMAIL" +) + +// Order is an order for sorting query results. +type Order string + +const ( + // Asc sorts in ascending order. + Asc Order = "ASC" + // Desc sorts in descending order. + Desc Order = "DESC" +) + func (c *baseClient) getUser(ctx context.Context, query *userQuery) (*UserRecord, error) { var parsed getAccountInfoResponse resp, err := c.post(ctx, "/accounts:lookup", query.build(), &parsed) @@ -1311,6 +1366,33 @@ type DeleteUsersErrorInfo struct { // array of errors that correspond to the failed deletions. An error is // returned if any of the identifiers are invalid or if more than 1000 // identifiers are specified. +// QueryUsers queries for user accounts based on the provided query configuration. +func (c *baseClient) QueryUsers(ctx context.Context, query *QueryUsersRequest) (*QueryUserInfoResponse, error) { + if query == nil { + return nil, fmt.Errorf("query request must not be nil") + } + + var parsed queryUsersResponse + _, err := c.post(ctx, "/accounts:query", query, &parsed) + if err != nil { + return nil, err + } + + var userRecords []*UserRecord + for _, user := range parsed.Users { + userRecord, err := user.makeUserRecord() + if err != nil { + return nil, fmt.Errorf("error while parsing response: %w", err) + } + userRecords = append(userRecords, userRecord) + } + + return &QueryUserInfoResponse{ + Users: userRecords, + Count: parsed.Count, + }, nil +} + func (c *baseClient) DeleteUsers(ctx context.Context, uids []string) (*DeleteUsersResult, error) { if len(uids) == 0 { return &DeleteUsersResult{}, nil diff --git a/auth/user_mgt_test.go b/auth/user_mgt_test.go index 53ccdc58..922bf4ac 100644 --- a/auth/user_mgt_test.go +++ b/auth/user_mgt_test.go @@ -1899,6 +1899,112 @@ func TestDeleteUsers(t *testing.T) { }) } +func TestQueryUsers(t *testing.T) { + resp := `{ + "usersInfo": [{ + "localId": "testuser", + "email": "testuser@example.com", + "phoneNumber": "+1234567890", + "emailVerified": true, + "displayName": "Test User", + "photoUrl": "http://www.example.com/testuser/photo.png", + "validSince": "1494364393", + "disabled": false, + "createdAt": "1234567890000", + "lastLoginAt": "1233211232000", + "customAttributes": "{\"admin\": true, \"package\": \"gold\"}", + "tenantId": "testTenant", + "providerUserInfo": [{ + "providerId": "password", + "displayName": "Test User", + "photoUrl": "http://www.example.com/testuser/photo.png", + "email": "testuser@example.com", + "rawId": "testuid" + }, { + "providerId": "phone", + "phoneNumber": "+1234567890", + "rawId": "testuid" + }], + "mfaInfo": [{ + "phoneInfo": "+1234567890", + "mfaEnrollmentId": "enrolledPhoneFactor", + "displayName": "My MFA Phone", + "enrolledAt": "2021-03-03T13:06:20.542896Z" + }, { + "totpInfo": {}, + "mfaEnrollmentId": "enrolledTOTPFactor", + "displayName": "My MFA TOTP", + "enrolledAt": "2021-03-03T13:06:20.542896Z" + }] + }], + "recordsCount": "1" + }` + s := echoServer([]byte(resp), t) + defer s.Close() + + query := &QueryUsersRequest{ + ReturnUserInfo: true, + Limit: "1", + SortBy: string(UserEmail), + Order: string(Asc), + Expression: []*SQLExpression{ + { + Email: "testuser@example.com", + }, + }, + } + + result, err := s.Client.QueryUsers(context.Background(), query) + if err != nil { + t.Fatalf("QueryUsers() = %v", err) + } + + if len(result.Users) != 1 { + t.Fatalf("QueryUsers() returned %d users; want 1", len(result.Users)) + } + + if result.Count != "1" { + t.Errorf("QueryUsers() returned count %q; want '1'", result.Count) + } + + if !reflect.DeepEqual(result.Users[0], testUser) { + t.Errorf("QueryUsers() = %#v; want = %#v", result.Users[0], testUser) + } + + wantPath := "/projects/mock-project-id/accounts:query" + if s.Req[0].RequestURI != wantPath { + t.Errorf("QueryUsers() URL = %q; want = %q", s.Req[0].RequestURI, wantPath) + } +} + +func TestQueryUsersError(t *testing.T) { + resp := `{ + "error": { + "message": "INVALID_QUERY" + } + }` + s := echoServer([]byte(resp), t) + defer s.Close() + s.Status = http.StatusBadRequest + + query := &QueryUsersRequest{ + ReturnUserInfo: true, + Limit: "1", + SortBy: "USER_EMAIL", + Order: "ASC", + Expression: []*SQLExpression{ + { + Email: "testuser@example.com", + }, + }, + } + + result, err := s.Client.QueryUsers(context.Background(), query) + if result != nil || err == nil { + t.Fatalf("QueryUsers() = (%v, %v); want = (nil, error)", result, err) + } +} + func TestMakeExportedUser(t *testing.T) { queryResponse := &userQueryResponse{ UID: "testuser",