Skip to content

Commit 982a9a2

Browse files
authored
[-] fix assignment to entry in nil map, fixes #968 (#969)
Add `NewSourceConn()` constructor and use direct access in `FetchRuntimeInfo()` without temp `RuntimeInfo` var
1 parent 03dfd4b commit 982a9a2

File tree

4 files changed

+29
-34
lines changed

4 files changed

+29
-34
lines changed

internal/cmdopts/cmdsource.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func (cmd *SourcePingCommand) Execute(args []string) error {
5757
case sources.SourcePostgresContinuous:
5858
_, e = sources.ResolveDatabasesFromPostgres(s)
5959
default:
60-
mdb := &sources.SourceConn{Source: s}
60+
mdb := sources.NewSourceConn(s)
6161
e = mdb.Connect(context.Background(), cmd.owner.Sources)
6262
}
6363
if e != nil {

internal/sources/conn.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ type (
5454
SourceConns []*SourceConn
5555
)
5656

57+
func NewSourceConn(s Source) *SourceConn {
58+
return &SourceConn{
59+
Source: s,
60+
RuntimeInfo: RuntimeInfo{
61+
Extensions: make(map[string]int),
62+
ChangeState: make(map[string]map[string]string),
63+
},
64+
}
65+
}
66+
5767
// Ping will try to ping the server to ensure the connection is still alive
5868
func (md *SourceConn) Ping(ctx context.Context) (err error) {
5969
if md.Kind == SourcePgBouncer {
@@ -155,15 +165,9 @@ func (md *SourceConn) FetchRuntimeInfo(ctx context.Context, forceRefetch bool) (
155165
if !forceRefetch && md.LastCheckedOn.After(time.Now().Add(time.Minute*-2)) { // use cached version for 2 min
156166
return nil
157167
}
158-
159-
dbNewSettings := RuntimeInfo{
160-
Extensions: make(map[string]int),
161-
ChangeState: make(map[string]map[string]string),
162-
}
163-
164168
switch md.Kind {
165169
case SourcePgBouncer, SourcePgPool:
166-
if dbNewSettings.VersionStr, dbNewSettings.Version, err = md.FetchVersion(ctx, func() string {
170+
if md.VersionStr, md.Version, err = md.FetchVersion(ctx, func() string {
167171
if md.Kind == SourcePgBouncer {
168172
return "SHOW VERSION"
169173
}
@@ -183,15 +187,15 @@ FROM
183187
pg_control_system()`
184188

185189
err = md.Conn.QueryRow(ctx, sql).
186-
Scan(&dbNewSettings.Version, &dbNewSettings.VersionStr,
187-
&dbNewSettings.IsInRecovery, &dbNewSettings.RealDbname,
188-
&dbNewSettings.SystemIdentifier, &dbNewSettings.IsSuperuser)
190+
Scan(&md.Version, &md.VersionStr,
191+
&md.IsInRecovery, &md.RealDbname,
192+
&md.SystemIdentifier, &md.IsSuperuser)
189193
if err != nil {
190194
return err
191195
}
192196

193-
dbNewSettings.ExecEnv = md.DiscoverPlatform(ctx)
194-
dbNewSettings.ApproxDbSize = md.FetchApproxSize(ctx)
197+
md.ExecEnv = md.DiscoverPlatform(ctx)
198+
md.ApproxDbSize = md.FetchApproxSize(ctx)
195199

196200
sqlExtensions := `select /* pgwatch_generated */ extname::text, (regexp_matches(extversion, $$\d+\.?\d+?$$))[1]::text as extversion from pg_extension order by 1;`
197201
var res pgx.Rows
@@ -204,14 +208,13 @@ FROM
204208
if extver == 0 {
205209
return fmt.Errorf("unexpected extension %s version input: %s", ext, ver)
206210
}
207-
dbNewSettings.Extensions[ext] = extver
211+
md.Extensions[ext] = extver
208212
return nil
209213
})
210214
}
211215

212216
}
213-
dbNewSettings.LastCheckedOn = time.Now()
214-
md.RuntimeInfo = dbNewSettings // store the new settings in the struct
217+
md.LastCheckedOn = time.Now()
215218
return err
216219
}
217220

internal/sources/conn_test.go

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,8 @@ func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
243243
t.Run("pgbouncer version fetch", func(t *testing.T) {
244244
mock, err := pgxmock.NewPool()
245245
require.NoError(t, err)
246-
md := &sources.SourceConn{
247-
Conn: mock,
248-
Source: sources.Source{Kind: sources.SourcePgBouncer},
249-
}
246+
md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgBouncer})
247+
md.Conn = mock
250248
mock.ExpectQuery("SHOW VERSION").
251249
WithArgs(pgx.QueryExecModeSimpleProtocol).
252250
WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("PgBouncer 1.12.0"))
@@ -260,10 +258,8 @@ func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
260258
t.Run("pgpool version fetch", func(t *testing.T) {
261259
mock, err := pgxmock.NewPool()
262260
require.NoError(t, err)
263-
md := &sources.SourceConn{
264-
Conn: mock,
265-
Source: sources.Source{Kind: sources.SourcePgPool},
266-
}
261+
md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgPool})
262+
md.Conn = mock
267263
mock.ExpectQuery("SHOW POOL_VERSION").
268264
WithArgs(pgx.QueryExecModeSimpleProtocol).
269265
WillReturnRows(pgxmock.NewRows([]string{"version"}).AddRow("4.1.2"))
@@ -277,10 +273,8 @@ func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
277273
t.Run("postgres version and extensions", func(t *testing.T) {
278274
mock, err := pgxmock.NewPool()
279275
require.NoError(t, err)
280-
md := &sources.SourceConn{
281-
Conn: mock,
282-
Source: sources.Source{Kind: sources.SourcePostgres},
283-
}
276+
md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePostgres})
277+
md.Conn = mock
284278
mock.ExpectQuery("select").WillReturnRows(
285279
pgxmock.NewRows([]string{"ver", "version", "pg_is_in_recovery", "current_database", "system_identifier", "is_superuser"}).
286280
AddRow(13, "PostgreSQL 13.3", false, "testdb", "42424242", true),
@@ -306,10 +300,8 @@ func TestSourceConn_FetchRuntimeInfo(t *testing.T) {
306300
t.Run("query error", func(t *testing.T) {
307301
mock, err := pgxmock.NewPool()
308302
require.NoError(t, err)
309-
md := &sources.SourceConn{
310-
Conn: mock,
311-
Source: sources.Source{Kind: sources.SourcePgBouncer},
312-
}
303+
md := sources.NewSourceConn(sources.Source{Kind: sources.SourcePgBouncer})
304+
md.Conn = mock
313305
mock.ExpectQuery("SHOW VERSION").
314306
WithArgs(pgx.QueryExecModeSimpleProtocol).
315307
WillReturnError(fmt.Errorf("db error"))

internal/sources/resolver.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (s Source) ResolveDatabases() (SourceConns, error) {
4747
case SourcePostgresContinuous:
4848
return ResolveDatabasesFromPostgres(s)
4949
}
50-
return SourceConns{&SourceConn{Source: s}}, nil
50+
return SourceConns{NewSourceConn(s)}, nil
5151
}
5252

5353
type PatroniClusterMember struct {
@@ -329,7 +329,7 @@ func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error)
329329
if err = rows.Scan(&dbname); err != nil {
330330
return nil, err
331331
}
332-
rdb := &SourceConn{Source: *s.Clone()}
332+
rdb := NewSourceConn(*s.Clone())
333333
rdb.Name += "_" + dbname
334334
rdb.SetDatabaseName(dbname)
335335
resolvedDbs = append(resolvedDbs, rdb)

0 commit comments

Comments
 (0)