diff --git a/go/fn/errors.go b/go/fn/errors.go index 8c80a412a..d2b322701 100644 --- a/go/fn/errors.go +++ b/go/fn/errors.go @@ -18,6 +18,8 @@ import ( "fmt" "log" "strings" + + "github.com/GoogleContainerTools/kpt-functions-sdk/go/fn/internal" ) // ErrMissingFnConfig raises error if a required functionConfig is missing. @@ -40,9 +42,12 @@ func (e *ErrOpOrDie) Error() string { func handleOptOrDieErr() { if v := recover(); v != nil { - if e, ok := v.(ErrOpOrDie); ok { - log.Fatalf(e.Error()) - } else { + switch v.(type) { + case ErrOpOrDie: + log.Fatalf(v.(*ErrOpOrDie).Error()) + case internal.ErrEqualSelector: + log.Fatalf(v.(*internal.ErrEqualSelector).Error()) + default: panic(v) } } diff --git a/go/fn/examples/example_read_field_test.go b/go/fn/examples/example_read_field_test.go index 46a50a03b..bee226855 100644 --- a/go/fn/examples/example_read_field_test.go +++ b/go/fn/examples/example_read_field_test.go @@ -30,10 +30,13 @@ func Example_aReadField() { func readField(rl *fn.ResourceList) error { for _, obj := range rl.Items { - if obj.GetAPIVersion() == "apps/v1" && obj.GetKind() == "Deployment" { + switch{ + case obj.IsGVK("apps/v1", "Deployment"): var replicas int obj.GetOrDie(&replicas, "spec", "replicas") - fn.Logf("replicas is %v\n", replicas) + case obj.IsGVK("rbac.authorization.k8s.io/v1", "ClusterRoleBinding"): + var namespace string + obj.GetOrDie(&namespace, "subjects", "kind=ServiceAccount", "namespace") } } return nil diff --git a/go/fn/examples/example_set_field_test.go b/go/fn/examples/example_set_field_test.go index 9bd27f4cf..8bf16c54f 100644 --- a/go/fn/examples/example_set_field_test.go +++ b/go/fn/examples/example_set_field_test.go @@ -30,9 +30,13 @@ func Example_cSetField() { func setField(rl *fn.ResourceList) error { for _, obj := range rl.Items { - if obj.GetAPIVersion() == "apps/v1" && obj.GetKind() == "Deployment" { + switch{ + case obj.IsGVK("apps/v1", "Deployment"): replicas := 10 obj.SetOrDie(&replicas, "spec", "replicas") + case obj.IsGVK("rbac.authorization.k8s.io/v1", "ClusterRoleBinding"): + namespace := "test" + obj.SetOrDie(&namespace, "subjects", "kind=ServiceAccount", "namespace") } } return nil diff --git a/go/fn/internal/errors.go b/go/fn/internal/errors.go new file mode 100644 index 000000000..2785dfa5f --- /dev/null +++ b/go/fn/internal/errors.go @@ -0,0 +1,16 @@ +package internal + +import ( + "fmt" +) + +// ErrOpOrDie raises if the KubeObject operation panics. +type ErrEqualSelector struct { + field string + expected string + actual string +} + +func (e *ErrEqualSelector) Error() string { + return fmt.Sprintf("invalid selector syntax: expect %v, got %v ", e.expected, e.actual) +} diff --git a/go/fn/internal/maphelpers.go b/go/fn/internal/maphelpers.go index 124998c02..258151db9 100644 --- a/go/fn/internal/maphelpers.go +++ b/go/fn/internal/maphelpers.go @@ -16,13 +16,15 @@ package internal import ( "fmt" + "strings" ) func (o *MapVariant) GetNestedValue(fields ...string) (variant, bool, error) { - current := o + var current variant + current = o n := len(fields) for i := 0; i < n; i++ { - entry, found := current.getVariant(fields[i]) + entry, found := current.(*MapVariant).getVariant(fields[i]) if !found { return nil, found, nil } @@ -30,26 +32,67 @@ func (o *MapVariant) GetNestedValue(fields ...string) (variant, bool, error) { if i == n-1 { return entry, true, nil } - entryM, ok := entry.(*MapVariant) - if !ok { + switch entry.(type){ + case *MapVariant: + current, _ = entry.(*MapVariant) + case *sliceVariant: + if i+1 < n && IsEqualSelector(fields[i+1]) { + foundMapVar, err := entry.(*sliceVariant).GetSliceElementBySelector(fields[i+1]) + if err != nil { + return nil, found, err + } + if foundMapVar != nil { + current = foundMapVar + i += 1 + } else { + return nil, false, nil + } + } else { + return nil, found, fmt.Errorf("wrong type, got: %T", entry) + } + default: return nil, found, fmt.Errorf("wrong type, got: %T", entry) } - current = entryM } return nil, false, fmt.Errorf("unexpected code reached") } func (o *MapVariant) SetNestedValue(val variant, fields ...string) error { - current := o + var current variant + current = o n := len(fields) var err error for i := 0; i < n; i++ { if i == n-1 { - current.set(fields[i], val) + current.(*MapVariant).set(fields[i], val) } else { - current, _, err = current.getMap(fields[i], true) - if err != nil { - return err + entry, found := current.(*MapVariant).getVariant(fields[i]) + if !found{ + return fmt.Errorf("fields not exist %v", strings.Join(fields, ".")) + } + switch entry.(type){ + case *sliceVariant: + if i+1 < n && IsEqualSelector(fields[i+1]) { + mapVar, err := entry.(*sliceVariant).GetSliceElementBySelector(fields[i+1]) + if err != nil { + return err + } + if mapVar == nil { + valueNode := buildMappingNode() + valueVariant := &MapVariant{node: valueNode} + entry.(*sliceVariant).Add(valueVariant) + current = valueVariant + return fmt.Errorf("no matching element in selector %v", fields[i+1]) + } else { + current = mapVar + i += 1 + } + } + default: // MapVariant or not exist + current, _, err = current.(*MapVariant).getMap(fields[i], true) + if err != nil { + return err + } } } } diff --git a/go/fn/internal/selector.go b/go/fn/internal/selector.go new file mode 100644 index 000000000..c2f5cf898 --- /dev/null +++ b/go/fn/internal/selector.go @@ -0,0 +1,37 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package internal + +import ( + "strings" +) + +const equalDelimiter = "=" + +func IsEqualSelector(field string) bool{ + if strings.Contains(field, equalDelimiter) { + return true + } + // TODO: extend and support multi "equal" syntax. e.g. `kind=Project,name=kit` + return false +} + +func GetEqualSelector(field string) (string, string){ + segments := strings.Split(field, equalDelimiter) + if len(segments) != 2 { + panic(ErrEqualSelector{expected: "YOUR_KEY=YOUR_VALUE", actual: field}) + } + return segments[0], segments[1] +} + diff --git a/go/fn/internal/slice.go b/go/fn/internal/slice.go index a7b5e1180..59f17c747 100644 --- a/go/fn/internal/slice.go +++ b/go/fn/internal/slice.go @@ -49,3 +49,24 @@ func (v *sliceVariant) Objects() ([]*MapVariant, error) { func (v *sliceVariant) Add(node variant) { v.node.Content = append(v.node.Content, node.Node()) } + +func (o *sliceVariant) GetSliceElementBySelector(field string) (*MapVariant, error) { + key, expected := GetEqualSelector(field) + elements, err := o.Objects() + if err != nil { + return nil, err + } + for _, mapElement := range elements { + actual, found, err := mapElement.GetNestedString(key) + if !found { + continue + } + if err != nil { + return nil, err + } + if actual == expected { + return mapElement, nil + } + } + return nil, nil +}