diff --git a/cmd/graph.go b/cmd/graph.go index 38ee10b2..d3a67832 100644 --- a/cmd/graph.go +++ b/cmd/graph.go @@ -103,9 +103,22 @@ func NewGraphRunCommand(factory *providerfactory.ProviderFactory, scenarioOrches return err } + outputDir, err := cmd.Flags().GetString("output") if err != nil { return err } + var outputDirPath *string + if outputDir != "" { + expandedOutputDir, err := commonutils.ExpandFolder(outputDir, nil) + if err != nil { + return err + } + // Ensure output directory exists + if err = EnsureDirectory(*expandedOutputDir); err != nil { + return fmt.Errorf("failed to create output directory %s: %v", *expandedOutputDir, err) + } + outputDirPath = expandedOutputDir + } kubeconfigPath, err := utils.PrepareKubeconfig(&kubeconfig, config) if err != nil { @@ -114,6 +127,12 @@ func NewGraphRunCommand(factory *providerfactory.ProviderFactory, scenarioOrches if kubeconfigPath == nil { return fmt.Errorf("kubeconfig not found: %s", kubeconfig) } + // Clean up kubeconfig file on exit + defer func() { + if kubeconfigPath != nil { + _ = os.Remove(*kubeconfigPath) + } + }() volumes[*kubeconfigPath] = config.KubeconfigPath if metricsProfile != "" { @@ -199,7 +218,7 @@ func NewGraphRunCommand(factory *providerfactory.ProviderFactory, scenarioOrches commChannel := make(chan *models.GraphCommChannel) go func() { - (*scenarioOrchestrator).RunGraph(nodes, executionPlan, environment, volumes, false, commChannel, registrySettings, nil) + (*scenarioOrchestrator).RunGraph(nodes, executionPlan, environment, volumes, false, commChannel, registrySettings, nil, outputDirPath) }() for { @@ -222,6 +241,10 @@ func NewGraphRunCommand(factory *providerfactory.ProviderFactory, scenarioOrches } } if exitOnerror { + // Clean up kubeconfig before exiting + if kubeconfigPath != nil { + _ = os.Remove(*kubeconfigPath) + } _, err = color.New(color.FgHiRed).Println(fmt.Sprintf("aborting chaos run with exit status %d", staterr.ExitStatus)) if err != nil { return err diff --git a/cmd/random.go b/cmd/random.go index a8d6e96c..7b9bb95c 100644 --- a/cmd/random.go +++ b/cmd/random.go @@ -206,7 +206,7 @@ func NewRandomRunCommand(factory *providerfactory.ProviderFactory, scenarioOrche commChannel := make(chan *models.GraphCommChannel) go func() { - (*scenarioOrchestrator).RunGraph(nodes, executionPlan, environment, volumes, false, commChannel, registrySettings, nil) + (*scenarioOrchestrator).RunGraph(nodes, executionPlan, environment, volumes, false, commChannel, registrySettings, nil, nil) }() for { diff --git a/cmd/root.go b/cmd/root.go index 1db48361..46428205 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -76,6 +76,7 @@ func Execute(providerFactory *factory.ProviderFactory, scenarioOrchestrator *sce graphRunCmd.Flags().String("alerts-profile", "", "custom alerts profile file path") graphRunCmd.Flags().String("metrics-profile", "", "custom metrics profile file path") graphRunCmd.Flags().Bool("exit-on-error", false, "if set this flag will the workflow will be interrupted and the tool will exit with a status greater than 0") + graphRunCmd.Flags().StringP("output", "o", "", "output directory for log files (default: current working directory)") graphScaffoldCmd := NewGraphScaffoldCommand(providerFactory, config) graphScaffoldCmd.Flags().Bool("global-env", false, "if set this flag will add global environment variables to each scenario in the graph") graphCmd.AddCommand(graphRunCmd) diff --git a/cmd/utils.go b/cmd/utils.go index dfbb5669..4e953e37 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -83,6 +83,12 @@ func CheckFileExists(filePath string) bool { return true } +// EnsureDirectory creates a directory if it doesn't exist +// Returns an error if the directory cannot be created +func EnsureDirectory(dirPath string) error { + return os.MkdirAll(dirPath, 0755) +} + func ParseFlags(scenarioDetail *models.ScenarioDetail, args []string, scenarioCollectedFlags map[string]*string, skipDefault bool) (env *map[string]ParsedField, vol *map[string]string, err error) { environment := make(map[string]ParsedField) volumes := make(map[string]string) diff --git a/pkg/scenarioorchestrator/common_functions.go b/pkg/scenarioorchestrator/common_functions.go index 786d6b5d..f4b7e0a8 100644 --- a/pkg/scenarioorchestrator/common_functions.go +++ b/pkg/scenarioorchestrator/common_functions.go @@ -28,6 +28,7 @@ func CommonRunGraph( config config.Config, registry *providermodels.RegistryV2, userID *int, + outputDir *string, ) { for step, s := range resolvedGraph { var wg sync.WaitGroup @@ -66,20 +67,26 @@ func CommonRunGraph( containerName := utils.GenerateContainerName(config, scenario.Name, &scID) filename := fmt.Sprintf("%s.log", containerName) - file, err := os.Create(path.Clean(filename)) + var logPath string + if outputDir != nil && *outputDir != "" { + logPath = path.Join(*outputDir, filename) + } else { + logPath = path.Clean(filename) + } + file, err := os.Create(logPath) if err != nil { commChannel <- &models.GraphCommChannel{Layer: nil, ScenarioID: nil, ScenarioLogFile: nil, Err: err} return } - commChannel <- &models.GraphCommChannel{Layer: &step, ScenarioID: &scID, ScenarioLogFile: &filename, Err: nil} + commChannel <- &models.GraphCommChannel{Layer: &step, ScenarioID: &scID, ScenarioLogFile: &logPath, Err: nil} wg.Add(1) go func() { defer wg.Done() _, err = orchestrator.RunAttached(scenario.Image, containerName, env, cache, volumes, file, file, nil, ctx, registry) if err != nil { - commChannel <- &models.GraphCommChannel{Layer: &step, ScenarioID: &scID, ScenarioLogFile: &filename, Err: err} + commChannel <- &models.GraphCommChannel{Layer: &step, ScenarioID: &scID, ScenarioLogFile: &logPath, Err: err} return } }() diff --git a/pkg/scenarioorchestrator/docker/scenario_orchestrator.go b/pkg/scenarioorchestrator/docker/scenario_orchestrator.go index f81fbbcf..6acfdaa9 100644 --- a/pkg/scenarioorchestrator/docker/scenario_orchestrator.go +++ b/pkg/scenarioorchestrator/docker/scenario_orchestrator.go @@ -466,8 +466,9 @@ func (c *ScenarioOrchestrator) RunGraph( commChannel chan *orchestratormodels.GraphCommChannel, registry *providermodels.RegistryV2, userID *int, + outputDir *string, ) { - scenarioorchestrator.CommonRunGraph(scenarios, resolvedGraph, extraEnv, extraVolumeMounts, cache, commChannel, c, c.Config, registry, userID) + scenarioorchestrator.CommonRunGraph(scenarios, resolvedGraph, extraEnv, extraVolumeMounts, cache, commChannel, c, c.Config, registry, userID, outputDir) } func (c *ScenarioOrchestrator) PrintContainerRuntime() { diff --git a/pkg/scenarioorchestrator/podman/scenario_orchestrator.go b/pkg/scenarioorchestrator/podman/scenario_orchestrator.go index 9d6bf1b0..ee823a16 100644 --- a/pkg/scenarioorchestrator/podman/scenario_orchestrator.go +++ b/pkg/scenarioorchestrator/podman/scenario_orchestrator.go @@ -373,9 +373,10 @@ func (c *ScenarioOrchestrator) RunGraph( commChannel chan *orchestratormodels.GraphCommChannel, registry *providermodels.RegistryV2, userID *int, + outputDir *string, ) { //TODO: add a getconfig method in scenarioOrchestrator - scenarioorchestrator.CommonRunGraph(scenarios, resolvedGraph, extraEnv, extraVolumeMounts, cache, commChannel, c, c.Config, registry, userID) + scenarioorchestrator.CommonRunGraph(scenarios, resolvedGraph, extraEnv, extraVolumeMounts, cache, commChannel, c, c.Config, registry, userID, outputDir) } func (c *ScenarioOrchestrator) PrintContainerRuntime() { diff --git a/pkg/scenarioorchestrator/scenario_orchestrator.go b/pkg/scenarioorchestrator/scenario_orchestrator.go index 6d9805ae..ed23f8c6 100644 --- a/pkg/scenarioorchestrator/scenario_orchestrator.go +++ b/pkg/scenarioorchestrator/scenario_orchestrator.go @@ -46,6 +46,7 @@ type ScenarioOrchestrator interface { commChannel chan *orchestrator_models.GraphCommChannel, registry *models.RegistryV2, userID *int, + outputDir *string, ) CleanContainers(ctx context.Context) (*int, error) diff --git a/pkg/scenarioorchestrator/scenarioorchestratortest/common_test_functions.go b/pkg/scenarioorchestrator/scenarioorchestratortest/common_test_functions.go index 83ae9a61..54ef8767 100644 --- a/pkg/scenarioorchestrator/scenarioorchestratortest/common_test_functions.go +++ b/pkg/scenarioorchestrator/scenarioorchestratortest/common_test_functions.go @@ -330,7 +330,7 @@ func CommonTestScenarioOrchestratorRunGraph(t *testing.T, so scenarioorchestrato commChannel := make(chan *models.GraphCommChannel) go func() { - so.RunGraph(nodes, executionPlan, map[string]string{}, map[string]string{}, false, commChannel, nil, uid) + so.RunGraph(nodes, executionPlan, map[string]string{}, map[string]string{}, false, commChannel, nil, uid, nil) }() for { @@ -406,7 +406,7 @@ func CommonTestScenarioOrchestratorRunGraph(t *testing.T, so scenarioorchestrato commChannel = make(chan *models.GraphCommChannel) go func() { - so.RunGraph(nodes, executionPlan, map[string]string{}, map[string]string{}, false, commChannel, nil, uid) + so.RunGraph(nodes, executionPlan, map[string]string{}, map[string]string{}, false, commChannel, nil, uid, nil) }() for {