Skip to content

Commit 6d669b0

Browse files
authored
Add signal notify context to backend and validators (#46)
* refactor: add sourceType to runoptions * refactor: add context to backend and validator packages * fix: honor cancellation on localsource getconfig * fix: honor cancellation on s3 source get config
1 parent e591ef2 commit 6d669b0

19 files changed

+83
-46
lines changed

cmd/appstreamfile/main.go

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package main
22

33
import (
4+
"context"
45
"flag"
56
"fmt"
67
"os"
8+
"os/signal"
79

810
"github.com/aslamcodes/appstreamfile/internal/backend"
911
"github.com/aslamcodes/appstreamfile/internal/logger"
@@ -12,13 +14,17 @@ import (
1214
)
1315

1416
type RunOptions struct {
15-
location string
16-
bucket string
17-
key string
18-
versionId string
17+
location string
18+
SourceType string
19+
bucket string
20+
key string
21+
versionId string
1922
}
2023

2124
func main() {
25+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
26+
defer cancel()
27+
2228
source := flag.String("source", "", "Configuration source: s3 or local")
2329
location := flag.String("location", "", "Local filesystem path to the config file")
2430
bucket := flag.String("bucket", "", "S3 bucket containing the config file")
@@ -30,24 +36,25 @@ func main() {
3036
logger.Init()
3137

3238
runOptions := &RunOptions{
33-
location: *location,
34-
bucket: *bucket,
35-
key: *key,
36-
versionId: *versionId,
39+
SourceType: *source,
40+
location: *location,
41+
bucket: *bucket,
42+
key: *key,
43+
versionId: *versionId,
3744
}
3845

39-
if err := run(*source, runOptions); err != nil {
46+
if err := run(ctx, runOptions); err != nil {
4047
fmt.Fprintln(os.Stderr, err)
4148
os.Exit(1)
4249
}
4350
}
4451

45-
func run(sourceType string, opts *RunOptions) error {
52+
func run(ctx context.Context, opts *RunOptions) error {
4653
var backendSource backend.BackendSource
4754
var err error
4855

49-
switch sourceType {
50-
case "local":
56+
switch opts.SourceType {
57+
case "local":
5158
if opts.location == "" {
5259
return fmt.Errorf("location of config file must be provided")
5360
}
@@ -58,7 +65,6 @@ func run(sourceType string, opts *RunOptions) error {
5865
}
5966
backendSource, err = backend.NewS3Backend(opts.bucket, opts.key, opts.versionId, "appstream_machine_role")
6067

61-
6268
default:
6369
return fmt.Errorf("invalid source provided")
6470
}
@@ -67,13 +73,13 @@ func run(sourceType string, opts *RunOptions) error {
6773
return fmt.Errorf("unable to create backend source: %w", err)
6874
}
6975

70-
config, err := backendSource.GetConfig()
76+
config, err := backendSource.GetConfig(ctx)
7177

7278
if err != nil {
7379
return fmt.Errorf("failed to fetch config from backend: %w", err)
7480
}
7581

76-
if err := validator.ValidateConfig(config); err != nil {
82+
if err := validator.ValidateConfig(ctx, config); err != nil {
7783
return fmt.Errorf("config file validation failed: %w", err)
7884
}
7985

internal/backend/backend.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package backend
22

3-
import c "github.com/aslamcodes/appstreamfile/internal/config"
3+
import (
4+
"context"
5+
6+
c "github.com/aslamcodes/appstreamfile/internal/config"
7+
)
48

59
type BackendSource interface {
6-
GetConfig() (*c.Config, error)
10+
GetConfig(ctx context.Context) (*c.Config, error)
711
}

internal/backend/local.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package backend
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67

@@ -12,7 +13,11 @@ type LocalBackend struct {
1213
Location string
1314
}
1415

15-
func (lb *LocalBackend) GetConfig() (*config.Config, error) {
16+
func (lb *LocalBackend) GetConfig(ctx context.Context) (*config.Config, error) {
17+
if err := ctx.Err(); err != nil {
18+
return nil, err
19+
}
20+
1621
fmt.Printf("Attempting to fetch config from local backend at %s\n", lb.Location)
1722

1823
data, err := os.ReadFile(lb.Location)

internal/backend/local_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package backend_test
22

33
import (
4+
"context"
45
"fmt"
56
"os"
67
"reflect"
@@ -15,7 +16,7 @@ func TestGetConfig(t *testing.T) {
1516
Location: "../../testdata/config_win.yaml",
1617
}
1718

18-
actual, err := localBackend.GetConfig()
19+
actual, err := localBackend.GetConfig(context.TODO())
1920

2021
if err != nil {
2122
t.Fatal(err)

internal/backend/s3.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ type S3Backend struct {
1616
Client S3Client
1717
}
1818

19-
func (s3Backend *S3Backend) GetConfig() (*config.Config, error) {
20-
ctx := context.Background()
19+
func (s3Backend *S3Backend) GetConfig(ctx context.Context) (*config.Config, error) {
20+
if err := ctx.Err(); err != nil {
21+
return nil, err
22+
}
2123

2224
if s3Backend.Client == nil {
2325
return nil, fmt.Errorf("client is nil")

internal/backend/s3_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ installers:
5252
Client: client,
5353
}
5454

55-
actual, err := backend.GetConfig()
55+
actual, err := backend.GetConfig(context.TODO())
5656

5757
if err != nil {
5858
t.Errorf("error fetching the config: %v", err)
@@ -76,7 +76,7 @@ func TestGetConfigFail(t *testing.T) {
7676
},
7777
}
7878

79-
_, err := backend.GetConfig()
79+
_, err := backend.GetConfig(context.TODO())
8080

8181
if err == nil {
8282
t.Errorf("expected %v, got nil", expectedErr)

internal/validator/catalog_validator.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package validator
22

33
import (
4+
"context"
45
"errors"
56

67
"github.com/aslamcodes/appstreamfile/internal/config"
@@ -12,7 +13,7 @@ var (
1213
ErrEmptyCatalogNamePath = errors.New("catalog application path and name cannot be empty")
1314
)
1415

15-
func ValidateCatalogApplications(configData *config.Config) error {
16+
func ValidateCatalogApplications(ctx context.Context, configData *config.Config) error {
1617
for _, c := range configData.Catalogs {
1718
if c.Name == "" && c.Path == "" {
1819
return ErrEmptyCatalogNamePath

internal/validator/catalog_validator_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package validator_test
22

33
import (
4+
"context"
45
"testing"
56

67
"github.com/aslamcodes/appstreamfile/internal/backend"
@@ -40,13 +41,13 @@ func TestValidateCatalogConfig(t *testing.T) {
4041
Location: tC.filename,
4142
}
4243

43-
config, err := lb.GetConfig()
44+
config, err := lb.GetConfig(context.TODO())
4445

4546
if err != nil {
4647
t.Errorf("unable to load config: %s", err.Error())
4748
}
4849

49-
actual := validator.ValidateCatalogApplications(config)
50+
actual := validator.ValidateCatalogApplications(context.TODO(), config)
5051

5152
if actual != tC.expected {
5253
t.Errorf("expected %v, actual %v", tC.expected, actual)

internal/validator/config_validator.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package validator
22

33
import (
4+
"context"
45
"errors"
56
"slices"
67

@@ -12,8 +13,7 @@ var (
1213
ErrInvalidPlatform = errors.New("Platform not supported")
1314
)
1415

15-
16-
func ValidatePlatforms(c *config.Config) error {
16+
func ValidatePlatforms(ctx context.Context, c *config.Config) error {
1717
if c.Platform == "" {
1818
return ErrPlatformMissing
1919
}

internal/validator/config_validator_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package validator_test
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"os"
@@ -46,6 +47,8 @@ installers:
4647
}
4748
for _, tC := range testCases {
4849
t.Run(tC.desc, func(t *testing.T) {
50+
ctx := context.TODO()
51+
4952
file, err := os.CreateTemp("../../testdata", fmt.Sprintf("test_%s_", tC.desc))
5053

5154
if err != nil {
@@ -58,13 +61,13 @@ installers:
5861
Location: file.Name(),
5962
}
6063

61-
configData, err := lb.GetConfig()
64+
configData, err := lb.GetConfig(ctx)
6265

6366
if err != nil {
6467
t.Errorf("unable to fetch config data: %v", err)
6568
}
6669

67-
err = validator.ValidatePlatforms(configData)
70+
err = validator.ValidatePlatforms(ctx, configData)
6871

6972
if !errors.Is(err, tC.expected) {
7073
t.Errorf("expected %v, got %v", tC.expected, err)

0 commit comments

Comments
 (0)