@@ -3,6 +3,7 @@ package client
33import (
44 "context"
55 "database/sql"
6+ "encoding/json"
67 "fmt"
78 "os"
89 "slices"
@@ -32,9 +33,11 @@ import (
3233// Projects always have a single User as the owner, and can be assigned to Groups
3334
3435type User struct {
35- Id string
36- Name string
37- Email string
36+ Id string
37+ Name string
38+ Email string
39+ Enabled bool
40+ Attrs map [string ]string
3841}
3942
4043type Group struct {
@@ -105,6 +108,12 @@ func NewClient(ctx context.Context, dc *config.Demo) (*Client, error) {
105108 }
106109 }
107110
111+ // Add migration for enabled and attrs columns if they don't exist
112+ err = c .migrateUsersTable (ctx )
113+ if err != nil {
114+ return nil , err
115+ }
116+
108117 if c .config .InitDb {
109118 err = c .initDB (ctx )
110119 if err != nil {
@@ -119,6 +128,55 @@ func (c *Client) Close() error {
119128 return c .rawDB .Close ()
120129}
121130
131+ func (c * Client ) migrateUsersTable (ctx context.Context ) error {
132+ // Check if the enabled and attrs columns exist
133+ query := "PRAGMA table_info(users)"
134+ rows , err := c .rawDB .QueryContext (ctx , query )
135+ if err != nil {
136+ return err
137+ }
138+ defer rows .Close ()
139+
140+ hasEnabledColumn := false
141+ hasAttrsColumn := false
142+ for rows .Next () {
143+ var cid int
144+ var name , dataType string
145+ var notNull , pk int
146+ var defaultValue interface {}
147+ err := rows .Scan (& cid , & name , & dataType , & notNull , & defaultValue , & pk )
148+ if err != nil {
149+ return err
150+ }
151+ if name == "enabled" {
152+ hasEnabledColumn = true
153+ }
154+ if name == "attrs" {
155+ hasAttrsColumn = true
156+ }
157+ }
158+
159+ // If the enabled column doesn't exist, add it
160+ if ! hasEnabledColumn {
161+ alterQuery := "ALTER TABLE users ADD COLUMN enabled BOOLEAN DEFAULT 1"
162+ _ , err = c .rawDB .ExecContext (ctx , alterQuery )
163+ if err != nil {
164+ return err
165+ }
166+ }
167+
168+ // If the attrs column doesn't exist, add it
169+ if ! hasAttrsColumn {
170+ alterQuery := "ALTER TABLE users ADD COLUMN attrs BLOB"
171+ _ , err = c .rawDB .ExecContext (ctx , alterQuery )
172+ if err != nil {
173+ return err
174+ }
175+ }
176+
177+ return nil
178+ }
179+
122180func (c * Client ) validateDB () error {
123181 // Check if the database is already initialized
124182 if c .db == nil {
@@ -154,10 +212,16 @@ func (c *Client) initDB(ctx context.Context) error {
154212
155213 switch {
156214 case dbResource .User != nil :
215+ attrs , err := json .Marshal (dbResource .User .Attrs )
216+ if err != nil {
217+ return err
218+ }
157219 row := goqu.Record {
158- "id" : dbResource .User .Id ,
159- "name" : dbResource .User .Name ,
160- "email" : dbResource .User .Email ,
220+ "id" : dbResource .User .Id ,
221+ "name" : dbResource .User .Name ,
222+ "email" : dbResource .User .Email ,
223+ "enabled" : dbResource .User .Enabled ,
224+ "attrs" : attrs ,
161225 }
162226 baseUserQ := c .db .Insert (users .Name ()).Prepared (true )
163227 baseUserQ = baseUserQ .Rows (row )
@@ -272,7 +336,7 @@ func (c *Client) ListUsers(ctx context.Context, pToken *pagination.Token) ([]*Us
272336 }
273337
274338 q := c .db .From (users .Name ()).Prepared (true )
275- q = q .Select ("id" , "name" , "email" ).
339+ q = q .Select ("id" , "name" , "email" , "enabled" , "attrs" ).
276340 Order (goqu .C ("id" ).Asc ()).
277341 Limit (uint (limit )). //nolint:gosec // This won't underflow
278342 Offset (uint (offset )) //nolint:gosec // This won't underflow
@@ -290,10 +354,19 @@ func (c *Client) ListUsers(ctx context.Context, pToken *pagination.Token) ([]*Us
290354 usersList := []* User {}
291355 for rows .Next () {
292356 user := & User {}
293- err = rows .Scan (& user .Id , & user .Name , & user .Email )
357+ attrsBytes := []byte {}
358+ err = rows .Scan (& user .Id , & user .Name , & user .Email , & user .Enabled , & attrsBytes )
294359 if err != nil {
295360 return nil , "" , err
296361 }
362+ if len (attrsBytes ) > 0 {
363+ err = json .Unmarshal (attrsBytes , & user .Attrs )
364+ if err != nil {
365+ return nil , "" , err
366+ }
367+ } else {
368+ user .Attrs = make (map [string ]string )
369+ }
297370 usersList = append (usersList , user )
298371 }
299372
@@ -313,7 +386,7 @@ func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) {
313386 }
314387
315388 q := c .db .From (users .Name ()).Prepared (true )
316- q = q .Select ("id" , "name" , "email" )
389+ q = q .Select ("id" , "name" , "email" , "enabled" , "attrs" )
317390 q = q .Where (goqu .C ("id" ).Eq (userID ))
318391
319392 query , args , err := q .ToSQL ()
@@ -323,11 +396,21 @@ func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) {
323396
324397 row := c .db .QueryRowContext (ctx , query , args ... )
325398 user := & User {}
326- err = row .Scan (& user .Id , & user .Name , & user .Email )
399+ attrsBytes := []byte {}
400+ err = row .Scan (& user .Id , & user .Name , & user .Email , & user .Enabled , & attrsBytes )
327401 if err != nil {
328402 return nil , err
329403 }
330404
405+ if len (attrsBytes ) > 0 {
406+ err = json .Unmarshal (attrsBytes , & user .Attrs )
407+ if err != nil {
408+ return nil , err
409+ }
410+ } else {
411+ user .Attrs = make (map [string ]string )
412+ }
413+
331414 return user , nil
332415}
333416
@@ -367,16 +450,25 @@ func (c *Client) CreateUser(ctx context.Context, name, email, password string) (
367450 }
368451
369452 user := & User {
370- Id : ksuid .New ().String (),
371- Name : name ,
372- Email : email ,
453+ Id : ksuid .New ().String (),
454+ Name : name ,
455+ Email : email ,
456+ Enabled : true , // Default to enabled
457+ Attrs : make (map [string ]string ),
458+ }
459+
460+ attrs , err := json .Marshal (user .Attrs )
461+ if err != nil {
462+ return nil , err
373463 }
374464
375465 q := c .db .Insert (users .Name ()).Prepared (true )
376466 q = q .Rows (goqu.Record {
377- "id" : user .Id ,
378- "name" : user .Name ,
379- "email" : user .Email ,
467+ "id" : user .Id ,
468+ "name" : user .Name ,
469+ "email" : user .Email ,
470+ "enabled" : user .Enabled ,
471+ "attrs" : attrs ,
380472 })
381473
382474 query , args , err := q .ToSQL ()
@@ -407,6 +499,39 @@ func (c *Client) CreateUser(ctx context.Context, name, email, password string) (
407499 return user , nil
408500}
409501
502+ func (c * Client ) UpdateUser (ctx context.Context , user * User ) error {
503+ err := c .validateDB ()
504+ if err != nil {
505+ return err
506+ }
507+
508+ attrs , err := json .Marshal (user .Attrs )
509+ if err != nil {
510+ return err
511+ }
512+
513+ q := c .db .Update (users .Name ()).Prepared (true )
514+ q = q .Set (goqu.Record {
515+ "name" : user .Name ,
516+ "email" : user .Email ,
517+ "enabled" : user .Enabled ,
518+ "attrs" : attrs ,
519+ })
520+ q = q .Where (goqu .C ("id" ).Eq (user .Id ))
521+
522+ query , args , err := q .ToSQL ()
523+ if err != nil {
524+ return err
525+ }
526+
527+ _ , err = c .db .ExecContext (ctx , query , args ... )
528+ if err != nil {
529+ return err
530+ }
531+
532+ return nil
533+ }
534+
410535func (c * Client ) ChangePassword (ctx context.Context , userID , password string ) error {
411536 err := c .validateDB ()
412537 if err != nil {
0 commit comments