Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions pkg/cmd/login/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type LoginStore interface {
GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error)
GetActiveOrganizationOrDefault() (*entity.Organization, error)
CreateOrganization(req store.CreateOrganizationRequest) (*entity.Organization, error)
SetDefaultOrganization(org *entity.Organization) error
GetServerSockFile() string
GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error)
UpdateUser(userID string, updatedUser *entity.UpdateUser) (*entity.User, error)
Expand All @@ -58,6 +59,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.
var skipBrowser bool
var emailFlag string
var authProviderFlag string
var orgFlag string

cmd := &cobra.Command{
Annotations: map[string]string{"housekeeping": ""},
Expand All @@ -82,7 +84,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.
},
Args: cmderrors.TransformToValidationError(cobra.NoArgs),
RunE: func(cmd *cobra.Command, args []string) error {
err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag)
err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag, orgFlag)
if err != nil {
// if err is ImportIDEConfigError, log err with sentry but continue
if _, ok := err.(*importideconfig.ImportIDEConfigError); !ok {
Expand All @@ -102,6 +104,7 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra.
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "print url instead of auto opening browser")
cmd.Flags().StringVar(&emailFlag, "email", "", "email to use for authentication")
cmd.Flags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is nvidia)")
cmd.Flags().StringVarP(&orgFlag, "org", "o", "", "organization to use (must exist)")
return cmd
}

Expand Down Expand Up @@ -130,27 +133,50 @@ func (o LoginOptions) loginAndGetOrCreateUser(loginToken string, skipBrowser boo
return user, nil
}

func (o LoginOptions) getOrCreateOrg(username string) (*entity.Organization, error) {
org, err := o.LoginStore.GetActiveOrganizationOrDefault()
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
func (o LoginOptions) getOrCreateOrg(username string, orgFlag string) (*entity.Organization, error) {
var org *entity.Organization
var err error

if org == nil {
newOrgName := makeFirstOrgName(username)
fmt.Printf("Creating your first org %s ...\n", newOrgName)
org, err = o.LoginStore.CreateOrganization(store.CreateOrganizationRequest{
Name: newOrgName,
})
if orgFlag != "" {
var orgs []entity.Organization
orgs, err = o.LoginStore.GetOrganizations(&store.GetOrganizationsOptions{Name: orgFlag})
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
if len(orgs) == 0 {
return nil, breverrors.NewValidationError(fmt.Sprintf("no org found with name %s", orgFlag))
} else if len(orgs) > 1 {
return nil, breverrors.NewValidationError(fmt.Sprintf("more than one org found with name %s", orgFlag))
}
org = &orgs[0]

err = o.LoginStore.SetDefaultOrganization(org)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
fmt.Println("done!")
} else {
org, err = o.LoginStore.GetActiveOrganizationOrDefault()
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}

if org == nil {
newOrgName := makeFirstOrgName(username)
fmt.Printf("Creating your first org %s ...\n", newOrgName)
org, err = o.LoginStore.CreateOrganization(store.CreateOrganizationRequest{
Name: newOrgName,
})
if err != nil {
return nil, breverrors.WrapAndTrace(err)
}
fmt.Println("done!")
}
}

return org, nil
}

func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string) error {
func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string, orgFlag string) error {
tokens, _ := o.LoginStore.GetAuthTokens()

if authProviderFlag != "" && authProviderFlag != "nvidia" && authProviderFlag != "legacy" {
Expand All @@ -175,7 +201,7 @@ func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrow
return breverrors.WrapAndTrace(err)
}

org, err := o.getOrCreateOrg(user.Username)
org, err := o.getOrCreateOrg(user.Username, orgFlag)
if err != nil {
return breverrors.WrapAndTrace(err)
}
Expand Down
Loading