diff --git a/pkg/cli/admin/mustgather/mustgather.go b/pkg/cli/admin/mustgather/mustgather.go index 4f03e5ea5e..f38ca3b4aa 100644 --- a/pkg/cli/admin/mustgather/mustgather.go +++ b/pkg/cli/admin/mustgather/mustgather.go @@ -8,12 +8,15 @@ import ( "io" "math/rand" "os" + "os/signal" "path" "regexp" "sort" "strconv" "strings" "sync" + "sync/atomic" + "syscall" "time" "github.com/spf13/cobra" @@ -33,7 +36,6 @@ import ( "k8s.io/client-go/rest" "k8s.io/client-go/util/workqueue" "k8s.io/klog/v2" - "k8s.io/kubectl/pkg/cmd/logs" kcmdutil "k8s.io/kubectl/pkg/cmd/util" "k8s.io/kubectl/pkg/polymorphichelpers" "k8s.io/kubectl/pkg/scheme" @@ -48,6 +50,7 @@ import ( imagereference "github.com/openshift/library-go/pkg/image/reference" "github.com/openshift/library-go/pkg/operator/resource/retry" "github.com/openshift/oc/pkg/cli/admin/inspect" + "github.com/openshift/oc/pkg/cli/internal/logs" "github.com/openshift/oc/pkg/cli/rsync" ocmdhelpers "github.com/openshift/oc/pkg/helpers/cmd" ) @@ -242,7 +245,7 @@ func (o *MustGatherOptions) Complete(f kcmdutil.Factory, cmd *cobra.Command, arg return fmt.Errorf("--run-namespace %s", errStr) } } - if err := o.completeImages(); err != nil { + if err := o.completeImages(context.Background()); err != nil { return err } o.PrinterCreated, err = printers.NewTypeSetter(scheme.Scheme).WrapToPrinter(&printers.NamePrinter{Operation: "created"}, nil) @@ -257,9 +260,9 @@ func (o *MustGatherOptions) Complete(f kcmdutil.Factory, cmd *cobra.Command, arg return nil } -func (o *MustGatherOptions) completeImages() error { +func (o *MustGatherOptions) completeImages(ctx context.Context) error { for _, imageStream := range o.ImageStreams { - if image, err := o.resolveImageStreamTagString(imageStream); err == nil { + if image, err := o.resolveImageStreamTagString(ctx, imageStream); err == nil { o.Images = append(o.Images, image) } else { return fmt.Errorf("unable to resolve image stream '%v': %v", imageStream, err) @@ -268,7 +271,7 @@ func (o *MustGatherOptions) completeImages() error { if len(o.Images) == 0 || o.AllImages { var image string var err error - if image, err = o.resolveImageStreamTag("openshift", "must-gather", "latest"); err != nil { + if image, err = o.resolveImageStreamTag(ctx, "openshift", "must-gather", "latest"); err != nil { o.log("%v\n", err) image = "registry.redhat.io/openshift4/ose-must-gather:latest" } @@ -279,12 +282,12 @@ func (o *MustGatherOptions) completeImages() error { pluginImages := make(map[string]struct{}) var err error - pluginImages, err = o.annotatedCSVs() + pluginImages, err = o.annotatedCSVs(ctx) if err != nil { return err } - cos, err := o.ConfigClient.ConfigV1().ClusterOperators().List(context.TODO(), metav1.ListOptions{}) + cos, err := o.ConfigClient.ConfigV1().ClusterOperators().List(ctx, metav1.ListOptions{}) if err != nil { return err } @@ -305,7 +308,7 @@ func (o *MustGatherOptions) completeImages() error { return nil } -func (o *MustGatherOptions) annotatedCSVs() (map[string]struct{}, error) { +func (o *MustGatherOptions) annotatedCSVs(ctx context.Context) (map[string]struct{}, error) { csvGVR := schema.GroupVersionResource{ Group: "operators.coreos.com", Version: "v1alpha1", @@ -313,7 +316,7 @@ func (o *MustGatherOptions) annotatedCSVs() (map[string]struct{}, error) { } pluginImages := make(map[string]struct{}) - csvs, err := o.DynamicClient.Resource(csvGVR).List(context.TODO(), metav1.ListOptions{}) + csvs, err := o.DynamicClient.Resource(csvGVR).List(ctx, metav1.ListOptions{}) if err != nil { return nil, err } @@ -327,12 +330,12 @@ func (o *MustGatherOptions) annotatedCSVs() (map[string]struct{}, error) { return pluginImages, nil } -func (o *MustGatherOptions) resolveImageStreamTagString(s string) (string, error) { +func (o *MustGatherOptions) resolveImageStreamTagString(ctx context.Context, s string) (string, error) { namespace, name, tag := parseImageStreamTagString(s) if len(namespace) == 0 { return "", fmt.Errorf("expected namespace/name:tag") } - return o.resolveImageStreamTag(namespace, name, tag) + return o.resolveImageStreamTag(ctx, namespace, name, tag) } func parseImageStreamTagString(s string) (string, string, string) { @@ -349,8 +352,8 @@ func parseImageStreamTagString(s string) (string, string, string) { return namespace, name, tag } -func (o *MustGatherOptions) resolveImageStreamTag(namespace, name, tag string) (string, error) { - imageStream, err := o.ImageClient.ImageStreams(namespace).Get(context.TODO(), name, metav1.GetOptions{}) +func (o *MustGatherOptions) resolveImageStreamTag(ctx context.Context, namespace, name, tag string) (string, error) { + imageStream, err := o.ImageClient.ImageStreams(namespace).Get(ctx, name, metav1.GetOptions{}) if err != nil { return "", err } @@ -398,6 +401,9 @@ type MustGatherOptions struct { LogWriter *os.File LogWriterMux sync.Mutex + + cleanupHooks []func(context.Context) + cleanupMux sync.Mutex } func (o *MustGatherOptions) Validate() error { @@ -578,37 +584,62 @@ func (o *MustGatherOptions) Run() error { }() } - // print at both the beginning and at the end. This information is important enough to be in both spots. - o.PrintBasicClusterState(context.TODO()) - defer func() { - fmt.Fprintf(o.RawOut, "\n\n") - fmt.Fprintf(o.RawOut, "Reprinting Cluster State:\n") - o.PrintBasicClusterState(context.TODO()) - }() + // Perform cleanup on termination. + // handlerRegistered is needed so that we don't print the signal received message + // while no signal was actually received, only the context cancelled in defer. + var ctx context.Context + { + var ( + unregisterHandler func() + handlerRegistered atomic.Bool + ) + handlerRegistered.Store(true) + ctx, unregisterHandler = signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer func() { + if handlerRegistered.CompareAndSwap(true, false) { + unregisterHandler() + } + }() + go func() { + <-ctx.Done() + if handlerRegistered.CompareAndSwap(true, false) { + o.log("Signal received, terminating. Another signal will cause immediate shutdown.") + unregisterHandler() + } + }() + } - // Ensure resource cleanup unless instructed otherwise ... - var cleanupNamespace func() + // Ensure to clean up resources unless instructed otherwise. if !o.Keep { defer func() { - if cleanupNamespace != nil { - cleanupNamespace() - } + o.log("Cleaning up cluster resources") + o.cleanup(context.Background()) }() } - // Due to 'stack unwiding', this should happen after 'clusterState' printing, to ensure that we always - // print our ClusterState information. + // print at both the beginning and at the end. This information is important enough to be in both spots. + o.PrintBasicClusterState(ctx) + defer func() { + if ctx.Err() != nil { + return + } + fmt.Fprintf(o.RawOut, "\n\n") + fmt.Fprintf(o.RawOut, "Reprinting Cluster State:\n") + o.PrintBasicClusterState(ctx) + }() + + // Due to 'stack unwinding', this should happen after 'clusterState' printing, to ensure that we always + // print our ClusterState information. runBackCollection := true defer func() { - if !runBackCollection { + if ctx.Err() != nil || !runBackCollection { return } - o.BackupGathering(context.TODO(), errs) + o.BackupGathering(ctx, errs) }() // Get or create "working" namespace ... - var ns *corev1.Namespace - ns, cleanupNamespace, err = o.getNamespace() + ns, err := o.getNamespace(ctx) if err != nil { // ensure the errors bubble up to BackupGathering method for display errs = []error{err} @@ -617,7 +648,7 @@ func (o *MustGatherOptions) Run() error { // Prefer to run in master if there's any but don't be explicit otherwise. // This enables the command to run by default in hypershift where there's no masters. - nodes, err := o.Client.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{ + nodes, err := o.Client.CoreV1().Nodes().List(ctx, metav1.ListOptions{ LabelSelector: o.NodeSelector, }) if err != nil { @@ -645,7 +676,7 @@ func (o *MustGatherOptions) Run() error { return err } if o.NodeSelector != "" { - nodes, err := o.Client.CoreV1().Nodes().List(context.TODO(), metav1.ListOptions{ + nodes, err := o.Client.CoreV1().Nodes().List(ctx, metav1.ListOptions{ LabelSelector: o.NodeSelector, }) if err != nil { @@ -658,7 +689,7 @@ func (o *MustGatherOptions) Run() error { } } else { if o.NodeName != "" { - if _, err := o.Client.CoreV1().Nodes().Get(context.TODO(), o.NodeName, metav1.GetOptions{}); err != nil { + if _, err := o.Client.CoreV1().Nodes().Get(ctx, o.NodeName, metav1.GetOptions{}); err != nil { // ensure the errors bubble up to BackupGathering method for display errs = []error{err} return err @@ -676,7 +707,7 @@ func (o *MustGatherOptions) Run() error { } defer o.logTimestamp() - queue := workqueue.NewRateLimitingQueue(workqueue.DefaultControllerRateLimiter()) + queue := workqueue.NewTypedRateLimitingQueue[*corev1.Pod](workqueue.DefaultTypedControllerRateLimiter[*corev1.Pod]()) var wg sync.WaitGroup errCh := make(chan error, len(pods)) @@ -685,6 +716,10 @@ func (o *MustGatherOptions) Run() error { queue.Add(pod) } queue.ShutDownWithDrain() + go func() { + <-ctx.Done() + queue.ShutDown() + }() wg.Add(concurrentMG) for i := 0; i < concurrentMG; i++ { @@ -695,10 +730,12 @@ func (o *MustGatherOptions) Run() error { if quit { return } - defer queue.Done(pod) - if err := o.processNextWorkItem(ns.Name, pod.(*corev1.Pod)); err != nil { - errCh <- err - } + func() { + defer queue.Done(pod) + if err := o.processNextWorkItem(ctx, ns.Name, pod); err != nil { + errCh <- err + } + }() } }() } @@ -726,9 +763,9 @@ func (o *MustGatherOptions) Run() error { } // processNextWorkItem creates & processes the must-gather pod and returns error if any -func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) error { +func (o *MustGatherOptions) processNextWorkItem(ctx context.Context, ns string, pod *corev1.Pod) error { var err error - pod, err = o.Client.CoreV1().Pods(ns).Create(context.TODO(), pod, metav1.CreateOptions{}) + pod, err = o.Client.CoreV1().Pods(ns).Create(ctx, pod, metav1.CreateOptions{}) if err != nil { return err } @@ -737,35 +774,33 @@ func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) erro } else { o.log("pod for plug-in image %s created", pod.Spec.Containers[0].Image) } - if len(o.RunNamespace) > 0 && !o.Keep { - defer func() { - // must-gather runs in its own separate namespace as default , so after it is completed - // it deletes this namespace and all the pods are removed by garbage collector. - // However, if user specifies namespace via `run-namespace`, these pods need to - // be deleted manually. - err = o.Client.CoreV1().Pods(o.RunNamespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{}) - if err != nil { + if len(o.RunNamespace) > 0 { + // must-gather runs in its own separate namespace as default , so after it is completed + // it deletes this namespace and all the pods are removed by garbage collector. + // However, if user specifies namespace via `run-namespace`, these pods need to + // be deleted manually. + o.addCleanupHook(func(ctx context.Context) { + if err := o.Client.CoreV1().Pods(o.RunNamespace).Delete(ctx, pod.Name, metav1.DeleteOptions{}); err != nil { klog.V(4).Infof("pod deletion error %v", err) } - }() + }) } log := o.newPodOutLogger(o.Out, pod.Name) // wait for gather container to be running (gather is running) - if err := o.waitForGatherContainerRunning(pod); err != nil { + if err := o.waitForGatherContainerRunning(ctx, pod); err != nil { log("gather did not start: %s", err) - return fmt.Errorf("gather did not start for pod %s: %s", pod.Name, err) - + return fmt.Errorf("gather did not start for pod %s: %w", pod.Name, err) } // stream gather container logs - if err := o.getGatherContainerLogs(pod); err != nil { + if err := o.getGatherContainerLogs(ctx, pod); err != nil { log("gather logs unavailable: %v", err) } // wait for pod to be running (gather has completed) log("waiting for gather to complete") - if err := o.waitForGatherToComplete(pod); err != nil { + if err := o.waitForGatherToComplete(ctx, pod); err != nil { log("gather never finished: %v", err) if exiterr, ok := err.(*exec.CodeExitError); ok { return exiterr @@ -775,12 +810,12 @@ func (o *MustGatherOptions) processNextWorkItem(ns string, pod *corev1.Pod) erro // copy the gathered files into the local destination dir log("downloading gather output") - pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(context.TODO(), pod.Name, metav1.GetOptions{}) + pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}) if err != nil { log("gather output not downloaded: %v\n", err) return fmt.Errorf("unable to download output from pod %s: %s", pod.Name, err) } - if err := o.copyFilesFromPod(pod); err != nil { + if err := o.copyFilesFromPod(ctx, pod); err != nil { log("gather output not downloaded: %v\n", err) return fmt.Errorf("unable to download output from pod %s: %s", pod.Name, err) } @@ -807,7 +842,7 @@ func (o *MustGatherOptions) logTimestamp() error { return err } -func (o *MustGatherOptions) copyFilesFromPod(pod *corev1.Pod) error { +func (o *MustGatherOptions) copyFilesFromPod(ctx context.Context, pod *corev1.Pod) error { streams := o.IOStreams streams.Out = o.newPrefixWriter(streams.Out, fmt.Sprintf("[%s] OUT", pod.Name), false, true) imageFolder := regexp.MustCompile("[^A-Za-z0-9]+").ReplaceAllString(pod.Status.ContainerStatuses[0].ImageID, "-") @@ -835,7 +870,7 @@ func (o *MustGatherOptions) copyFilesFromPod(pod *corev1.Pod) error { Container: gatherContainerName, Timestamps: true, } - readCloser, err := o.Client.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions).Stream(context.TODO()) + readCloser, err := o.Client.CoreV1().Pods(pod.Namespace).GetLogs(pod.Name, logOptions).Stream(ctx) if err != nil { return err } @@ -869,11 +904,12 @@ func (o *MustGatherOptions) copyFilesFromPod(pod *corev1.Pod) error { return kutilerrors.NewAggregate(errs) } -func (o *MustGatherOptions) getGatherContainerLogs(pod *corev1.Pod) error { +func (o *MustGatherOptions) getGatherContainerLogs(ctx context.Context, pod *corev1.Pod) error { since2s := int64(2) opts := &logs.LogsOptions{ Namespace: pod.Namespace, ResourceArg: pod.Name, + Follow: true, Options: &corev1.PodLogOptions{ Follow: true, Container: pod.Spec.Containers[0].Name, @@ -890,13 +926,13 @@ func (o *MustGatherOptions) getGatherContainerLogs(pod *corev1.Pod) error { // gather script might take longer than the default API server time, // so we should check if the gather script still runs and re-run logs // thus we run this in a loop - if err := opts.RunLogs(); err != nil { + if err := opts.RunLogsContext(ctx); err != nil { return err } // to ensure we don't print all of history set since to past 2 seconds opts.Options.(*corev1.PodLogOptions).SinceSeconds = &since2s - if done, _ := o.isGatherDone(pod); done { + if done, _ := o.isGatherDone(ctx, pod); done { return nil } klog.V(4).Infof("lost logs, re-trying...") @@ -927,15 +963,15 @@ func (o *MustGatherOptions) newPrefixWriter(out io.Writer, prefix string, ignore return writer } -func (o *MustGatherOptions) waitForGatherToComplete(pod *corev1.Pod) error { - return wait.PollUntilContextTimeout(context.TODO(), 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { - return o.isGatherDone(pod) +func (o *MustGatherOptions) waitForGatherToComplete(ctx context.Context, pod *corev1.Pod) error { + return wait.PollUntilContextTimeout(ctx, 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { + return o.isGatherDone(ctx, pod) }) } -func (o *MustGatherOptions) isGatherDone(pod *corev1.Pod) (bool, error) { +func (o *MustGatherOptions) isGatherDone(ctx context.Context, pod *corev1.Pod) (bool, error) { var err error - if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(context.TODO(), pod.Name, metav1.GetOptions{}); err != nil { + if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}); err != nil { // at this stage pod should exist, we've been gathering container logs, so error if not found if kerrors.IsNotFound(err) { return true, err @@ -967,10 +1003,10 @@ func (o *MustGatherOptions) isGatherDone(pod *corev1.Pod) (bool, error) { return false, nil } -func (o *MustGatherOptions) waitForGatherContainerRunning(pod *corev1.Pod) error { - return wait.PollUntilContextTimeout(context.TODO(), 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { +func (o *MustGatherOptions) waitForGatherContainerRunning(ctx context.Context, pod *corev1.Pod) error { + return wait.PollUntilContextTimeout(ctx, 10*time.Second, o.Timeout, true, func(ctx context.Context) (bool, error) { var err error - if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(context.TODO(), pod.Name, metav1.GetOptions{}); err == nil { + if pod, err = o.Client.CoreV1().Pods(pod.Namespace).Get(ctx, pod.Name, metav1.GetOptions{}); err == nil { if len(pod.Status.ContainerStatuses) == 0 { return false, nil } @@ -992,41 +1028,40 @@ func (o *MustGatherOptions) waitForGatherContainerRunning(pod *corev1.Pod) error }) } -func (o *MustGatherOptions) getNamespace() (*corev1.Namespace, func(), error) { +func (o *MustGatherOptions) getNamespace(ctx context.Context) (*corev1.Namespace, error) { if o.RunNamespace == "" { - return o.createTempNamespace() + return o.createTempNamespace(ctx) } - ns, err := o.Client.CoreV1().Namespaces().Get(context.TODO(), o.RunNamespace, metav1.GetOptions{}) + ns, err := o.Client.CoreV1().Namespaces().Get(ctx, o.RunNamespace, metav1.GetOptions{}) if err != nil { - return nil, nil, fmt.Errorf("retrieving namespace %q: %w", o.RunNamespace, err) + return nil, fmt.Errorf("retrieving namespace %q: %w", o.RunNamespace, err) } - return ns, func() {}, nil + return ns, nil } -func (o *MustGatherOptions) createTempNamespace() (*corev1.Namespace, func(), error) { - ns, err := o.Client.CoreV1().Namespaces().Create(context.TODO(), newNamespace(), metav1.CreateOptions{}) +func (o *MustGatherOptions) createTempNamespace(ctx context.Context) (*corev1.Namespace, error) { + ns, err := o.Client.CoreV1().Namespaces().Create(ctx, newNamespace(), metav1.CreateOptions{}) if err != nil { - return nil, nil, fmt.Errorf("creating temp namespace: %w", err) + return nil, fmt.Errorf("creating temp namespace: %w", err) } o.PrinterCreated.PrintObj(ns, o.LogOut) - crb, err := o.Client.RbacV1().ClusterRoleBindings().Create(context.TODO(), newClusterRoleBinding(ns), metav1.CreateOptions{}) + crb, err := o.Client.RbacV1().ClusterRoleBindings().Create(ctx, newClusterRoleBinding(ns), metav1.CreateOptions{}) if err != nil { - return nil, nil, fmt.Errorf("creating temp clusterRoleBinding: %w", err) + return nil, fmt.Errorf("creating temp clusterRoleBinding: %w", err) } o.PrinterCreated.PrintObj(crb, o.LogOut) - cleanup := func() { - if err := o.Client.CoreV1().Namespaces().Delete(context.TODO(), ns.Name, metav1.DeleteOptions{}); err != nil { + o.addCleanupHook(func(ctx context.Context) { + if err := o.Client.CoreV1().Namespaces().Delete(ctx, ns.Name, metav1.DeleteOptions{}); err != nil { fmt.Printf("%v\n", err) } else { o.PrinterDeleted.PrintObj(ns, o.LogOut) } - } - - return ns, cleanup, nil + }) + return ns, nil } func newNamespace() *corev1.Namespace { @@ -1237,19 +1272,45 @@ func (o *MustGatherOptions) BackupGathering(ctx context.Context, errs []error) { streams.Out = o.newPrefixWriter(streams.Out, fmt.Sprintf("[must-gather ] OUT"), false, true) destDir := path.Join(o.DestDir, fmt.Sprintf("inspect.local.%06d", rand.Int63())) - if err := runInspect(streams, rest.CopyConfig(o.Config), destDir, []string{typeTargets}); err != nil { + if err := runInspect(ctx, streams, rest.CopyConfig(o.Config), destDir, []string{typeTargets}); err != nil { fmt.Fprintf(o.ErrOut, "error completing cluster type inspection: %v\n", err) } fmt.Fprintf(o.ErrOut, "Falling back to `oc adm inspect %s` to collect basic cluster named resources.\n", strings.Join(namedTargets, " ")) - if err := runInspect(streams, rest.CopyConfig(o.Config), destDir, namedTargets); err != nil { + if err := runInspect(ctx, streams, rest.CopyConfig(o.Config), destDir, namedTargets); err != nil { fmt.Fprintf(o.ErrOut, "error completing cluster named resource inspection: %v\n", err) } return } -func runInspect(streams genericiooptions.IOStreams, config *rest.Config, destDir string, arguments []string) error { +func (o *MustGatherOptions) addCleanupHook(hook func(context.Context)) { + if hook == nil { + return + } + + o.cleanupMux.Lock() + defer o.cleanupMux.Unlock() + o.cleanupHooks = append(o.cleanupHooks, hook) +} + +func (o *MustGatherOptions) cleanup(ctx context.Context) { + o.cleanupMux.Lock() + hooks := append([]func(context.Context){}, o.cleanupHooks...) + o.cleanupMux.Unlock() + + var wg sync.WaitGroup + wg.Add(len(o.cleanupHooks)) + for _, hook := range hooks { + go func() { + defer wg.Done() + hook(ctx) + }() + } + wg.Wait() +} + +func runInspect(ctx context.Context, streams genericiooptions.IOStreams, config *rest.Config, destDir string, arguments []string) error { inspectOptions := inspect.NewInspectOptions(streams) inspectOptions.RESTConfig = config inspectOptions.DestDir = destDir @@ -1260,7 +1321,7 @@ func runInspect(streams genericiooptions.IOStreams, config *rest.Config, destDir if err := inspectOptions.Validate(); err != nil { return fmt.Errorf("error validating backup collection: %w", err) } - if err := inspectOptions.Run(); err != nil { + if err := inspectOptions.RunContext(ctx); err != nil { return fmt.Errorf("error running backup collection: %w", err) } return nil diff --git a/pkg/cli/admin/mustgather/mustgather_test.go b/pkg/cli/admin/mustgather/mustgather_test.go index b854ef1c8d..41643def08 100644 --- a/pkg/cli/admin/mustgather/mustgather_test.go +++ b/pkg/cli/admin/mustgather/mustgather_test.go @@ -119,7 +119,7 @@ func TestImagesAndImageStreams(t *testing.T) { LogOut: genericiooptions.NewTestIOStreamsDiscard().Out, AllImages: tc.allImages, } - err := options.completeImages() + err := options.completeImages(context.TODO()) if err != nil { t.Fatal(err) } @@ -214,7 +214,7 @@ func TestGetNamespace(t *testing.T) { tc.Options.PrinterCreated = printers.NewDiscardingPrinter() tc.Options.PrinterDeleted = printers.NewDiscardingPrinter() - ns, cleanup, err := tc.Options.getNamespace() + ns, err := tc.Options.getNamespace(context.TODO()) if err != nil { if tc.ShouldFail { return @@ -231,7 +231,7 @@ func TestGetNamespace(t *testing.T) { t.Error("namespace should exist") } - cleanup() + tc.Options.cleanup(context.Background()) if _, err = tc.Options.Client.CoreV1().Namespaces().Get(context.TODO(), ns.Name, metav1.GetOptions{}); err != nil { if !k8sapierrors.IsNotFound(err) { diff --git a/pkg/cli/internal/logs/doc.go b/pkg/cli/internal/logs/doc.go new file mode 100644 index 0000000000..418a61e5c3 --- /dev/null +++ b/pkg/cli/internal/logs/doc.go @@ -0,0 +1,6 @@ +// Package logs is a copy of kubernetes/staging/src/k8s.io/kubectl/pkg/cmd/logs, +// which contains LogOptions.RunLogsContext function that is needed for proper signal handling. +// This is not yet available in v33. +// +// TODO: Remove and replace once deps are updated to future v34. +package logs diff --git a/pkg/cli/internal/logs/logs.go b/pkg/cli/internal/logs/logs.go new file mode 100644 index 0000000000..af702362cb --- /dev/null +++ b/pkg/cli/internal/logs/logs.go @@ -0,0 +1,516 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package logs + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "regexp" + "sync" + "time" + + "github.com/spf13/cobra" + + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/cli-runtime/pkg/genericclioptions" + "k8s.io/cli-runtime/pkg/genericiooptions" + "k8s.io/client-go/rest" + cmdutil "k8s.io/kubectl/pkg/cmd/util" + "k8s.io/kubectl/pkg/polymorphichelpers" + "k8s.io/kubectl/pkg/scheme" + "k8s.io/kubectl/pkg/util" + "k8s.io/kubectl/pkg/util/completion" + "k8s.io/kubectl/pkg/util/i18n" + "k8s.io/kubectl/pkg/util/interrupt" + "k8s.io/kubectl/pkg/util/templates" +) + +const ( + logsUsageStr = "logs [-f] [-p] (POD | TYPE/NAME) [-c CONTAINER]" +) + +var ( + logsLong = templates.LongDesc(i18n.T(` + Print the logs for a container in a pod or specified resource. + If the pod has only one container, the container name is optional.`)) + + logsExample = templates.Examples(i18n.T(` + # Return snapshot logs from pod nginx with only one container + kubectl logs nginx + + # Return snapshot logs from pod nginx, prefixing each line with the source pod and container name + kubectl logs nginx --prefix + + # Return snapshot logs from pod nginx, limiting output to 500 bytes + kubectl logs nginx --limit-bytes=500 + + # Return snapshot logs from pod nginx, waiting up to 20 seconds for it to start running. + kubectl logs nginx --pod-running-timeout=20s + + # Return snapshot logs from pod nginx with multi containers + kubectl logs nginx --all-containers=true + + # Return snapshot logs from all pods in the deployment nginx + kubectl logs deployment/nginx --all-pods=true + + # Return snapshot logs from all containers in pods defined by label app=nginx + kubectl logs -l app=nginx --all-containers=true + + # Return snapshot logs from all pods defined by label app=nginx, limiting concurrent log requests to 10 pods + kubectl logs -l app=nginx --max-log-requests=10 + + # Return snapshot of previous terminated ruby container logs from pod web-1 + kubectl logs -p -c ruby web-1 + + # Begin streaming the logs from pod nginx, continuing even if errors occur + kubectl logs nginx -f --ignore-errors=true + + # Begin streaming the logs of the ruby container in pod web-1 + kubectl logs -f -c ruby web-1 + + # Begin streaming the logs from all containers in pods defined by label app=nginx + kubectl logs -f -l app=nginx --all-containers=true + + # Display only the most recent 20 lines of output in pod nginx + kubectl logs --tail=20 nginx + + # Show all logs from pod nginx written in the last hour + kubectl logs --since=1h nginx + + # Show all logs with timestamps from pod nginx starting from August 30, 2024, at 06:00:00 UTC + kubectl logs nginx --since-time=2024-08-30T06:00:00Z --timestamps=true + + # Show logs from a kubelet with an expired serving certificate + kubectl logs --insecure-skip-tls-verify-backend nginx + + # Return snapshot logs from first container of a job named hello + kubectl logs job/hello + + # Return snapshot logs from container nginx-1 of a deployment named nginx + kubectl logs deployment/nginx -c nginx-1`)) + + selectorTail int64 = 10 + logsUsageErrStr = fmt.Sprintf("expected '%s'.\nPOD or TYPE/NAME is a required argument for the logs command", logsUsageStr) +) + +const ( + defaultPodLogsTimeout = 20 * time.Second +) + +type LogsOptions struct { + Namespace string + ResourceArg string + AllContainers bool + AllPods bool + Options runtime.Object + Resources []string + + ConsumeRequestFn func(context.Context, rest.ResponseWrapper, io.Writer) error + + // PodLogOptions + SinceTime string + SinceSeconds time.Duration + Follow bool + Previous bool + Timestamps bool + IgnoreLogErrors bool + LimitBytes int64 + Tail int64 + Container string + InsecureSkipTLSVerifyBackend bool + + // whether or not a container name was given via --container + ContainerNameSpecified bool + Selector string + MaxFollowConcurrency int + Prefix bool + + Object runtime.Object + GetPodTimeout time.Duration + RESTClientGetter genericclioptions.RESTClientGetter + LogsForObject polymorphichelpers.LogsForObjectFunc + AllPodLogsForObject polymorphichelpers.AllPodLogsForObjectFunc + + genericiooptions.IOStreams + + TailSpecified bool + + containerNameFromRefSpecRegexp *regexp.Regexp +} + +func NewLogsOptions(streams genericiooptions.IOStreams) *LogsOptions { + return &LogsOptions{ + IOStreams: streams, + Tail: -1, + MaxFollowConcurrency: 5, + + containerNameFromRefSpecRegexp: regexp.MustCompile(`spec\.(?:initContainers|containers|ephemeralContainers){(.+)}`), + } +} + +// NewCmdLogs creates a new pod logs command +func NewCmdLogs(f cmdutil.Factory, streams genericiooptions.IOStreams) *cobra.Command { + o := NewLogsOptions(streams) + + cmd := &cobra.Command{ + Use: logsUsageStr, + DisableFlagsInUseLine: true, + Short: i18n.T("Print the logs for a container in a pod"), + Long: logsLong, + Example: logsExample, + ValidArgsFunction: completion.PodResourceNameAndContainerCompletionFunc(f), + Run: func(cmd *cobra.Command, args []string) { + cmdutil.CheckErr(o.Complete(f, cmd, args)) + cmdutil.CheckErr(o.Validate()) + cmdutil.CheckErr(o.RunLogs()) + }, + } + o.AddFlags(cmd) + return cmd +} + +func (o *LogsOptions) AddFlags(cmd *cobra.Command) { + cmd.Flags().BoolVar(&o.AllPods, "all-pods", o.AllPods, "Get logs from all pod(s). Sets prefix to true.") + cmd.Flags().BoolVar(&o.AllContainers, "all-containers", o.AllContainers, "Get all containers' logs in the pod(s).") + cmd.Flags().BoolVarP(&o.Follow, "follow", "f", o.Follow, "Specify if the logs should be streamed.") + cmd.Flags().BoolVar(&o.Timestamps, "timestamps", o.Timestamps, "Include timestamps on each line in the log output") + cmd.Flags().Int64Var(&o.LimitBytes, "limit-bytes", o.LimitBytes, "Maximum bytes of logs to return. Defaults to no limit.") + cmd.Flags().BoolVarP(&o.Previous, "previous", "p", o.Previous, "If true, print the logs for the previous instance of the container in a pod if it exists.") + cmd.Flags().Int64Var(&o.Tail, "tail", o.Tail, "Lines of recent log file to display. Defaults to -1 with no selector, showing all log lines otherwise 10, if a selector is provided.") + cmd.Flags().BoolVar(&o.IgnoreLogErrors, "ignore-errors", o.IgnoreLogErrors, "If watching / following pod logs, allow for any errors that occur to be non-fatal") + cmd.Flags().StringVar(&o.SinceTime, "since-time", o.SinceTime, i18n.T("Only return logs after a specific date (RFC3339). Defaults to all logs. Only one of since-time / since may be used.")) + cmd.Flags().DurationVar(&o.SinceSeconds, "since", o.SinceSeconds, "Only return logs newer than a relative duration like 5s, 2m, or 3h. Defaults to all logs. Only one of since-time / since may be used.") + cmd.Flags().StringVarP(&o.Container, "container", "c", o.Container, "Print the logs of this container") + cmd.Flags().BoolVar(&o.InsecureSkipTLSVerifyBackend, "insecure-skip-tls-verify-backend", o.InsecureSkipTLSVerifyBackend, + "Skip verifying the identity of the kubelet that logs are requested from. In theory, an attacker could provide invalid log content back. You might want to use this if your kubelet serving certificates have expired.") + cmdutil.AddPodRunningTimeoutFlag(cmd, defaultPodLogsTimeout) + cmdutil.AddLabelSelectorFlagVar(cmd, &o.Selector) + cmd.Flags().IntVar(&o.MaxFollowConcurrency, "max-log-requests", o.MaxFollowConcurrency, "Specify maximum number of concurrent logs to follow when using by a selector. Defaults to 5.") + cmd.Flags().BoolVar(&o.Prefix, "prefix", o.Prefix, "Prefix each log line with the log source (pod name and container name)") +} + +func (o *LogsOptions) ToLogOptions() (*corev1.PodLogOptions, error) { + logOptions := &corev1.PodLogOptions{ + Container: o.Container, + Follow: o.Follow, + Previous: o.Previous, + Timestamps: o.Timestamps, + InsecureSkipTLSVerifyBackend: o.InsecureSkipTLSVerifyBackend, + } + + if len(o.SinceTime) > 0 { + t, err := util.ParseRFC3339(o.SinceTime, metav1.Now) + if err != nil { + return nil, err + } + + logOptions.SinceTime = &t + } + + if o.LimitBytes != 0 { + logOptions.LimitBytes = &o.LimitBytes + } + + if o.SinceSeconds != 0 { + // round up to the nearest second + sec := int64(o.SinceSeconds.Round(time.Second).Seconds()) + logOptions.SinceSeconds = &sec + } + + if len(o.Selector) > 0 && o.Tail == -1 && !o.TailSpecified { + logOptions.TailLines = &selectorTail + } else if o.Tail != -1 { + logOptions.TailLines = &o.Tail + } + + return logOptions, nil +} + +func (o *LogsOptions) Complete(f cmdutil.Factory, cmd *cobra.Command, args []string) error { + o.ContainerNameSpecified = cmd.Flag("container").Changed + o.TailSpecified = cmd.Flag("tail").Changed + o.Resources = args + + switch len(args) { + case 0: + if len(o.Selector) == 0 { + return cmdutil.UsageErrorf(cmd, "%s", logsUsageErrStr) + } + case 1: + o.ResourceArg = args[0] + if len(o.Selector) != 0 { + return cmdutil.UsageErrorf(cmd, "only a selector (-l) or a POD name is allowed") + } + case 2: + o.ResourceArg = args[0] + o.Container = args[1] + default: + return cmdutil.UsageErrorf(cmd, "%s", logsUsageErrStr) + } + + if o.AllPods { + o.Prefix = true + } + + var err error + o.Namespace, _, err = f.ToRawKubeConfigLoader().Namespace() + if err != nil { + return err + } + + o.ConsumeRequestFn = DefaultConsumeRequest + + o.GetPodTimeout, err = cmdutil.GetPodRunningTimeoutFlag(cmd) + if err != nil { + return err + } + + o.Options, err = o.ToLogOptions() + if err != nil { + return err + } + + o.RESTClientGetter = f + o.LogsForObject = polymorphichelpers.LogsForObjectFn + o.AllPodLogsForObject = polymorphichelpers.AllPodLogsForObjectFn + + if o.Object == nil { + builder := f.NewBuilder(). + WithScheme(scheme.Scheme, scheme.Scheme.PrioritizedVersionsAllGroups()...). + NamespaceParam(o.Namespace).DefaultNamespace(). + SingleResourceType() + if o.ResourceArg != "" { + builder.ResourceNames("pods", o.ResourceArg) + } + if o.Selector != "" { + builder.ResourceTypes("pods").LabelSelectorParam(o.Selector) + } + infos, err := builder.Do().Infos() + if err != nil { + if apierrors.IsNotFound(err) { + err = fmt.Errorf("error from server (NotFound): %w in namespace %q", err, o.Namespace) + } + return err + } + if o.Selector == "" && len(infos) != 1 { + return errors.New("expected a resource") + } + o.Object = infos[0].Object + if o.Selector != "" && len(o.Object.(*corev1.PodList).Items) == 0 { + fmt.Fprintf(o.ErrOut, "No resources found in %s namespace.\n", o.Namespace) + } + } + + return nil +} + +func (o LogsOptions) Validate() error { + if len(o.SinceTime) > 0 && o.SinceSeconds != 0 { + return fmt.Errorf("at most one of `sinceTime` or `sinceSeconds` may be specified") + } + + logsOptions, ok := o.Options.(*corev1.PodLogOptions) + if !ok { + return errors.New("unexpected logs options object") + } + if o.AllContainers && len(logsOptions.Container) > 0 { + return fmt.Errorf("--all-containers=true should not be specified with container name %s", logsOptions.Container) + } + + if o.ContainerNameSpecified && len(o.Resources) == 2 { + return fmt.Errorf("only one of -c or an inline [CONTAINER] arg is allowed") + } + + if o.LimitBytes < 0 { + return fmt.Errorf("--limit-bytes must be greater than 0") + } + + if logsOptions.SinceSeconds != nil && *logsOptions.SinceSeconds < int64(0) { + return fmt.Errorf("--since must be greater than 0") + } + + if logsOptions.TailLines != nil && *logsOptions.TailLines < -1 { + return fmt.Errorf("--tail must be greater than or equal to -1") + } + + return nil +} + +// RunLogs wraps RunLogsContext with signal handling. +// When a signal is received, streaming is stopped, then followed by os.Exit(1). +func (o LogsOptions) RunLogs() error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + intr := interrupt.New(nil, cancel) + return intr.Run(func() error { + return o.RunLogsContext(ctx) + }) +} + +// RunLogsContext retrieves a pod log. +func (o LogsOptions) RunLogsContext(ctx context.Context) error { + var requests map[corev1.ObjectReference]rest.ResponseWrapper + var err error + if o.AllPods { + requests, err = o.AllPodLogsForObject(o.RESTClientGetter, o.Object, o.Options, o.GetPodTimeout, o.AllContainers) + } else { + requests, err = o.LogsForObject(o.RESTClientGetter, o.Object, o.Options, o.GetPodTimeout, o.AllContainers) + } + if err != nil { + return err + } + + if o.Follow && len(requests) > 1 { + if len(requests) > o.MaxFollowConcurrency { + return fmt.Errorf( + "you are attempting to follow %d log streams, but maximum allowed concurrency is %d, use --max-log-requests to increase the limit", + len(requests), o.MaxFollowConcurrency, + ) + } + } + + if o.Follow && len(requests) > 1 { + return o.parallelConsumeRequest(ctx, requests) + } + return o.sequentialConsumeRequest(ctx, requests) +} + +func (o LogsOptions) parallelConsumeRequest(ctx context.Context, requests map[corev1.ObjectReference]rest.ResponseWrapper) error { + reader, writer := io.Pipe() + wg := &sync.WaitGroup{} + wg.Add(len(requests)) + for objRef, request := range requests { + go func(objRef corev1.ObjectReference, request rest.ResponseWrapper) { + defer wg.Done() + out := o.addPrefixIfNeeded(objRef, writer) + if err := o.ConsumeRequestFn(ctx, request, out); err != nil { + if !o.IgnoreLogErrors { + writer.CloseWithError(err) + + // It's important to return here to propagate the error via the pipe + return + } + + fmt.Fprintf(writer, "error: %v\n", err) + } + + }(objRef, request) + } + + go func() { + wg.Wait() + writer.Close() + }() + + _, err := io.Copy(o.Out, reader) + return err +} + +func (o LogsOptions) sequentialConsumeRequest(ctx context.Context, requests map[corev1.ObjectReference]rest.ResponseWrapper) error { + for objRef, request := range requests { + out := o.addPrefixIfNeeded(objRef, o.Out) + if err := o.ConsumeRequestFn(ctx, request, out); err != nil { + if !o.IgnoreLogErrors { + return err + } + + fmt.Fprintf(o.Out, "error: %v\n", err) + } + } + + return nil +} + +func (o LogsOptions) addPrefixIfNeeded(ref corev1.ObjectReference, writer io.Writer) io.Writer { + if !o.Prefix || ref.FieldPath == "" || ref.Name == "" { + return writer + } + + // We rely on ref.FieldPath to contain a reference to a container + // including a container name (not an index) so we can get a container name + // without making an extra API request. + var containerName string + containerNameMatches := o.containerNameFromRefSpecRegexp.FindStringSubmatch(ref.FieldPath) + if len(containerNameMatches) == 2 { + containerName = containerNameMatches[1] + } + + prefix := fmt.Sprintf("[pod/%s/%s] ", ref.Name, containerName) + return &prefixingWriter{ + prefix: []byte(prefix), + writer: writer, + } +} + +// DefaultConsumeRequest reads the data from request and writes into +// the out writer. It buffers data from requests until the newline or io.EOF +// occurs in the data, so it doesn't interleave logs sub-line +// when running concurrently. +// +// A successful read returns err == nil, not err == io.EOF. +// Because the function is defined to read from request until io.EOF, it does +// not treat an io.EOF as an error to be reported. +func DefaultConsumeRequest(ctx context.Context, request rest.ResponseWrapper, out io.Writer) error { + readCloser, err := request.Stream(ctx) + if err != nil { + return err + } + defer readCloser.Close() + + r := bufio.NewReader(readCloser) + for { + bytes, err := r.ReadBytes('\n') + if _, err := out.Write(bytes); err != nil { + return err + } + + if err != nil { + if err != io.EOF { + return err + } + return nil + } + } +} + +type prefixingWriter struct { + prefix []byte + writer io.Writer +} + +func (pw *prefixingWriter) Write(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + + // Perform an "atomic" write of a prefix and p to make sure that it doesn't interleave + // sub-line when used concurrently with io.PipeWrite. + n, err := pw.writer.Write(append(pw.prefix, p...)) + if n > len(p) { + // To comply with the io.Writer interface requirements we must + // return a number of bytes written from p (0 <= n <= len(p)), + // so we are ignoring the length of the prefix here. + return len(p), err + } + return n, err +} diff --git a/pkg/cli/internal/logs/logs_test.go b/pkg/cli/internal/logs/logs_test.go new file mode 100644 index 0000000000..ba02561249 --- /dev/null +++ b/pkg/cli/internal/logs/logs_test.go @@ -0,0 +1,970 @@ +/* +Copyright 2014 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package logs + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + "testing/iotest" + "time" + + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/cli-runtime/pkg/genericclioptions" + "k8s.io/cli-runtime/pkg/genericiooptions" + restclient "k8s.io/client-go/rest" + "k8s.io/client-go/rest/fake" + cmdtesting "k8s.io/kubectl/pkg/cmd/testing" + "k8s.io/kubectl/pkg/scheme" +) + +func TestLog(t *testing.T) { + tests := []struct { + name string + opts func(genericiooptions.IOStreams) *LogsOptions + expectedErr string + expectedOutSubstrings []string + }{ + { + name: "v1 - pod log", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + + return o + }, + expectedOutSubstrings: []string{"test log content\n"}, + }, + { + name: "pod logs with prefix", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod", + FieldPath: "spec.containers{test-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Prefix = true + + return o + }, + expectedOutSubstrings: []string{"[pod/test-pod/test-container] test log content\n"}, + }, + { + name: "stateful set logs with all pods", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-sts-0", + FieldPath: "spec.containers{test-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content for pod test-sts-0\n")}, + { + Kind: "Pod", + Name: "test-sts-1", + FieldPath: "spec.containers{test-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content for pod test-sts-1\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Prefix = true + return o + }, + expectedOutSubstrings: []string{ + "[pod/test-sts-0/test-container] test log content for pod test-sts-0\n", + "[pod/test-sts-1/test-container] test log content for pod test-sts-1\n", + }, + }, + { + name: "pod logs with prefix: init container", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod", + FieldPath: "spec.initContainers{test-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Prefix = true + + return o + }, + expectedOutSubstrings: []string{"[pod/test-pod/test-container] test log content\n"}, + }, + { + name: "pod logs with prefix: ephemeral container", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod", + FieldPath: "spec.ephemeralContainers{test-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Prefix = true + + return o + }, + expectedOutSubstrings: []string{"[pod/test-pod/test-container] test log content\n"}, + }, + { + name: "get logs from multiple requests sequentially", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod-1", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 1\n")}, + { + Kind: "Pod", + Name: "some-pod-2", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 2\n")}, + { + Kind: "Pod", + Name: "some-pod-3", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 3\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + return o + }, + expectedOutSubstrings: []string{ + "test log content from source 1\n", + "test log content from source 2\n", + "test log content from source 3\n", + }, + }, + { + name: "follow logs from multiple requests concurrently", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + wg := &sync.WaitGroup{} + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod-1", + FieldPath: "spec.containers{some-container-1}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 1\n")}, + { + Kind: "Pod", + Name: "some-pod-2", + FieldPath: "spec.containers{some-container-2}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 2\n")}, + { + Kind: "Pod", + Name: "some-pod-3", + FieldPath: "spec.containers{some-container-3}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 3\n")}, + }, + wg: wg, + } + wg.Add(3) + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Follow = true + return o + }, + expectedOutSubstrings: []string{ + "test log content from source 1\n", + "test log content from source 2\n", + "test log content from source 3\n", + }, + }, + { + name: "fail to follow logs from multiple requests when there are more logs sources then MaxFollowConcurrency allows", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + wg := &sync.WaitGroup{} + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod-1", + FieldPath: "spec.containers{test-container-1}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + { + Kind: "Pod", + Name: "test-pod-2", + FieldPath: "spec.containers{test-container-2}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + { + Kind: "Pod", + Name: "test-pod-3", + FieldPath: "spec.containers{test-container-3}", + }: &responseWrapperMock{data: strings.NewReader("test log content\n")}, + }, + wg: wg, + } + wg.Add(3) + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.MaxFollowConcurrency = 2 + o.Follow = true + return o + }, + expectedErr: "you are attempting to follow 3 log streams, but maximum allowed concurrency is 2, use --max-log-requests to increase the limit", + }, + { + name: "fail if LogsForObject fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.LogsForObject = func(restClientGetter genericclioptions.RESTClientGetter, object, options runtime.Object, timeout time.Duration, allContainers bool) (map[corev1.ObjectReference]restclient.ResponseWrapper, error) { + return nil, errors.New("Error from the LogsForObject") + } + return o + }, + expectedErr: "Error from the LogsForObject", + }, + { + name: "fail to get logs, if ConsumeRequestFn fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod-1", + FieldPath: "spec.containers{test-container-1}", + }: &responseWrapperMock{}, + { + Kind: "Pod", + Name: "test-pod-2", + FieldPath: "spec.containers{test-container-1}", + }: &responseWrapperMock{}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = func(ctx context.Context, req restclient.ResponseWrapper, out io.Writer) error { + return errors.New("Error from the ConsumeRequestFn") + } + return o + }, + expectedErr: "Error from the ConsumeRequestFn", + }, + { + name: "follow logs from multiple requests concurrently with prefix", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + wg := &sync.WaitGroup{} + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod-1", + FieldPath: "spec.containers{test-container-1}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 1\n")}, + { + Kind: "Pod", + Name: "test-pod-2", + FieldPath: "spec.containers{test-container-2}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 2\n")}, + { + Kind: "Pod", + Name: "test-pod-3", + FieldPath: "spec.containers{test-container-3}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 3\n")}, + }, + wg: wg, + } + wg.Add(3) + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Follow = true + o.Prefix = true + return o + }, + expectedOutSubstrings: []string{ + "[pod/test-pod-1/test-container-1] test log content from source 1\n", + "[pod/test-pod-2/test-container-2] test log content from source 2\n", + "[pod/test-pod-3/test-container-3] test log content from source 3\n", + }, + }, + { + name: "fail to follow logs from multiple requests, if ConsumeRequestFn fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + wg := &sync.WaitGroup{} + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod-1", + FieldPath: "spec.containers{test-container-1}", + }: &responseWrapperMock{}, + { + Kind: "Pod", + Name: "test-pod-2", + FieldPath: "spec.containers{test-container-2}", + }: &responseWrapperMock{}, + { + Kind: "Pod", + Name: "test-pod-3", + FieldPath: "spec.containers{test-container-3}", + }: &responseWrapperMock{}, + }, + wg: wg, + } + wg.Add(3) + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = func(ctx context.Context, req restclient.ResponseWrapper, out io.Writer) error { + return errors.New("Error from the ConsumeRequestFn") + } + o.Follow = true + return o + }, + expectedErr: "Error from the ConsumeRequestFn", + }, + { + name: "fail to follow logs, if ConsumeRequestFn fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "test-pod-1", + FieldPath: "spec.containers{test-container-1}", + }: &responseWrapperMock{}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = func(ctx context.Context, req restclient.ResponseWrapper, out io.Writer) error { + return errors.New("Error from the ConsumeRequestFn") + } + o.Follow = true + return o + }, + expectedErr: "Error from the ConsumeRequestFn", + }, + { + name: "get logs from multiple requests and ignores the error if the container fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod-error-container", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{err: errors.New("error-container")}, + { + Kind: "Pod", + Name: "some-pod-1", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 1\n")}, + { + Kind: "Pod", + Name: "some-pod-2", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 2\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.IgnoreLogErrors = true + return o + }, + expectedOutSubstrings: []string{ + "error-container\n", + "test log content from source 1\n", + "test log content from source 2\n", + }, + }, + { + name: "get logs from multiple requests and an container fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod-error-container", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{err: errors.New("error-container")}, + { + Kind: "Pod", + Name: "some-pod", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + return o + }, + expectedErr: "error-container", + }, + { + name: "follow logs from multiple requests and ignores the error if the container fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod-error-container", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{err: errors.New("error-container")}, + { + Kind: "Pod", + Name: "some-pod-1", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 1\n")}, + { + Kind: "Pod", + Name: "some-pod-2", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source 2\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.IgnoreLogErrors = true + o.Follow = true + return o + }, + expectedOutSubstrings: []string{ + "error-container\n", + "test log content from source 1\n", + "test log content from source 2\n", + }, + }, + { + name: "follow logs from multiple requests and an container fails", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + mock := &logTestMock{ + logsForObjectRequests: map[corev1.ObjectReference]restclient.ResponseWrapper{ + { + Kind: "Pod", + Name: "some-pod-error-container", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{err: errors.New("error-container")}, + { + Kind: "Pod", + Name: "some-pod", + FieldPath: "spec.containers{some-container}", + }: &responseWrapperMock{data: strings.NewReader("test log content from source\n")}, + }, + } + + o := NewLogsOptions(streams) + o.LogsForObject = mock.mockLogsForObject + o.ConsumeRequestFn = mock.mockConsumeRequest + o.Follow = true + return o + }, + expectedErr: "error-container", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + tf := cmdtesting.NewTestFactory().WithNamespace("test") + defer tf.Cleanup() + + streams, _, buf, _ := genericiooptions.NewTestIOStreams() + + opts := test.opts(streams) + opts.Namespace = "test" + opts.Object = testPod() + opts.Options = &corev1.PodLogOptions{} + err := opts.RunLogs() + + if err == nil && len(test.expectedErr) > 0 { + t.Fatalf("expected error %q, got none", test.expectedErr) + } + + if err != nil && !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("%s: expected to find:\n\t%s\nfound:\n\t%s\n", test.name, test.expectedErr, err.Error()) + } + + bufStr := buf.String() + if test.expectedOutSubstrings != nil { + for _, substr := range test.expectedOutSubstrings { + if !strings.Contains(bufStr, substr) { + t.Errorf("%s: expected to contain %#v. Output: %#v", test.name, substr, bufStr) + } + } + } + }) + } +} + +func testPod() *corev1.Pod { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "foo", Namespace: "test", ResourceVersion: "10"}, + Spec: corev1.PodSpec{ + RestartPolicy: corev1.RestartPolicyAlways, + DNSPolicy: corev1.DNSClusterFirst, + Containers: []corev1.Container{ + { + Name: "bar", + }, + }, + }, + } +} + +func TestValidateLogOptions(t *testing.T) { + f := cmdtesting.NewTestFactory() + defer f.Cleanup() + f.WithNamespace("") + + tests := []struct { + name string + args []string + opts func(genericiooptions.IOStreams) *LogsOptions + expected string + }{ + { + name: "since & since-time", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.SinceSeconds = time.Hour + o.SinceTime = "2006-01-02T15:04:05Z" + + var err error + o.Options, err = o.ToLogOptions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + return o + }, + args: []string{"foo"}, + expected: "at most one of `sinceTime` or `sinceSeconds` may be specified", + }, + { + name: "negative since-time", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.SinceSeconds = -1 * time.Second + + var err error + o.Options, err = o.ToLogOptions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + return o + }, + args: []string{"foo"}, + expected: "must be greater than 0", + }, + { + name: "negative limit-bytes", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.LimitBytes = -100 + + var err error + o.Options, err = o.ToLogOptions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + return o + }, + args: []string{"foo"}, + expected: "must be greater than 0", + }, + { + name: "negative tail", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.Tail = -100 + + var err error + o.Options, err = o.ToLogOptions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + return o + }, + args: []string{"foo"}, + expected: "--tail must be greater than or equal to -1", + }, + { + name: "container name combined with --all-containers", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.AllContainers = true + o.Container = "my-container" + + var err error + o.Options, err = o.ToLogOptions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + return o + }, + args: []string{"my-pod", "my-container"}, + expected: "--all-containers=true should not be specified with container", + }, + { + name: "container name combined with second argument", + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.Container = "my-container" + o.ContainerNameSpecified = true + + var err error + o.Options, err = o.ToLogOptions() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + return o + }, + args: []string{"my-pod", "my-container"}, + expected: "only one of -c or an inline", + }, + } + for _, test := range tests { + streams := genericiooptions.NewTestIOStreamsDiscard() + + o := test.opts(streams) + o.Resources = test.args + + err := o.Validate() + if err == nil { + t.Fatalf("expected error %q, got none", test.expected) + } + + if !strings.Contains(err.Error(), test.expected) { + t.Errorf("%s: expected to find:\n\t%s\nfound:\n\t%s\n", test.name, test.expected, err.Error()) + } + } +} + +func TestLogComplete(t *testing.T) { + f := cmdtesting.NewTestFactory() + defer f.Cleanup() + + tests := []struct { + name string + args []string + opts func(genericiooptions.IOStreams) *LogsOptions + expected string + }{ + { + name: "One args case", + args: []string{"foo"}, + opts: func(streams genericiooptions.IOStreams) *LogsOptions { + o := NewLogsOptions(streams) + o.Selector = "foo" + return o + }, + expected: "only a selector (-l) or a POD name is allowed", + }, + } + for _, test := range tests { + cmd := NewCmdLogs(f, genericiooptions.NewTestIOStreamsDiscard()) + out := "" + + // checkErr breaks tests in case of errors, plus we just + // need to check errors returned by the command validation + o := test.opts(genericiooptions.NewTestIOStreamsDiscard()) + err := o.Complete(f, cmd, test.args) + if err == nil { + t.Fatalf("expected error %q, got none", test.expected) + } + + out = err.Error() + if !strings.Contains(out, test.expected) { + t.Errorf("%s: expected to find:\n\t%s\nfound:\n\t%s\n", test.name, test.expected, out) + } + } +} + +func TestDefaultConsumeRequest(t *testing.T) { + tests := []struct { + name string + request restclient.ResponseWrapper + expectedErr string + expectedOut string + }{ + { + name: "error from request stream", + request: &responseWrapperMock{ + err: errors.New("err from the stream"), + }, + expectedErr: "err from the stream", + }, + { + name: "error while reading", + request: &responseWrapperMock{ + data: iotest.TimeoutReader(strings.NewReader("Some data")), + }, + expectedErr: iotest.ErrTimeout.Error(), + expectedOut: "Some data", + }, + { + name: "read with empty string", + request: &responseWrapperMock{ + data: strings.NewReader(""), + }, + expectedOut: "", + }, + { + name: "read without new lines", + request: &responseWrapperMock{ + data: strings.NewReader("some string without a new line"), + }, + expectedOut: "some string without a new line", + }, + { + name: "read with newlines in the middle", + request: &responseWrapperMock{ + data: strings.NewReader("foo\nbar"), + }, + expectedOut: "foo\nbar", + }, + { + name: "read with newline at the end", + request: &responseWrapperMock{ + data: strings.NewReader("foo\n"), + }, + expectedOut: "foo\n", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + buf := &bytes.Buffer{} + err := DefaultConsumeRequest(context.TODO(), test.request, buf) + + if err != nil && !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("%s: expected to find:\n\t%s\nfound:\n\t%s\n", test.name, test.expectedErr, err.Error()) + } + + if buf.String() != test.expectedOut { + t.Errorf("%s: did not get expected log content. Got: %s", test.name, buf.String()) + } + }) + } +} + +func TestNoResourceFoundMessage(t *testing.T) { + tf := cmdtesting.NewTestFactory().WithNamespace("test") + defer tf.Cleanup() + + ns := scheme.Codecs.WithoutConversion() + codec := scheme.Codecs.LegacyCodec(scheme.Scheme.PrioritizedVersionsAllGroups()...) + pods, _, _ := cmdtesting.EmptyTestData() + tf.UnstructuredClient = &fake.RESTClient{ + NegotiatedSerializer: ns, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + switch req.URL.Path { + case "/namespaces/test/pods": + if req.URL.Query().Get("labelSelector") == "foo" { + return &http.Response{StatusCode: http.StatusOK, Header: cmdtesting.DefaultHeader(), Body: cmdtesting.ObjBody(codec, pods)}, nil + } + t.Fatalf("unexpected request: %#v\n%#v", req.URL, req) + return nil, nil + default: + t.Fatalf("unexpected request: %#v\n%#v", req.URL, req) + return nil, nil + } + }), + } + + streams, _, buf, errbuf := genericiooptions.NewTestIOStreams() + cmd := NewCmdLogs(tf, streams) + o := NewLogsOptions(streams) + o.Selector = "foo" + err := o.Complete(tf, cmd, []string{}) + + if err != nil { + t.Fatalf("Unexpected error, expected none, got %v", err) + } + + expected := "" + if e, a := expected, buf.String(); e != a { + t.Errorf("expected to find:\n\t%s\nfound:\n\t%s\n", e, a) + } + + expectedErr := "No resources found in test namespace.\n" + if e, a := expectedErr, errbuf.String(); e != a { + t.Errorf("expected to find:\n\t%s\nfound:\n\t%s\n", e, a) + } +} + +func TestNoPodInNamespaceFoundMessage(t *testing.T) { + namespace, podName := "test", "bar" + + tf := cmdtesting.NewTestFactory().WithNamespace(namespace) + defer tf.Cleanup() + + ns := scheme.Codecs.WithoutConversion() + codec := scheme.Codecs.LegacyCodec(scheme.Scheme.PrioritizedVersionsAllGroups()...) + errStatus := apierrors.NewNotFound(schema.GroupResource{Resource: "pods"}, podName).Status() + + tf.UnstructuredClient = &fake.RESTClient{ + NegotiatedSerializer: ns, + Client: fake.CreateHTTPClient(func(req *http.Request) (*http.Response, error) { + switch req.URL.Path { + case fmt.Sprintf("/namespaces/%s/pods/%s", namespace, podName): + fallthrough + case fmt.Sprintf("/namespaces/%s/pods", namespace): + fallthrough + case fmt.Sprintf("/api/v1/namespaces/%s", namespace): + return &http.Response{StatusCode: http.StatusNotFound, Header: cmdtesting.DefaultHeader(), Body: cmdtesting.ObjBody(codec, &errStatus)}, nil + default: + t.Fatalf("unexpected request: %#v\n%#v", req.URL, req) + return nil, nil + } + }), + } + + streams, _, _, _ := genericiooptions.NewTestIOStreams() + cmd := NewCmdLogs(tf, streams) + o := NewLogsOptions(streams) + err := o.Complete(tf, cmd, []string{podName}) + + if err == nil { + t.Fatal("Expected NotFound error, got nil") + } + + expected := fmt.Sprintf("error from server (NotFound): pods %q not found in namespace %q", podName, namespace) + if e, a := expected, err.Error(); e != a { + t.Errorf("expected to find:\n\t%s\nfound:\n\t%s\n", e, a) + } +} + +type responseWrapperMock struct { + data io.Reader + err error +} + +func (r *responseWrapperMock) DoRaw(context.Context) ([]byte, error) { + data, _ := io.ReadAll(r.data) + return data, r.err +} + +func (r *responseWrapperMock) Stream(context.Context) (io.ReadCloser, error) { + return io.NopCloser(r.data), r.err +} + +type logTestMock struct { + logsForObjectRequests map[corev1.ObjectReference]restclient.ResponseWrapper + + // We need a WaitGroup in some test cases to make sure that we fetch logs concurrently. + // These test cases will finish successfully without the WaitGroup, but the WaitGroup + // will help us to identify regression when someone accidentally changes + // concurrent fetching to sequential + wg *sync.WaitGroup +} + +func (l *logTestMock) mockConsumeRequest(ctx context.Context, request restclient.ResponseWrapper, out io.Writer) error { + readCloser, err := request.Stream(ctx) + if err != nil { + return err + } + defer readCloser.Close() + + // Just copy everything for a test sake + _, err = io.Copy(out, readCloser) + if l.wg != nil { + l.wg.Done() + l.wg.Wait() + } + return err +} + +func (l *logTestMock) mockLogsForObject(restClientGetter genericclioptions.RESTClientGetter, object, options runtime.Object, timeout time.Duration, allContainers bool) (map[corev1.ObjectReference]restclient.ResponseWrapper, error) { + switch object.(type) { + case *appsv1.Deployment: + _, ok := options.(*corev1.PodLogOptions) + if !ok { + return nil, errors.New("provided options object is not a PodLogOptions") + } + + return l.logsForObjectRequests, nil + case *corev1.Pod: + _, ok := options.(*corev1.PodLogOptions) + if !ok { + return nil, errors.New("provided options object is not a PodLogOptions") + } + + return l.logsForObjectRequests, nil + default: + return nil, fmt.Errorf("cannot get the logs from %T", object) + } +}