-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstore.go
More file actions
145 lines (127 loc) · 3.85 KB
/
store.go
File metadata and controls
145 lines (127 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package main
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"strings"
"go.uber.org/zap"
)
type Permission uint8
const (
PermRead Permission = 1 << 0 //1
PermList Permission = 1 << 1 //2
PermWrite Permission = 1 << 2 //4
PermDelete Permission = 1 << 3 //8
)
type User struct {
ID int
DisplayName string
GroupName string
Username string
PasswordHash sql.NullString // bcrypt hash
PublicKey sql.NullString
RootPath string
Perms Permission
Disabled bool
}
type UserStore struct {
dbType string
db *sql.DB
logger *zap.SugaredLogger
}
func NewUserStore(dsn string) *UserStore {
dbType := getEnvOrDefault("DB_TYPE", "sqlite")
db, err := sql.Open(dbType, dsn)
if err != nil {
panic(err)
}
logger, err := initLogger()
if err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
// Ensure DB schema exists; if sftp_users table missing, apply ddl.sql
if err := applyDDLIfNeeded(dbType,db, logger); err != nil {
logger.Fatalf("Failed to apply DDL: %v", err)
}
return &UserStore{db: db, logger: logger, dbType: dbType }
}
func applyDDLIfNeeded(dbType string, db *sql.DB, logger *zap.SugaredLogger) error {
logger.Infof("Checking for sftp_users table")
// Try a simple query against the table
var tmp int
err := db.QueryRow("SELECT 1 FROM sftp_users LIMIT 1").Scan(&tmp)
if err == nil {
logger.Infof("sftp_users table exists")
return nil
}
logger.Warnf("sftp_users table not found or inaccessible (%v). Attempting to apply ddl.sql", err)
ddlBytes, rerr := os.ReadFile(dbType + "_ddl.sql")
if rerr != nil {
logger.Errorf("Failed to read ddl.sql: %v", rerr)
return rerr
}
ddl := string(ddlBytes)
// Execute the DDL. Some drivers/drivers' Exec may not accept multiple statements;
// try Exec as-is first, then fallback to splitting on semicolon.
if _, execErr := db.Exec(ddl); execErr == nil {
logger.Infof("Applied ddl.sql successfully")
return nil
} else {
logger.Warnf("Exec of ddl.sql failed: %v — attempting split-exec", execErr)
// naive split; acceptable for simple SQL files
statements := splitSQLStatements(ddl)
tx, terr := db.Begin()
if terr != nil {
logger.Errorf("Failed to begin transaction for applying DDL: %v", terr)
return terr
}
for _, stmt := range statements {
if stmt = trimWhitespace(stmt); stmt == "" {
continue
}
if _, serr := tx.Exec(stmt); serr != nil {
_ = tx.Rollback()
logger.Errorf("Failed to execute statement: %v", serr)
return serr
}
}
if cerr := tx.Commit(); cerr != nil {
logger.Errorf("Failed to commit DDL transaction: %v", cerr)
return cerr
}
logger.Infof("Applied ddl.sql successfully (split-exec)")
return nil
}
}
func splitSQLStatements(ddl string) []string {
return []string(filterEmpty(strings.Split(ddl, ";")))
}
func trimWhitespace(s string) string { return strings.TrimSpace(s) }
func filterEmpty(in []string) []string {
out := make([]string, 0, len(in))
for _, s := range in {
if strings.TrimSpace(s) != "" {
out = append(out, s)
}
}
return out
}
func (s *UserStore) FetchUserByUsername(ctx context.Context, username string) (*User, error) {
s.logger.Infof("Fetching user by username: %s", username)
// Use driver-specific placeholder
placeholder := "?"
if strings.EqualFold(s.dbType, "postgres") {
placeholder = "$1"
}
query := fmt.Sprintf(`SELECT id, display_name, group_name, username, password_hash, public_key, root_path, perms, disabled FROM sftp_users WHERE username = %s`, placeholder)
row := s.db.QueryRowContext(ctx, query, username)
var user User
err := row.Scan(&user.ID, &user.DisplayName, &user.GroupName, &user.Username, &user.PasswordHash, &user.PublicKey, &user.RootPath, &user.Perms, &user.Disabled)
if err != nil {
s.logger.Errorf("Error fetching user: %v", err)
return nil, err
}
return &user, nil
}