diff --git a/pkg/cli/admin/mustgather/mustgather.go b/pkg/cli/admin/mustgather/mustgather.go index 4f03e5ea5e..30d0785fb7 100644 --- a/pkg/cli/admin/mustgather/mustgather.go +++ b/pkg/cli/admin/mustgather/mustgather.go @@ -242,7 +242,7 @@ func (o *MustGatherOptions) Complete(f kcmdutil.Factory, cmd *cobra.Command, arg return fmt.Errorf("--run-namespace %s", errStr) } } - if err := o.completeImages(); err != nil { + if err := o.completeImages(context.TODO()); err != nil { return err } o.PrinterCreated, err = printers.NewTypeSetter(scheme.Scheme).WrapToPrinter(&printers.NamePrinter{Operation: "created"}, nil) @@ -257,9 +257,9 @@ func (o *MustGatherOptions) Complete(f kcmdutil.Factory, cmd *cobra.Command, arg return nil } -func (o *MustGatherOptions) completeImages() error { +func (o *MustGatherOptions) completeImages(ctx context.Context) error { for _, imageStream := range o.ImageStreams { - if image, err := o.resolveImageStreamTagString(imageStream); err == nil { + if image, err := o.resolveImageStreamTagString(ctx, imageStream); err == nil { o.Images = append(o.Images, image) } else { return fmt.Errorf("unable to resolve image stream '%v': %v", imageStream, err) @@ -268,7 +268,7 @@ func (o *MustGatherOptions) completeImages() error { if len(o.Images) == 0 || o.AllImages { var image string var err error - if image, err = o.resolveImageStreamTag("openshift", "must-gather", "latest"); err != nil { + if image, err = o.resolveImageStreamTag(ctx, "openshift", "must-gather", "latest"); err != nil { o.log("%v\n", err) image = "registry.redhat.io/openshift4/ose-must-gather:latest" } @@ -279,12 +279,12 @@ func (o *MustGatherOptions) completeImages() error { pluginImages := make(map[string]struct{}) var err error - pluginImages, err = o.annotatedCSVs() + pluginImages, err = o.annotatedCSVs(ctx) if err != nil { return err } - cos, err := o.ConfigClient.ConfigV1().ClusterOperators().List(context.TODO(), metav1.ListOptions{}) + cos, err := o.ConfigClient.ConfigV1().ClusterOperators().List(ctx, metav1.ListOptions{}) if err != nil { return err } @@ -305,7 +305,7 @@ func (o *MustGatherOptions) completeImages() error { return nil } -func (o *MustGatherOptions) annotatedCSVs() (map[string]struct{}, error) { +func (o *MustGatherOptions) annotatedCSVs(ctx context.Context) (map[string]struct{}, error) { csvGVR := schema.GroupVersionResource{ Group: "operators.coreos.com", Version: "v1alpha1", @@ -313,7 +313,7 @@ func (o *MustGatherOptions) annotatedCSVs() (map[string]struct{}, error) { } pluginImages := make(map[string]struct{}) - csvs, err := o.DynamicClient.Resource(csvGVR).List(context.TODO(), metav1.ListOptions{}) + csvs, err := o.DynamicClient.Resource(csvGVR).List(ctx, metav1.ListOptions{}) if err != nil { return nil, err } @@ -327,12 +327,12 @@ func (o *MustGatherOptions) annotatedCSVs() (map[string]struct{}, error) { return pluginImages, nil } -func (o *MustGatherOptions) resolveImageStreamTagString(s string) (string, error) { +func (o *MustGatherOptions) resolveImageStreamTagString(ctx context.Context, s string) (string, error) { namespace, name, tag := parseImageStreamTagString(s) if len(namespace) == 0 { return "", fmt.Errorf("expected namespace/name:tag") } - return o.resolveImageStreamTag(namespace, name, tag) + return o.resolveImageStreamTag(ctx, namespace, name, tag) } func parseImageStreamTagString(s string) (string, string, string) { @@ -349,8 +349,8 @@ func parseImageStreamTagString(s string) (string, string, string) { return namespace, name, tag } -func (o *MustGatherOptions) resolveImageStreamTag(namespace, name, tag string) (string, error) { - imageStream, err := o.ImageClient.ImageStreams(namespace).Get(context.TODO(), name, metav1.GetOptions{}) +func (o *MustGatherOptions) resolveImageStreamTag(ctx context.Context, namespace, name, tag string) (string, error) { + imageStream, err := o.ImageClient.ImageStreams(namespace).Get(ctx, name, metav1.GetOptions{}) if err != nil { return "", err } @@ -557,6 +557,10 @@ func getCandidateNodeNames(nodes *corev1.NodeList, hasMaster bool) []string { func (o *MustGatherOptions) Run() error { var errs []error + // The following context is being used for now until proper signal handling is implemented. + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + if err := os.MkdirAll(o.DestDir, os.ModePerm); err != nil { // ensure the errors bubble up to BackupGathering method for display errs = []error{err} @@ -579,14 +583,20 @@ func (o *MustGatherOptions) Run() error { } // print at both the beginning and at the end. This information is important enough to be in both spots. - o.PrintBasicClusterState(context.TODO()) + o.PrintBasicClusterState(ctx) defer func() { + // Shortcircuit this in case the context is already cancelled. + if ctx.Err() != nil { + klog.Warning("Reprinting cluster state skipped, terminating...") + return + } fmt.Fprintf(o.RawOut, "\n\n") fmt.Fprintf(o.RawOut, "Reprinting Cluster State:\n") - o.PrintBasicClusterState(context.TODO()) + o.PrintBasicClusterState(ctx) }() // Ensure resource cleanup unless instructed otherwise ... + // There is no context passed into the cleanup function as it's unrelated to the main context. var cleanupNamespace func() if !o.Keep { defer func() { @@ -596,19 +606,20 @@ func (o *MustGatherOptions) Run() error { }() } - // Due to 'stack unwiding', this should happen after 'clusterState' printing, to ensure that we always - // print our ClusterState information. + // Due to 'stack unwinding', this should happen after 'clusterState' printing, to ensure that we always + // print our ClusterState information. runBackCollection := true defer func() { - if !runBackCollection { + // Shortcircuit this in case the context is already cancelled. + if ctx.Err() != nil || !runBackCollection { return } - o.BackupGathering(context.TODO(), errs) + o.BackupGathering(ctx, errs) }() // Get or create "working" namespace ... var ns *corev1.Namespace - ns, cleanupNamespace, err = o.getNamespace() + ns, cleanupNamespace, err = o.getNamespace(ctx) if err != nil { // ensure the errors bubble up to BackupGathering method for display errs = []error{err} @@ -617,7 +628,7 @@ func (o *MustGatherOptions) Run() error { // Prefer to run in master if there's any but don't be explicit otherwise. // This enables the command to run by default in hypershift where there's no masters. - nodes, err := o.Client.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{ + nodes, err := o.Client.CoreV1().Nodes().List(ctx, metav1.ListOptions{ LabelSelector: o.NodeSelector, }) if err != nil { @@ -645,7 +656,7 @@ func (o *MustGatherOptions) Run() error { return err } if o.NodeSelector != "" { - nodes, err := o.Client.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{ + nodes, err := o.Client.CoreV1().Nodes().List(ctx, metav1.ListOptions{ LabelSelector: o.NodeSelector, }) if err != nil { @@ -658,7 +669,7 @@ func (o *MustGatherOptions) Run() error { } } else { if o.NodeName != "" { - if _, err := o.Client.CoreV1().Nodes().Get(context.TODO(), o.NodeName, metav1.GetOptions{}); err != nil { + if _, err := o.Client.CoreV1().Nodes().Get(ctx, o.NodeName, metav1.GetOptions{}); err != nil { // ensure the errors bubble up to BackupGathering method for display errs = []error{err} return err @@ -676,7 +687,7 @@ func (o *MustGatherOptions) Run() error { } defer o.logTimestamp() - queue := workqueue.NewRateLimitingQueue(workqueue.DefaultControllerRateLimiter()) + queue := workqueue.NewTypedRateLimitingQueue[*corev1.Pod](workqueue.DefaultTypedControllerRateLimiter[*corev1.Pod]()) var wg sync.WaitGroup errCh := make(chan error, len(pods)) @@ -686,6 +697,11 @@ func (o *MustGatherOptions) Run() error { } queue.ShutDownWithDrain() + go func() { + <-ctx.Done() + queue.ShutDown() + }() + wg.Add(concurrentMG) for i := 0; i < concurrentMG; i++ { go func() { @@ -695,10 +711,14 @@ func (o *MustGatherOptions) Run() error { if quit { return } - defer queue.Done(pod) - if err := o.processNextWorkItem(ns.Name, pod.(*corev1.Pod)); err != nil { - errCh <- err - } + func() { + // Make sure queue.Done is called inside the current iteration, + // not just when the whole worker thread exits. + defer queue.Done(pod) + if err := o.processNextWorkItem(ctx, ns.Name, pod); err != nil { + errCh <- err + } + }() } }() } @@ -726,9 +746,9 @@ func (o *MustGatherOptions) Run() error { } // processNextWorkItem creates & processes the must-gather pod and returns error if any -func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) error { +func (o *MustGatherOptions) processNextWorkItem(ctx context.Context, ns string, pod *corev1.Pod) error { var err error - pod, err = o.Client.CoreV1().Pods(ns).Create(context.TODO(), pod, metav1.CreateOptions{}) + pod, err = o.Client.CoreV1().Pods(ns).Create(ctx, pod, metav1.CreateOptions{}) if err != nil { return err } @@ -743,8 +763,7 @@ func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) erro // it deletes this namespace and all the pods are removed by garbage collector. // However, if user specifies namespace via `run-namespace`, these pods need to // be deleted manually. - err = o.Client.CoreV1().Pods(o.RunNamespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{}) - if err != nil { + if err := o.Client.CoreV1().Pods(o.RunNamespace).Delete(ctx, pod.Name, metav1.DeleteOptions{}); err != nil { klog.V(4).Infof("pod deletion error %v", err) } }() @@ -753,19 +772,18 @@ func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) erro log := o.newPodOutLogger(o.Out, pod.Name) // wait for gather container to be running (gather is running) - if err := o.waitForGatherContainerRunning(pod); err != nil { + if err := o.waitForGatherContainerRunning(ctx, pod); err != nil { log("gather did not start: %s", err) - return fmt.Errorf("gather did not start for pod %s: %s", pod.Name, err) - + return fmt.Errorf("gather did not start for pod %s: %w", pod.Name, err) } // stream gather container logs - if err := o.getGatherContainerLogs(pod); err != nil { + if err := o.getGatherContainerLogs(ctx, pod); err != nil { log("gather logs unavailable: %v", err) } // wait for pod to be running (gather has completed) log("waiting for gather to complete") - if err := o.waitForGatherToComplete(pod); err != nil { + if err := o.waitForGatherToComplete(ctx, pod); err != nil { log("gather never finished: %v", err) if exiterr, ok := err.(*exec.CodeExitError); ok { return exiterr @@ -775,12 +793,12 @@ func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) erro // copy the gathered files into the local destination dir log("downloading gather output") - pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(context.TODO(), pod.Name, metav1.GetOptions{}) + pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}) if err != nil { log("gather output not downloaded: %v\n", err) return fmt.Errorf("unable to download output from pod %s: %s", pod.Name, err) } - if err := o.copyFilesFromPod(pod); err != nil { + if err := o.copyFilesFromPod(ctx, pod); err != nil { log("gather output not downloaded: %v\n", err) return fmt.Errorf("unable to download output from pod %s: %s", pod.Name, err) } @@ -807,7 +825,7 @@ func (o *MustGatherOptions) logTimestamp() error { return err } -func (o *MustGatherOptions) copyFilesFromPod(pod *corev1.Pod) error { +func (o *MustGatherOptions) copyFilesFromPod(ctx context.Context, pod *corev1.Pod) error { streams := o.IOStreams streams.Out = o.newPrefixWriter(streams.Out, fmt.Sprintf("[%s] OUT", pod.Name), false, true) imageFolder := regexp.MustCompile("[^A-Za-z0-9]+").ReplaceAllString(pod.Status.ContainerStatuses[0].ImageID, "-") @@ -835,7 +853,7 @@ func (o *MustGatherOptions) copyFilesFromPod(pod *corev1.Pod) error { Container: gatherContainerName, Timestamps: true, } - readCloser, err := o.Client.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions).Stream(context.TODO()) + readCloser, err := o.Client.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions).Stream(ctx) if err != nil { return err } @@ -869,7 +887,7 @@ func (o *MustGatherOptions) copyFilesFromPod(pod *corev1.Pod) error { return kutilerrors.NewAggregate(errs) } -func (o *MustGatherOptions) getGatherContainerLogs(pod *corev1.Pod) error { +func (o *MustGatherOptions) getGatherContainerLogs(ctx context.Context, pod *corev1.Pod) error { since2s := int64(2) opts := &logs.LogsOptions{ Namespace: pod.Namespace, @@ -890,13 +908,15 @@ func (o *MustGatherOptions) getGatherContainerLogs(pod *corev1.Pod) error { // gather script might take longer than the default API server time, // so we should check if the gather script still runs and re-run logs // thus we run this in a loop + // + // TODO: Use opts.RunLogsContext once Kubernetes v1.35 is available. if err := opts.RunLogs(); err != nil { return err } // to ensure we don't print all of history set since to past 2 seconds opts.Options.(*corev1.PodLogOptions).SinceSeconds = &since2s - if done, _ := o.isGatherDone(pod); done { + if done, _ := o.isGatherDone(ctx, pod); done { return nil } klog.V(4).Infof("lost logs, re-trying...") @@ -927,15 +947,15 @@ func (o *MustGatherOptions) newPrefixWriter(out io.Writer, prefix string, ignore return writer } -func (o *MustGatherOptions) waitForGatherToComplete(pod *corev1.Pod) error { - return wait.PollUntilContextTimeout(context.TODO(), 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { - return o.isGatherDone(pod) +func (o *MustGatherOptions) waitForGatherToComplete(ctx context.Context, pod *corev1.Pod) error { + return wait.PollUntilContextTimeout(ctx, 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { + return o.isGatherDone(ctx, pod) }) } -func (o *MustGatherOptions) isGatherDone(pod *corev1.Pod) (bool, error) { +func (o *MustGatherOptions) isGatherDone(ctx context.Context, pod *corev1.Pod) (bool, error) { var err error - if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(context.TODO(), pod.Name, metav1.GetOptions{}); err != nil { + if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}); err != nil { // at this stage pod should exist, we've been gathering container logs, so error if not found if kerrors.IsNotFound(err) { return true, err @@ -967,10 +987,10 @@ func (o *MustGatherOptions) isGatherDone(pod *corev1.Pod) (bool, error) { return false, nil } -func (o *MustGatherOptions) waitForGatherContainerRunning(pod *corev1.Pod) error { - return wait.PollUntilContextTimeout(context.TODO(), 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { +func (o *MustGatherOptions) waitForGatherContainerRunning(ctx context.Context, pod *corev1.Pod) error { + return wait.PollUntilContextTimeout(ctx, 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { var err error - if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(context.TODO(), pod.Name, metav1.GetOptions{}); err == nil { + if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}); err == nil { if len(pod.Status.ContainerStatuses) == 0 { return false, nil } @@ -992,12 +1012,12 @@ func (o *MustGatherOptions) waitForGatherContainerRunning(pod *corev1.Pod) error }) } -func (o *MustGatherOptions) getNamespace() (*corev1.Namespace, func(), error) { +func (o *MustGatherOptions) getNamespace(ctx context.Context) (*corev1.Namespace, func(), error) { if o.RunNamespace == "" { - return o.createTempNamespace() + return o.createTempNamespace(ctx) } - ns, err := o.Client.CoreV1().Namespaces().Get(context.TODO(), o.RunNamespace, metav1.GetOptions{}) + ns, err := o.Client.CoreV1().Namespaces().Get(ctx, o.RunNamespace, metav1.GetOptions{}) if err != nil { return nil, nil, fmt.Errorf("retrieving namespace %q: %w", o.RunNamespace, err) } @@ -1005,19 +1025,13 @@ func (o *MustGatherOptions) getNamespace() (*corev1.Namespace, func(), error) { return ns, func() {}, nil } -func (o *MustGatherOptions) createTempNamespace() (*corev1.Namespace, func(), error) { - ns, err := o.Client.CoreV1().Namespaces().Create(context.TODO(), newNamespace(), metav1.CreateOptions{}) +func (o *MustGatherOptions) createTempNamespace(ctx context.Context) (*corev1.Namespace, func(), error) { + ns, err := o.Client.CoreV1().Namespaces().Create(ctx, newNamespace(), metav1.CreateOptions{}) if err != nil { return nil, nil, fmt.Errorf("creating temp namespace: %w", err) } o.PrinterCreated.PrintObj(ns, o.LogOut) - crb, err := o.Client.RbacV1().ClusterRoleBindings().Create(context.TODO(), newClusterRoleBinding(ns), metav1.CreateOptions{}) - if err != nil { - return nil, nil, fmt.Errorf("creating temp clusterRoleBinding: %w", err) - } - o.PrinterCreated.PrintObj(crb, o.LogOut) - cleanup := func() { if err := o.Client.CoreV1().Namespaces().Delete(context.TODO(), ns.Name, metav1.DeleteOptions{}); err != nil { fmt.Printf("%v\n", err) @@ -1026,6 +1040,12 @@ func (o *MustGatherOptions) createTempNamespace() (*corev1.Namespace, func(), er } } + crb, err := o.Client.RbacV1().ClusterRoleBindings().Create(ctx, newClusterRoleBinding(ns), metav1.CreateOptions{}) + if err != nil { + return nil, cleanup, fmt.Errorf("creating temp clusterRoleBinding: %w", err) + } + o.PrinterCreated.PrintObj(crb, o.LogOut) + return ns, cleanup, nil } @@ -1237,19 +1257,19 @@ func (o *MustGatherOptions) BackupGathering(ctx context.Context, errs []error) { streams.Out = o.newPrefixWriter(streams.Out, fmt.Sprintf("[must-gather ] OUT"), false, true) destDir := path.Join(o.DestDir, fmt.Sprintf("inspect.local.%06d", rand.Int63())) - if err := runInspect(streams, rest.CopyConfig(o.Config), destDir, []string{typeTargets}); err != nil { + if err := runInspect(ctx, streams, rest.CopyConfig(o.Config), destDir, []string{typeTargets}); err != nil { fmt.Fprintf(o.ErrOut, "error completing cluster type inspection: %v\n", err) } fmt.Fprintf(o.ErrOut, "Falling back to `oc adm inspect %s` to collect basic cluster named resources.\n", strings.Join(namedTargets, " ")) - if err := runInspect(streams, rest.CopyConfig(o.Config), destDir, namedTargets); err != nil { + if err := runInspect(ctx, streams, rest.CopyConfig(o.Config), destDir, namedTargets); err != nil { fmt.Fprintf(o.ErrOut, "error completing cluster named resource inspection: %v\n", err) } return } -func runInspect(streams genericiooptions.IOStreams, config *rest.Config, destDir string, arguments []string) error { +func runInspect(ctx context.Context, streams genericiooptions.IOStreams, config *rest.Config, destDir string, arguments []string) error { inspectOptions := inspect.NewInspectOptions(streams) inspectOptions.RESTConfig = config inspectOptions.DestDir = destDir @@ -1260,7 +1280,7 @@ func runInspect(streams genericiooptions.IOStreams, config *rest.Config, destDir if err := inspectOptions.Validate(); err != nil { return fmt.Errorf("error validating backup collection: %w", err) } - if err := inspectOptions.Run(); err != nil { + if err := inspectOptions.RunContext(ctx); err != nil { return fmt.Errorf("error running backup collection: %w", err) } return nil diff --git a/pkg/cli/admin/mustgather/mustgather_test.go b/pkg/cli/admin/mustgather/mustgather_test.go index b854ef1c8d..f1fa1622d9 100644 --- a/pkg/cli/admin/mustgather/mustgather_test.go +++ b/pkg/cli/admin/mustgather/mustgather_test.go @@ -119,7 +119,7 @@ func TestImagesAndImageStreams(t *testing.T) { LogOut: genericiooptions.NewTestIOStreamsDiscard().Out, AllImages: tc.allImages, } - err := options.completeImages() + err := options.completeImages(context.TODO()) if err != nil { t.Fatal(err) } @@ -214,7 +214,7 @@ func TestGetNamespace(t *testing.T) { tc.Options.PrinterCreated = printers.NewDiscardingPrinter() tc.Options.PrinterDeleted = printers.NewDiscardingPrinter() - ns, cleanup, err := tc.Options.getNamespace() + ns, cleanup, err := tc.Options.getNamespace(context.TODO()) if err != nil { if tc.ShouldFail { return