diff --git a/cmd/appstreamfile/main.go b/cmd/appstreamfile/main.go index 124592f..d520f43 100644 --- a/cmd/appstreamfile/main.go +++ b/cmd/appstreamfile/main.go @@ -1,9 +1,11 @@ package main import ( + "context" "flag" "fmt" "os" + "os/signal" "github.com/aslamcodes/appstreamfile/internal/backend" "github.com/aslamcodes/appstreamfile/internal/logger" @@ -12,13 +14,17 @@ import ( ) type RunOptions struct { - location string - bucket string - key string - versionId string + location string + SourceType string + bucket string + key string + versionId string } func main() { + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + source := flag.String("source", "", "Configuration source: s3 or local") location := flag.String("location", "", "Local filesystem path to the config file") bucket := flag.String("bucket", "", "S3 bucket containing the config file") @@ -30,24 +36,25 @@ func main() { logger.Init() runOptions := &RunOptions{ - location: *location, - bucket: *bucket, - key: *key, - versionId: *versionId, + SourceType: *source, + location: *location, + bucket: *bucket, + key: *key, + versionId: *versionId, } - if err := run(*source, runOptions); err != nil { + if err := run(ctx, runOptions); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } -func run(sourceType string, opts *RunOptions) error { +func run(ctx context.Context, opts *RunOptions) error { var backendSource backend.BackendSource var err error - switch sourceType { - case "local": + switch opts.SourceType { + case "local": if opts.location == "" { return fmt.Errorf("location of config file must be provided") } @@ -58,7 +65,6 @@ func run(sourceType string, opts *RunOptions) error { } backendSource, err = backend.NewS3Backend(opts.bucket, opts.key, opts.versionId, "appstream_machine_role") - default: return fmt.Errorf("invalid source provided") } @@ -67,13 +73,13 @@ func run(sourceType string, opts *RunOptions) error { return fmt.Errorf("unable to create backend source: %w", err) } - config, err := backendSource.GetConfig() + config, err := backendSource.GetConfig(ctx) if err != nil { return fmt.Errorf("failed to fetch config from backend: %w", err) } - if err := validator.ValidateConfig(config); err != nil { + if err := validator.ValidateConfig(ctx, config); err != nil { return fmt.Errorf("config file validation failed: %w", err) } diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 018e86b..1cfde00 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -1,7 +1,11 @@ package backend -import c "github.com/aslamcodes/appstreamfile/internal/config" +import ( + "context" + + c "github.com/aslamcodes/appstreamfile/internal/config" +) type BackendSource interface { - GetConfig() (*c.Config, error) + GetConfig(ctx context.Context) (*c.Config, error) } diff --git a/internal/backend/local.go b/internal/backend/local.go index ce69adf..65f2940 100644 --- a/internal/backend/local.go +++ b/internal/backend/local.go @@ -1,6 +1,7 @@ package backend import ( + "context" "fmt" "os" @@ -12,7 +13,11 @@ type LocalBackend struct { Location string } -func (lb *LocalBackend) GetConfig() (*config.Config, error) { +func (lb *LocalBackend) GetConfig(ctx context.Context) (*config.Config, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + fmt.Printf("Attempting to fetch config from local backend at %s\n", lb.Location) data, err := os.ReadFile(lb.Location) diff --git a/internal/backend/local_test.go b/internal/backend/local_test.go index 5d13189..c13682d 100644 --- a/internal/backend/local_test.go +++ b/internal/backend/local_test.go @@ -1,6 +1,7 @@ package backend_test import ( + "context" "fmt" "os" "reflect" @@ -15,7 +16,7 @@ func TestGetConfig(t *testing.T) { Location: "../../testdata/config_win.yaml", } - actual, err := localBackend.GetConfig() + actual, err := localBackend.GetConfig(context.TODO()) if err != nil { t.Fatal(err) diff --git a/internal/backend/s3.go b/internal/backend/s3.go index 508de10..fc3db96 100644 --- a/internal/backend/s3.go +++ b/internal/backend/s3.go @@ -16,8 +16,10 @@ type S3Backend struct { Client S3Client } -func (s3Backend *S3Backend) GetConfig() (*config.Config, error) { - ctx := context.Background() +func (s3Backend *S3Backend) GetConfig(ctx context.Context) (*config.Config, error) { + if err := ctx.Err(); err != nil { + return nil, err + } if s3Backend.Client == nil { return nil, fmt.Errorf("client is nil") diff --git a/internal/backend/s3_test.go b/internal/backend/s3_test.go index f42fb39..452dd8e 100644 --- a/internal/backend/s3_test.go +++ b/internal/backend/s3_test.go @@ -52,7 +52,7 @@ installers: Client: client, } - actual, err := backend.GetConfig() + actual, err := backend.GetConfig(context.TODO()) if err != nil { t.Errorf("error fetching the config: %v", err) @@ -76,7 +76,7 @@ func TestGetConfigFail(t *testing.T) { }, } - _, err := backend.GetConfig() + _, err := backend.GetConfig(context.TODO()) if err == nil { t.Errorf("expected %v, got nil", expectedErr) diff --git a/internal/validator/catalog_validator.go b/internal/validator/catalog_validator.go index 945bb3a..84502a2 100644 --- a/internal/validator/catalog_validator.go +++ b/internal/validator/catalog_validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "github.com/aslamcodes/appstreamfile/internal/config" @@ -12,7 +13,7 @@ var ( ErrEmptyCatalogNamePath = errors.New("catalog application path and name cannot be empty") ) -func ValidateCatalogApplications(configData *config.Config) error { +func ValidateCatalogApplications(ctx context.Context, configData *config.Config) error { for _, c := range configData.Catalogs { if c.Name == "" && c.Path == "" { return ErrEmptyCatalogNamePath diff --git a/internal/validator/catalog_validator_test.go b/internal/validator/catalog_validator_test.go index 83baaba..d834b37 100644 --- a/internal/validator/catalog_validator_test.go +++ b/internal/validator/catalog_validator_test.go @@ -1,6 +1,7 @@ package validator_test import ( + "context" "testing" "github.com/aslamcodes/appstreamfile/internal/backend" @@ -40,13 +41,13 @@ func TestValidateCatalogConfig(t *testing.T) { Location: tC.filename, } - config, err := lb.GetConfig() + config, err := lb.GetConfig(context.TODO()) if err != nil { t.Errorf("unable to load config: %s", err.Error()) } - actual := validator.ValidateCatalogApplications(config) + actual := validator.ValidateCatalogApplications(context.TODO(), config) if actual != tC.expected { t.Errorf("expected %v, actual %v", tC.expected, actual) diff --git a/internal/validator/config_validator.go b/internal/validator/config_validator.go index dab7069..5f2a636 100644 --- a/internal/validator/config_validator.go +++ b/internal/validator/config_validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "slices" @@ -12,8 +13,7 @@ var ( ErrInvalidPlatform = errors.New("Platform not supported") ) - -func ValidatePlatforms(c *config.Config) error { +func ValidatePlatforms(ctx context.Context, c *config.Config) error { if c.Platform == "" { return ErrPlatformMissing } diff --git a/internal/validator/config_validator_test.go b/internal/validator/config_validator_test.go index a480b8c..8e11149 100644 --- a/internal/validator/config_validator_test.go +++ b/internal/validator/config_validator_test.go @@ -1,6 +1,7 @@ package validator_test import ( + "context" "errors" "fmt" "os" @@ -46,6 +47,8 @@ installers: } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { + ctx := context.TODO() + file, err := os.CreateTemp("../../testdata", fmt.Sprintf("test_%s_", tC.desc)) if err != nil { @@ -58,13 +61,13 @@ installers: Location: file.Name(), } - configData, err := lb.GetConfig() + configData, err := lb.GetConfig(ctx) if err != nil { t.Errorf("unable to fetch config data: %v", err) } - err = validator.ValidatePlatforms(configData) + err = validator.ValidatePlatforms(ctx, configData) if !errors.Is(err, tC.expected) { t.Errorf("expected %v, got %v", tC.expected, err) diff --git a/internal/validator/file_validator.go b/internal/validator/file_validator.go index 6badf4a..478bb72 100644 --- a/internal/validator/file_validator.go +++ b/internal/validator/file_validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "github.com/aslamcodes/appstreamfile/internal/config" @@ -10,7 +11,7 @@ var ( ErrFileDeployPathMissing = errors.New("file deploy path should not be null") ) -func ValidateFileDeploys(c *config.Config) error { +func ValidateFileDeploys(ctx context.Context, c *config.Config) error { for _, file := range c.Files { if file.Path == "" { return ErrFileDeployPathMissing diff --git a/internal/validator/file_validator_test.go b/internal/validator/file_validator_test.go index 0ea5c76..5c3f850 100644 --- a/internal/validator/file_validator_test.go +++ b/internal/validator/file_validator_test.go @@ -1,6 +1,7 @@ package validator_test import ( + "context" "errors" "fmt" "os" @@ -52,13 +53,15 @@ files: Location: file.Name(), } - configData, err := lb.GetConfig() + ctx := context.TODO() + + configData, err := lb.GetConfig(ctx) if err != nil { t.Errorf("unable to fetch config data: %v", err) } - err = validator.ValidateFileDeploys(configData) + err = validator.ValidateFileDeploys(ctx, configData) if !errors.Is(err, tC.expected) { t.Errorf("expected %v, got %v", tC.expected, err) diff --git a/internal/validator/image_validator.go b/internal/validator/image_validator.go index f7a8627..9569f4b 100644 --- a/internal/validator/image_validator.go +++ b/internal/validator/image_validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "strings" @@ -12,7 +13,7 @@ var ( ErrInvalidTagsCreateImage = errors.New("format invalid for create-image tags (key1:value1)") ) -func ValidateImage(c *config.Config) error { +func ValidateImage(ctx context.Context, c *config.Config) error { if c.Image.Name == "" { return ErrInvalidParametersCreateImage } diff --git a/internal/validator/image_validator_test.go b/internal/validator/image_validator_test.go index 0a0ee93..b13f838 100644 --- a/internal/validator/image_validator_test.go +++ b/internal/validator/image_validator_test.go @@ -1,6 +1,7 @@ package validator_test import ( + "context" "errors" "fmt" "os" @@ -64,6 +65,8 @@ image: } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { + ctx := context.TODO() + file, err := os.CreateTemp("../../testdata", fmt.Sprintf("test_%s_", tC.desc)) if err != nil { @@ -76,13 +79,13 @@ image: Location: file.Name(), } - configData, err := lb.GetConfig() + configData, err := lb.GetConfig(ctx) if err != nil { t.Errorf("unable to fetch config data: %v", err) } - err = validator.ValidateImage(configData) + err = validator.ValidateImage(ctx, configData) if !errors.Is(err, tC.expected) { t.Errorf("expected %v, got %v", tC.expected, err) diff --git a/internal/validator/installer_validator.go b/internal/validator/installer_validator.go index 8d84df1..62d83de 100644 --- a/internal/validator/installer_validator.go +++ b/internal/validator/installer_validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "fmt" "slices" @@ -12,7 +13,7 @@ var ( ErrInvalidExecutableForPlatform = errors.New("Invalid executable for given platform") ) -func InstallerValidator(configData *config.Config) error { +func InstallerValidator(ctx context.Context, configData *config.Config) error { platform := configData.Platform platformExecs, exists := ExecPlatformMap[platform] diff --git a/internal/validator/installer_validator_test.go b/internal/validator/installer_validator_test.go index 6f17066..11902c2 100644 --- a/internal/validator/installer_validator_test.go +++ b/internal/validator/installer_validator_test.go @@ -1,6 +1,7 @@ package validator_test import ( + "context" "errors" "fmt" "os" @@ -48,14 +49,14 @@ installers: lb := backend.LocalBackend{ Location: file.Name(), } - - configData, err := lb.GetConfig() +ctx := context.TODO() + configData, err := lb.GetConfig(ctx) if err != nil { t.Errorf("unable to fetch config data: %v", err) } - err = validator.InstallerValidator(configData) + err = validator.InstallerValidator(ctx,configData) if !errors.Is(err, tC.expected) { t.Errorf("expected %v, got %v", tC.expected, err) diff --git a/internal/validator/path_validator.go b/internal/validator/path_validator.go index f85895c..25d6c4c 100644 --- a/internal/validator/path_validator.go +++ b/internal/validator/path_validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "errors" "fmt" "regexp" @@ -14,7 +15,7 @@ var ( ErrInvalidPathForWindows = errors.New("path is invalid for windows") ) -func ValidatePaths(c *config.Config) error { +func ValidatePaths(ctx context.Context, c *config.Config) error { var ( winDrive = regexp.MustCompile(`^[A-Za-z]:`) uncPath = regexp.MustCompile(`^\\\\`) diff --git a/internal/validator/path_validator_test.go b/internal/validator/path_validator_test.go index d5b7211..ef7094c 100644 --- a/internal/validator/path_validator_test.go +++ b/internal/validator/path_validator_test.go @@ -1,6 +1,7 @@ package validator_test import ( + "context" "errors" "fmt" "os" @@ -49,13 +50,14 @@ files: Location: file.Name(), } - configData, err := lb.GetConfig() + ctx := context.TODO() + configData, err := lb.GetConfig(ctx) if err != nil { t.Errorf("unable to fetch config data: %v", err) } - err = validator.ValidatePaths(configData) + err = validator.ValidatePaths(ctx, configData) if !errors.Is(err, tC.expected) { t.Errorf("expected %v, got %v", tC.expected, err) diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 1152209..31ae748 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -1,6 +1,7 @@ package validator import ( + "context" "fmt" "github.com/aslamcodes/appstreamfile/internal/config" @@ -11,8 +12,8 @@ var ExecPlatformMap = map[string][]string{ "unix": {"bash"}, } -func ValidateConfig(c *config.Config) error { - validators := []func(*config.Config) error{ +func ValidateConfig(ctx context.Context, c *config.Config) error { + validators := []func(context.Context, *config.Config) error{ ValidateCatalogApplications, ValidateFileDeploys, ValidateImage, @@ -22,7 +23,7 @@ func ValidateConfig(c *config.Config) error { } for _, v := range validators { - if err := v(c); err != nil { + if err := v(ctx, c); err != nil { return err } }