diff --git a/pkg/types/types.go b/pkg/types/types.go index fd39fde73..bbf8d6dcc 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -156,10 +156,11 @@ type ProbeContext struct { // AppDetails contains all the application related envs type AppDetails struct { - Namespace string - Labels []string - Kind string - Names []string + Namespace string + Labels []string + Kind string + Names []string + LabelMatchMode string } func GetTargets(targets string) []AppDetails { @@ -171,9 +172,18 @@ func GetTargets(targets string) []AppDetails { for _, k := range t { val := strings.Split(strings.TrimSpace(k), ":") data := AppDetails{ - Kind: val[0], - Namespace: val[1], + Kind: val[0], + Namespace: val[1], + LabelMatchMode: "union", + } + + if len(val) > 3 { + mode := strings.TrimSpace(val[3]) + if mode == "intersection" || mode == "union" { + data.LabelMatchMode = mode + } } + if strings.Contains(val[2], "=") { data.Labels = parse(val[2]) } else { diff --git a/pkg/types/types_test.go b/pkg/types/types_test.go new file mode 100644 index 000000000..e279f75d1 --- /dev/null +++ b/pkg/types/types_test.go @@ -0,0 +1,140 @@ +package types + +import ( + "testing" +) + +func TestGetTargets_UnionModeDefault(t *testing.T) { + // Test that union is the default mode when not specified + targets := "deployment:default:[app=nginx,tier=frontend]" + result := GetTargets(targets) + + if len(result) != 1 { + t.Errorf("Expected 1 AppDetail, got %d", len(result)) + } + + if result[0].LabelMatchMode != "union" { + t.Errorf("Expected default LabelMatchMode to be 'union', got '%s'", result[0].LabelMatchMode) + } + + if result[0].Kind != "deployment" { + t.Errorf("Expected Kind to be 'deployment', got '%s'", result[0].Kind) + } + + if result[0].Namespace != "default" { + t.Errorf("Expected Namespace to be 'default', got '%s'", result[0].Namespace) + } + + if len(result[0].Labels) != 2 { + t.Errorf("Expected 2 labels, got %d", len(result[0].Labels)) + } +} + +func TestGetTargets_ExplicitUnionMode(t *testing.T) { + // Test explicit union mode + targets := "statefulset:prod:[app=postgres,role=primary]:union" + result := GetTargets(targets) + + if len(result) != 1 { + t.Errorf("Expected 1 AppDetail, got %d", len(result)) + } + + if result[0].LabelMatchMode != "union" { + t.Errorf("Expected LabelMatchMode to be 'union', got '%s'", result[0].LabelMatchMode) + } + + if result[0].Kind != "statefulset" { + t.Errorf("Expected Kind to be 'statefulset', got '%s'", result[0].Kind) + } +} + +func TestGetTargets_IntersectionMode(t *testing.T) { + // Test intersection mode + targets := "cluster:default:[cnpg.io/instanceRole=primary,cnpg.io/cluster=pg-eu]:intersection" + result := GetTargets(targets) + + if len(result) != 1 { + t.Errorf("Expected 1 AppDetail, got %d", len(result)) + } + + if result[0].LabelMatchMode != "intersection" { + t.Errorf("Expected LabelMatchMode to be 'intersection', got '%s'", result[0].LabelMatchMode) + } + + if result[0].Kind != "cluster" { + t.Errorf("Expected Kind to be 'cluster', got '%s'", result[0].Kind) + } + + if result[0].Namespace != "default" { + t.Errorf("Expected Namespace to be 'default', got '%s'", result[0].Namespace) + } + + if len(result[0].Labels) != 2 { + t.Errorf("Expected 2 labels, got %d", len(result[0].Labels)) + } + + expectedLabels := []string{"cnpg.io/instanceRole=primary", "cnpg.io/cluster=pg-eu"} + for i, label := range result[0].Labels { + if label != expectedLabels[i] { + t.Errorf("Expected label[%d] to be '%s', got '%s'", i, expectedLabels[i], label) + } + } +} + +func TestGetTargets_InvalidMode(t *testing.T) { + // Test that invalid mode falls back to union + targets := "deployment:default:[app=nginx]:invalid" + result := GetTargets(targets) + + if len(result) != 1 { + t.Errorf("Expected 1 AppDetail, got %d", len(result)) + } + + // Invalid mode should fall back to union + if result[0].LabelMatchMode != "union" { + t.Errorf("Expected invalid mode to fall back to 'union', got '%s'", result[0].LabelMatchMode) + } +} + +func TestGetTargets_MultipleSemicolonSeparated(t *testing.T) { + // Test multiple targets with different modes + targets := "deployment:ns1:[app=web]:union;statefulset:ns2:[db=postgres,env=prod]:intersection" + result := GetTargets(targets) + + if len(result) != 2 { + t.Errorf("Expected 2 AppDetails, got %d", len(result)) + } + + // First target - union + if result[0].LabelMatchMode != "union" { + t.Errorf("Expected first target LabelMatchMode to be 'union', got '%s'", result[0].LabelMatchMode) + } + + // Second target - intersection + if result[1].LabelMatchMode != "intersection" { + t.Errorf("Expected second target LabelMatchMode to be 'intersection', got '%s'", result[1].LabelMatchMode) + } +} + +func TestGetTargets_WithNames(t *testing.T) { + // Test that Names parsing still works with the new field + targets := "pod:default:[pod1,pod2,pod3]" + result := GetTargets(targets) + + if len(result) != 1 { + t.Errorf("Expected 1 AppDetail, got %d", len(result)) + } + + if len(result[0].Names) != 3 { + t.Errorf("Expected 3 names, got %d", len(result[0].Names)) + } + + if len(result[0].Labels) != 0 { + t.Errorf("Expected 0 labels when Names are provided, got %d", len(result[0].Labels)) + } + + // Default mode should still be union + if result[0].LabelMatchMode != "union" { + t.Errorf("Expected default LabelMatchMode to be 'union', got '%s'", result[0].LabelMatchMode) + } +} diff --git a/pkg/utils/common/pods.go b/pkg/utils/common/pods.go index 5de53affe..07f4e63e6 100644 --- a/pkg/utils/common/pods.go +++ b/pkg/utils/common/pods.go @@ -306,16 +306,30 @@ func GetTargetPodsWhenTargetPodsENVNotSet(podAffPerc int, clients clients.Client } finalPods.Items = append(finalPods.Items, pods.Items...) } else { - for _, label := range target.Labels { - pods, err := clients.KubeClient.CoreV1().Pods(target.Namespace).List(context.Background(), v1.ListOptions{LabelSelector: label}) + // Check label match mode to determine union vs intersection + if target.LabelMatchMode == "intersection" { + // INTERSECTION: Get pods matching ALL labels + pods, err := getPodsWithIntersectionLabels(target, clients) if err != nil { - return finalPods, cerrors.Error{ErrorCode: cerrors.ErrorTypeTargetSelection, Target: fmt.Sprintf("{podLabel: %s, namespace: %s}", label, target.Namespace), Reason: err.Error()} + return finalPods, err } - filteredPods, err := filterPodsByOwnerKind(pods.Items, target, clients) + filteredPods, err := filterPodsByOwnerKind(pods, target, clients) if err != nil { return finalPods, stacktrace.Propagate(err, "could not identify parent type from pod") } finalPods.Items = append(finalPods.Items, filteredPods...) + } else { + for _, label := range target.Labels { + pods, err := clients.KubeClient.CoreV1().Pods(target.Namespace).List(context.Background(), v1.ListOptions{LabelSelector: label}) + if err != nil { + return finalPods, cerrors.Error{ErrorCode: cerrors.ErrorTypeTargetSelection, Target: fmt.Sprintf("{podLabel: %s, namespace: %s}", label, target.Namespace), Reason: err.Error()} + } + filteredPods, err := filterPodsByOwnerKind(pods.Items, target, clients) + if err != nil { + return finalPods, stacktrace.Propagate(err, "could not identify parent type from pod") + } + finalPods.Items = append(finalPods.Items, filteredPods...) + } } } } @@ -331,6 +345,42 @@ func GetTargetPodsWhenTargetPodsENVNotSet(podAffPerc int, clients clients.Client return filterPodsByPercentage(finalPods, podAffPerc), nil } +// getPodsWithIntersectionLabels returns pods that match ALL the provided labels (intersection) +func getPodsWithIntersectionLabels(target types.AppDetails, clients clients.ClientSets) ([]core_v1.Pod, error) { + if len(target.Labels) == 0 { + return nil, cerrors.Error{ + ErrorCode: cerrors.ErrorTypeTargetSelection, + Target: fmt.Sprintf("{namespace: %s, kind: %s}", target.Namespace, target.Kind), + Reason: "no labels provided for intersection", + } + } + + // Build comma-separated label selector for intersection (AND logic) + // e.g., "app=nginx,env=prod,role=primary" + labelSelector := strings.Join(target.Labels, ",") + + pods, err := clients.KubeClient.CoreV1().Pods(target.Namespace).List(context.Background(), v1.ListOptions{ + LabelSelector: labelSelector, + }) + if err != nil { + return nil, cerrors.Error{ + ErrorCode: cerrors.ErrorTypeTargetSelection, + Target: fmt.Sprintf("{labels: %v, namespace: %s}", target.Labels, target.Namespace), + Reason: err.Error(), + } + } + + if len(pods.Items) == 0 { + return nil, cerrors.Error{ + ErrorCode: cerrors.ErrorTypeTargetSelection, + Target: fmt.Sprintf("{labels: %v, namespace: %s}", target.Labels, target.Namespace), + Reason: "no pods found matching all labels", + } + } + + return pods.Items, nil +} + func filterPodsByOwnerKind(pods []core_v1.Pod, target types.AppDetails, clients clients.ClientSets) ([]core_v1.Pod, error) { var filteredPods []core_v1.Pod for _, pod := range pods {