Skip to content

Commit c690a5f

Browse files
Updating baton-demo with latest SDK (#69)
1 parent 7bdf400 commit c690a5f

File tree

4 files changed

+421
-25
lines changed

4 files changed

+421
-25
lines changed

pkg/client/client.go

Lines changed: 141 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package client
33
import (
44
"context"
55
"database/sql"
6+
"encoding/json"
67
"fmt"
78
"os"
89
"slices"
@@ -32,9 +33,11 @@ import (
3233
// Projects always have a single User as the owner, and can be assigned to Groups
3334

3435
type User struct {
35-
Id string
36-
Name string
37-
Email string
36+
Id string
37+
Name string
38+
Email string
39+
Enabled bool
40+
Attrs map[string]string
3841
}
3942

4043
type Group struct {
@@ -105,6 +108,12 @@ func NewClient(ctx context.Context, dc *config.Demo) (*Client, error) {
105108
}
106109
}
107110

111+
// Add migration for enabled and attrs columns if they don't exist
112+
err = c.migrateUsersTable(ctx)
113+
if err != nil {
114+
return nil, err
115+
}
116+
108117
if c.config.InitDb {
109118
err = c.initDB(ctx)
110119
if err != nil {
@@ -119,6 +128,55 @@ func (c *Client) Close() error {
119128
return c.rawDB.Close()
120129
}
121130

131+
func (c *Client) migrateUsersTable(ctx context.Context) error {
132+
// Check if the enabled and attrs columns exist
133+
query := "PRAGMA table_info(users)"
134+
rows, err := c.rawDB.QueryContext(ctx, query)
135+
if err != nil {
136+
return err
137+
}
138+
defer rows.Close()
139+
140+
hasEnabledColumn := false
141+
hasAttrsColumn := false
142+
for rows.Next() {
143+
var cid int
144+
var name, dataType string
145+
var notNull, pk int
146+
var defaultValue interface{}
147+
err := rows.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &pk)
148+
if err != nil {
149+
return err
150+
}
151+
if name == "enabled" {
152+
hasEnabledColumn = true
153+
}
154+
if name == "attrs" {
155+
hasAttrsColumn = true
156+
}
157+
}
158+
159+
// If the enabled column doesn't exist, add it
160+
if !hasEnabledColumn {
161+
alterQuery := "ALTER TABLE users ADD COLUMN enabled BOOLEAN DEFAULT 1"
162+
_, err = c.rawDB.ExecContext(ctx, alterQuery)
163+
if err != nil {
164+
return err
165+
}
166+
}
167+
168+
// If the attrs column doesn't exist, add it
169+
if !hasAttrsColumn {
170+
alterQuery := "ALTER TABLE users ADD COLUMN attrs BLOB"
171+
_, err = c.rawDB.ExecContext(ctx, alterQuery)
172+
if err != nil {
173+
return err
174+
}
175+
}
176+
177+
return nil
178+
}
179+
122180
func (c *Client) validateDB() error {
123181
// Check if the database is already initialized
124182
if c.db == nil {
@@ -154,10 +212,16 @@ func (c *Client) initDB(ctx context.Context) error {
154212

155213
switch {
156214
case dbResource.User != nil:
215+
attrs, err := json.Marshal(dbResource.User.Attrs)
216+
if err != nil {
217+
return err
218+
}
157219
row := goqu.Record{
158-
"id": dbResource.User.Id,
159-
"name": dbResource.User.Name,
160-
"email": dbResource.User.Email,
220+
"id": dbResource.User.Id,
221+
"name": dbResource.User.Name,
222+
"email": dbResource.User.Email,
223+
"enabled": dbResource.User.Enabled,
224+
"attrs": attrs,
161225
}
162226
baseUserQ := c.db.Insert(users.Name()).Prepared(true)
163227
baseUserQ = baseUserQ.Rows(row)
@@ -272,7 +336,7 @@ func (c *Client) ListUsers(ctx context.Context, pToken *pagination.Token) ([]*Us
272336
}
273337

274338
q := c.db.From(users.Name()).Prepared(true)
275-
q = q.Select("id", "name", "email").
339+
q = q.Select("id", "name", "email", "enabled", "attrs").
276340
Order(goqu.C("id").Asc()).
277341
Limit(uint(limit)). //nolint:gosec // This won't underflow
278342
Offset(uint(offset)) //nolint:gosec // This won't underflow
@@ -290,10 +354,19 @@ func (c *Client) ListUsers(ctx context.Context, pToken *pagination.Token) ([]*Us
290354
usersList := []*User{}
291355
for rows.Next() {
292356
user := &User{}
293-
err = rows.Scan(&user.Id, &user.Name, &user.Email)
357+
attrsBytes := []byte{}
358+
err = rows.Scan(&user.Id, &user.Name, &user.Email, &user.Enabled, &attrsBytes)
294359
if err != nil {
295360
return nil, "", err
296361
}
362+
if len(attrsBytes) > 0 {
363+
err = json.Unmarshal(attrsBytes, &user.Attrs)
364+
if err != nil {
365+
return nil, "", err
366+
}
367+
} else {
368+
user.Attrs = make(map[string]string)
369+
}
297370
usersList = append(usersList, user)
298371
}
299372

@@ -313,7 +386,7 @@ func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) {
313386
}
314387

315388
q := c.db.From(users.Name()).Prepared(true)
316-
q = q.Select("id", "name", "email")
389+
q = q.Select("id", "name", "email", "enabled", "attrs")
317390
q = q.Where(goqu.C("id").Eq(userID))
318391

319392
query, args, err := q.ToSQL()
@@ -323,11 +396,21 @@ func (c *Client) GetUser(ctx context.Context, userID string) (*User, error) {
323396

324397
row := c.db.QueryRowContext(ctx, query, args...)
325398
user := &User{}
326-
err = row.Scan(&user.Id, &user.Name, &user.Email)
399+
attrsBytes := []byte{}
400+
err = row.Scan(&user.Id, &user.Name, &user.Email, &user.Enabled, &attrsBytes)
327401
if err != nil {
328402
return nil, err
329403
}
330404

405+
if len(attrsBytes) > 0 {
406+
err = json.Unmarshal(attrsBytes, &user.Attrs)
407+
if err != nil {
408+
return nil, err
409+
}
410+
} else {
411+
user.Attrs = make(map[string]string)
412+
}
413+
331414
return user, nil
332415
}
333416

@@ -367,16 +450,25 @@ func (c *Client) CreateUser(ctx context.Context, name, email, password string) (
367450
}
368451

369452
user := &User{
370-
Id: ksuid.New().String(),
371-
Name: name,
372-
Email: email,
453+
Id: ksuid.New().String(),
454+
Name: name,
455+
Email: email,
456+
Enabled: true, // Default to enabled
457+
Attrs: make(map[string]string),
458+
}
459+
460+
attrs, err := json.Marshal(user.Attrs)
461+
if err != nil {
462+
return nil, err
373463
}
374464

375465
q := c.db.Insert(users.Name()).Prepared(true)
376466
q = q.Rows(goqu.Record{
377-
"id": user.Id,
378-
"name": user.Name,
379-
"email": user.Email,
467+
"id": user.Id,
468+
"name": user.Name,
469+
"email": user.Email,
470+
"enabled": user.Enabled,
471+
"attrs": attrs,
380472
})
381473

382474
query, args, err := q.ToSQL()
@@ -407,6 +499,39 @@ func (c *Client) CreateUser(ctx context.Context, name, email, password string) (
407499
return user, nil
408500
}
409501

502+
func (c *Client) UpdateUser(ctx context.Context, user *User) error {
503+
err := c.validateDB()
504+
if err != nil {
505+
return err
506+
}
507+
508+
attrs, err := json.Marshal(user.Attrs)
509+
if err != nil {
510+
return err
511+
}
512+
513+
q := c.db.Update(users.Name()).Prepared(true)
514+
q = q.Set(goqu.Record{
515+
"name": user.Name,
516+
"email": user.Email,
517+
"enabled": user.Enabled,
518+
"attrs": attrs,
519+
})
520+
q = q.Where(goqu.C("id").Eq(user.Id))
521+
522+
query, args, err := q.ToSQL()
523+
if err != nil {
524+
return err
525+
}
526+
527+
_, err = c.db.ExecContext(ctx, query, args...)
528+
if err != nil {
529+
return err
530+
}
531+
532+
return nil
533+
}
534+
410535
func (c *Client) ChangePassword(ctx context.Context, userID, password string) error {
411536
err := c.validateDB()
412537
if err != nil {

pkg/client/data.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,18 @@ func (g *generator) Next() (*dbResource, bool) {
120120
return db, true
121121
}
122122
if g.currentUser < g.config.Users {
123+
userFullName := fmt.Sprintf("User %07d", g.currentUser)
124+
userEmail := fmt.Sprintf("user-%d@example.com", g.currentUser)
123125
db := &dbResource{
124126
User: &User{
125-
Id: userId(g.currentUser),
126-
Name: fmt.Sprintf("User %07d", g.currentUser),
127-
Email: fmt.Sprintf("user-%d@example.com", g.currentUser),
127+
Id: userId(g.currentUser),
128+
Name: userFullName,
129+
Email: userEmail,
130+
Enabled: true, // Default to enabled
131+
Attrs: map[string]string{
132+
"full_name": userFullName,
133+
"email": userEmail,
134+
},
128135
},
129136
}
130137
g.currentUser++
@@ -167,7 +174,7 @@ func (t *usersTable) Name() string {
167174
}
168175

169176
func (t *usersTable) Schema() (string, []any) {
170-
return "CREATE TABLE IF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, email TEXT)", []any{}
177+
return "CREATE TABLE IF NOT EXISTS users (id TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, email TEXT, enabled BOOLEAN DEFAULT 1, attrs BLOB)", []any{}
171178
}
172179

173180
var groups = (*groupsTable)(nil)

0 commit comments

Comments
 (0)