From d6848b343407d6e4a39d57abad527804df5ac642 Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Mon, 13 Nov 2023 19:38:13 +0300 Subject: [PATCH] feat: provide an example of typed resource type This is something that I wanted to try for a while - attach the resource type directly to the resource.Type constant. The (partial) result is this PR. While it may look interesting, I don't think that we should merge it ATM. Because: 1. There are a LOT of `Naked()` calls around the code. 2. This will require a lot of refactoring. Even here it required a lot of places to be changed. 3. I'm not sure that this level of generic abuse is worth it (will it increase the resulting binary size because of the bloat?). Still, this is the good starting point if anyone is interested in how it could be done. Signed-off-by: Dmitriy Matrenichev --- pkg/controller/conformance/controllers.go | 72 +++++++------------ pkg/controller/conformance/resource.go | 4 +- pkg/controller/conformance/resources.go | 9 +-- pkg/controller/conformance/runtime.go | 18 ++--- pkg/controller/protobuf/protobuf_test.go | 4 +- pkg/resource/protobuf/registry_test.go | 8 +-- pkg/safe/reader.go | 5 ++ pkg/safe/safe.go | 31 ++++++++ pkg/safe/state.go | 22 ++++++ pkg/safe/state_test.go | 32 ++++++--- pkg/state/conformance/resources.go | 5 +- pkg/state/conformance/state.go | 26 +++---- pkg/state/filter_test.go | 2 +- pkg/state/impl/inmem/local_test.go | 6 +- pkg/state/impl/store/bolt/bbolt_test.go | 2 +- .../store/compression/compression_test.go | 2 +- .../impl/store/encryption/marshaler_test.go | 2 +- pkg/state/impl/store/protobuf_test.go | 2 +- pkg/state/protobuf/client/client_test.go | 2 +- pkg/state/protobuf/protobuf_test.go | 2 +- 20 files changed, 159 insertions(+), 97 deletions(-) diff --git a/pkg/controller/conformance/controllers.go b/pkg/controller/conformance/controllers.go index e67b0b8d..e5f10f15 100644 --- a/pkg/controller/conformance/controllers.go +++ b/pkg/controller/conformance/controllers.go @@ -33,12 +33,12 @@ func (ctrl *IntToStrController) Inputs() []controller.Input { return []controller.Input{ { Namespace: ctrl.SourceNamespace, - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.InputStrong, }, { Namespace: ctrl.TargetNamespace, - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.InputDestroyReady, }, } @@ -48,7 +48,7 @@ func (ctrl *IntToStrController) Inputs() []controller.Input { func (ctrl *IntToStrController) Outputs() []controller.Output { return []controller.Output{ { - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.OutputExclusive, }, } @@ -58,7 +58,7 @@ func (ctrl *IntToStrController) Outputs() []controller.Output { // //nolint:gocognit func (ctrl *IntToStrController) Run(ctx context.Context, r controller.Runtime, _ *zap.Logger) error { - sourceMd := resource.NewMetadata(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) + sourceMd := safe.NewTaggedMD(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) for { select { @@ -67,10 +67,7 @@ func (ctrl *IntToStrController) Run(ctx context.Context, r controller.Runtime, _ case <-r.EventCh(): } - intList, err := safe.ReaderList[interface { - IntegerResource - resource.Resource - }](ctx, r, sourceMd) + intList, err := safe.ReaderListByMD(ctx, r, sourceMd) if err != nil { return fmt.Errorf("error listing objects: %w", err) } @@ -154,7 +151,7 @@ func (ctrl *StrToSentenceController) Run(ctx context.Context, r controller.Runti if err := r.UpdateInputs([]controller.Input{ { Namespace: ctrl.SourceNamespace, - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.InputStrong, }, { @@ -166,7 +163,7 @@ func (ctrl *StrToSentenceController) Run(ctx context.Context, r controller.Runti return fmt.Errorf("error setting up dependencies: %w", err) } - sourceMd := resource.NewMetadata(ctrl.SourceNamespace, StrResourceType, "", resource.VersionUndefined) + sourceMd := safe.NewTaggedMD(ctrl.SourceNamespace, StrResourceType, "", resource.VersionUndefined) for { select { @@ -175,10 +172,7 @@ func (ctrl *StrToSentenceController) Run(ctx context.Context, r controller.Runti case <-r.EventCh(): } - strList, err := safe.ReaderList[interface { - StringResource - resource.Resource - }](ctx, r, sourceMd) + strList, err := safe.ReaderListByMD(ctx, r, sourceMd) if err != nil { return fmt.Errorf("error listing objects: %w", err) } @@ -250,7 +244,7 @@ func (ctrl *SumController) Inputs() []controller.Input { func (ctrl *SumController) Outputs() []controller.Output { return []controller.Output{ { - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.OutputShared, }, } @@ -261,14 +255,14 @@ func (ctrl *SumController) Run(ctx context.Context, r controller.Runtime, _ *zap if err := r.UpdateInputs([]controller.Input{ { Namespace: ctrl.SourceNamespace, - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.InputWeak, }, }); err != nil { return fmt.Errorf("error setting up dependencies: %w", err) } - sourceMd := resource.NewMetadata(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) + sourceMd := safe.NewTaggedMD(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) for { select { @@ -277,10 +271,7 @@ func (ctrl *SumController) Run(ctx context.Context, r controller.Runtime, _ *zap case <-r.EventCh(): } - intList, err := safe.ReaderList[interface { - IntegerResource - resource.Resource - }](ctx, r, sourceMd, state.WithLabelQuery(resource.RawLabelQuery(ctrl.SourceLabelQuery))) + intList, err := safe.ReaderListByMD(ctx, r, sourceMd, state.WithLabelQuery(resource.RawLabelQuery(ctrl.SourceLabelQuery))) if err != nil { return fmt.Errorf("error listing objects: %w", err) } @@ -325,7 +316,7 @@ func (ctrl *FailingController) Inputs() []controller.Input { func (ctrl *FailingController) Outputs() []controller.Output { return []controller.Output{ { - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.OutputExclusive, }, } @@ -372,7 +363,7 @@ func (ctrl *IntDoublerController) Inputs() []controller.Input { return []controller.Input{ { Namespace: ctrl.SourceNamespace, - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.InputStrong, }, } @@ -382,7 +373,7 @@ func (ctrl *IntDoublerController) Inputs() []controller.Input { func (ctrl *IntDoublerController) Outputs() []controller.Output { return []controller.Output{ { - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.OutputShared, }, } @@ -390,7 +381,7 @@ func (ctrl *IntDoublerController) Outputs() []controller.Output { // Run implements controller.Controller interface. func (ctrl *IntDoublerController) Run(ctx context.Context, r controller.Runtime, _ *zap.Logger) error { - sourceMd := resource.NewMetadata(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) + sourceMd := safe.NewTaggedMD(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) for { select { @@ -401,10 +392,7 @@ func (ctrl *IntDoublerController) Run(ctx context.Context, r controller.Runtime, r.StartTrackingOutputs() - intList, err := safe.ReaderList[interface { - IntegerResource - resource.Resource - }](ctx, r, sourceMd) + intList, err := safe.ReaderListByMD(ctx, r, sourceMd) if err != nil { return fmt.Errorf("error listing objects: %w", err) } @@ -423,7 +411,7 @@ func (ctrl *IntDoublerController) Run(ctx context.Context, r controller.Runtime, } } - if err = r.CleanupOutputs(ctx, resource.NewMetadata(ctrl.TargetNamespace, IntResourceType, "", resource.VersionUndefined)); err != nil { + if err = r.CleanupOutputs(ctx, safe.NewTaggedMD(ctrl.TargetNamespace, IntResourceType, "", resource.VersionUndefined)); err != nil { return fmt.Errorf("error cleaning up outputs: %w", err) } } @@ -445,7 +433,7 @@ func (ctrl *ModifyWithResultController) Inputs() []controller.Input { return []controller.Input{ { Namespace: ctrl.SourceNamespace, - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.InputStrong, }, } @@ -455,7 +443,7 @@ func (ctrl *ModifyWithResultController) Inputs() []controller.Input { func (ctrl *ModifyWithResultController) Outputs() []controller.Output { return []controller.Output{ { - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.OutputExclusive, }, } @@ -463,7 +451,7 @@ func (ctrl *ModifyWithResultController) Outputs() []controller.Output { // Run implements controller.Controller interface. func (ctrl *ModifyWithResultController) Run(ctx context.Context, r controller.Runtime, _ *zap.Logger) error { - sourceMd := resource.NewMetadata(ctrl.SourceNamespace, StrResourceType, "", resource.VersionUndefined) + sourceMd := safe.NewTaggedMD(ctrl.SourceNamespace, StrResourceType, "", resource.VersionUndefined) for { select { @@ -472,10 +460,7 @@ func (ctrl *ModifyWithResultController) Run(ctx context.Context, r controller.Ru case <-r.EventCh(): } - strList, err := safe.ReaderList[interface { - StringResource - resource.Resource - }](ctx, r, sourceMd) + strList, err := safe.ReaderListByMD(ctx, r, sourceMd) if err != nil { return fmt.Errorf("error listing objects: %w", err) } @@ -535,12 +520,12 @@ func (ctrl *MetricsController) Inputs() []controller.Input { return []controller.Input{ { Namespace: ctrl.SourceNamespace, - Type: IntResourceType, + Type: IntResourceType.Naked(), Kind: controller.InputStrong, }, { Namespace: ctrl.TargetNamespace, - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.InputDestroyReady, }, } @@ -550,7 +535,7 @@ func (ctrl *MetricsController) Inputs() []controller.Input { func (ctrl *MetricsController) Outputs() []controller.Output { return []controller.Output{ { - Type: StrResourceType, + Type: StrResourceType.Naked(), Kind: controller.OutputExclusive, }, } @@ -558,7 +543,7 @@ func (ctrl *MetricsController) Outputs() []controller.Output { // Run implements controller.Controller interface. func (ctrl *MetricsController) Run(ctx context.Context, r controller.Runtime, _ *zap.Logger) error { - sourceMd := resource.NewMetadata(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) + sourceMd := safe.NewTaggedMD(ctrl.SourceNamespace, IntResourceType, "", resource.VersionUndefined) for { select { @@ -567,10 +552,7 @@ func (ctrl *MetricsController) Run(ctx context.Context, r controller.Runtime, _ case <-r.EventCh(): } - intList, err := safe.ReaderList[interface { - IntegerResource - resource.Resource - }](ctx, r, sourceMd) + intList, err := safe.ReaderListByMD(ctx, r, sourceMd) if err != nil { return fmt.Errorf("error listing objects: %w", err) } diff --git a/pkg/controller/conformance/resource.go b/pkg/controller/conformance/resource.go index fcff0938..3cd5769a 100644 --- a/pkg/controller/conformance/resource.go +++ b/pkg/controller/conformance/resource.go @@ -4,7 +4,9 @@ package conformance -import "github.com/cosi-project/runtime/pkg/resource" +import ( + "github.com/cosi-project/runtime/pkg/resource" +) // Resource represents some T value. type Resource[T any, S Spec[T], SS SpecPtr[T, S]] struct { diff --git a/pkg/controller/conformance/resources.go b/pkg/controller/conformance/resources.go index f719622a..1016d8aa 100644 --- a/pkg/controller/conformance/resources.go +++ b/pkg/controller/conformance/resources.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "github.com/cosi-project/runtime/pkg/resource" + "github.com/cosi-project/runtime/pkg/safe" ) // IntegerResource is implemented by resources holding ints. @@ -23,14 +24,14 @@ type StringResource interface { } // IntResourceType is the type of IntResource. -const IntResourceType = resource.Type("test/int") +const IntResourceType = safe.TaggedType[*IntResource]("test/int") // IntResource represents some integer value. type IntResource = Resource[int, intSpec, *intSpec] // NewIntResource creates new IntResource. func NewIntResource(ns resource.Namespace, id resource.ID, value int) *IntResource { - return NewResource[int, intSpec, *intSpec](resource.NewMetadata(ns, IntResourceType, id, resource.VersionUndefined), value) + return NewResource[int, intSpec, *intSpec](safe.NewTaggedMD(ns, IntResourceType, id, resource.VersionUndefined).Naked(), value) } type intSpec struct{ ValueGetSet[int] } @@ -48,14 +49,14 @@ func (is intSpec) MarshalProto() ([]byte, error) { } // StrResourceType is the type of StrResource. -const StrResourceType = resource.Type("test/str") +const StrResourceType = safe.TaggedType[*StrResource]("test/str") // StrResource represents some string value. type StrResource = Resource[string, strSpec, *strSpec] // NewStrResource creates new StrResource. func NewStrResource(ns resource.Namespace, id resource.ID, value string) *StrResource { - return NewResource[string, strSpec, *strSpec](resource.NewMetadata(ns, StrResourceType, id, resource.VersionUndefined), value) + return NewResource[string, strSpec, *strSpec](resource.NewMetadata(ns, StrResourceType.Naked(), id, resource.VersionUndefined), value) } type strSpec struct{ ValueGetSet[string] } diff --git a/pkg/controller/conformance/runtime.go b/pkg/controller/conformance/runtime.go index 28f4ceeb..1e54c8e3 100644 --- a/pkg/controller/conformance/runtime.go +++ b/pkg/controller/conformance/runtime.go @@ -113,13 +113,13 @@ func (suite *RuntimeSuite) assertIntObjects(ids []string, values []int) retry.Re typ := IntResourceType return func() error { - items, err := suite.State.List(suite.ctx, resource.NewMetadata(ns, typ, "", resource.VersionUndefined)) + items, err := safe.StateListByMD(suite.ctx, suite.State, safe.NewTaggedMD(ns, typ, "", resource.VersionUndefined)) if err != nil { return err } - if len(items.Items) != len(ids) { - return retry.ExpectedErrorf("expected %d objects, got %d", len(ids), len(items.Items)) + if items.Len() != len(ids) { + return retry.ExpectedErrorf("expected %d objects, got %d", len(ids), items.Len()) } for i, id := range ids { @@ -186,13 +186,13 @@ func (suite *RuntimeSuite) TestIntToStrControllers() { suite.Assert().NoError(suite.State.Create(suite.ctx, NewIntResource("default", "two", 2))) suite.Assert().NoError(retry.Constant(10*time.Second, retry.WithUnits(10*time.Millisecond)). - Retry(suite.assertStrObjects("default", StrResourceType, []string{"one", "two"}, []string{"1", "2"}))) + Retry(suite.assertStrObjects("default", StrResourceType.Naked(), []string{"one", "two"}, []string{"1", "2"}))) three := NewIntResource("default", "three", 3) suite.Assert().NoError(suite.State.Create(suite.ctx, three)) suite.Assert().NoError(retry.Constant(10*time.Second, retry.WithUnits(10*time.Millisecond)). - Retry(suite.assertStrObjects("default", StrResourceType, []string{"one", "two", "three"}, []string{"1", "2", "3"}))) + Retry(suite.assertStrObjects("default", StrResourceType.Naked(), []string{"one", "two", "three"}, []string{"1", "2", "3"}))) type integerResource interface { IntegerResource @@ -207,7 +207,7 @@ func (suite *RuntimeSuite) TestIntToStrControllers() { suite.Assert().NoError(err) suite.Assert().NoError(retry.Constant(10*time.Second, retry.WithUnits(10*time.Millisecond)). - Retry(suite.assertStrObjects("default", StrResourceType, []string{"one", "two", "three"}, []string{"1", "2", "33"}))) + Retry(suite.assertStrObjects("default", StrResourceType.Naked(), []string{"one", "two", "three"}, []string{"1", "2", "33"}))) ready, err := suite.State.Teardown(suite.ctx, three.Metadata()) suite.Assert().NoError(err) @@ -217,7 +217,7 @@ func (suite *RuntimeSuite) TestIntToStrControllers() { suite.Assert().NoError(err) suite.Assert().NoError(retry.Constant(10*time.Second, retry.WithUnits(10*time.Millisecond)). - Retry(suite.assertStrObjects("default", StrResourceType, []string{"one", "two"}, []string{"1", "2"}))) + Retry(suite.assertStrObjects("default", StrResourceType.Naked(), []string{"one", "two"}, []string{"1", "2"}))) suite.Assert().NoError(suite.State.Destroy(suite.ctx, three.Metadata())) } @@ -474,7 +474,7 @@ func (suite *RuntimeSuite) TestModifyWithResultController() { suite.Require().NoError(suite.State.Create(suite.ctx, NewStrResource(srcNS, "id", "val-1"))) suite.Assert().NoError(retry.Constant(10*time.Second, retry.WithUnits(10*time.Millisecond)).Retry( - suite.assertStrObjects(targetNS, StrResourceType, + suite.assertStrObjects(targetNS, StrResourceType.Naked(), []string{"id-out", "id-out-modify-result"}, []string{"val-1-modified", "val-1-valid"}, ), @@ -490,7 +490,7 @@ func (suite *RuntimeSuite) TestModifyWithResultController() { suite.Require().NoError(err) suite.Assert().NoError(retry.Constant(10*time.Second, retry.WithUnits(10*time.Millisecond)).Retry( - suite.assertStrObjects(targetNS, StrResourceType, + suite.assertStrObjects(targetNS, StrResourceType.Naked(), []string{"id-out", "id-out-modify-result"}, []string{"val-2-modified", "val-2-valid"}, ), diff --git a/pkg/controller/protobuf/protobuf_test.go b/pkg/controller/protobuf/protobuf_test.go index d6fafddc..7e32ddbb 100644 --- a/pkg/controller/protobuf/protobuf_test.go +++ b/pkg/controller/protobuf/protobuf_test.go @@ -38,8 +38,8 @@ type ProtobufConformanceSuite struct { } func TestProtobufConformance(t *testing.T) { - require.NoError(t, protobuf.RegisterResource(conformance.IntResourceType, &conformance.IntResource{})) - require.NoError(t, protobuf.RegisterResource(conformance.StrResourceType, &conformance.StrResource{})) + require.NoError(t, protobuf.RegisterResource(conformance.IntResourceType.Naked(), &conformance.IntResource{})) + require.NoError(t, protobuf.RegisterResource(conformance.StrResourceType.Naked(), &conformance.StrResource{})) require.NoError(t, protobuf.RegisterResource(conformance.SentenceResourceType, &conformance.SentenceResource{})) suite := &ProtobufConformanceSuite{ diff --git a/pkg/resource/protobuf/registry_test.go b/pkg/resource/protobuf/registry_test.go index 9904e3ac..5632c591 100644 --- a/pkg/resource/protobuf/registry_test.go +++ b/pkg/resource/protobuf/registry_test.go @@ -16,12 +16,12 @@ import ( ) func BenchmarkCreateResource(b *testing.B) { - _ = protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{}) //nolint:errcheck + _ = protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{}) //nolint:errcheck protoR := &v1alpha1.Resource{ Metadata: &v1alpha1.Metadata{ Namespace: "ns", - Type: conformance.PathResourceType, + Type: conformance.PathResourceType.Naked(), Id: "a/b", Version: "3", Phase: "running", @@ -53,12 +53,12 @@ func BenchmarkCreateResource(b *testing.B) { func TestRegistry(t *testing.T) { t.Parallel() - require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) protoR := &v1alpha1.Resource{ Metadata: &v1alpha1.Metadata{ Namespace: "ns", - Type: conformance.PathResourceType, + Type: conformance.PathResourceType.Naked(), Id: "a/b", Version: "3", Phase: "running", diff --git a/pkg/safe/reader.go b/pkg/safe/reader.go index a61bd9ff..95fb3408 100644 --- a/pkg/safe/reader.go +++ b/pkg/safe/reader.go @@ -123,3 +123,8 @@ func ReaderWatchFor[T resource.Resource](ctx context.Context, rdr controller.Rea func ReaderWatchForResource[T resource.Resource](ctx context.Context, rdr controller.Reader, r T, conds ...state.WatchForConditionFunc) (T, error) { //nolint:ireturn return ReaderWatchFor[T](ctx, rdr, r.Metadata(), conds...) } + +// ReaderListByMD is a type safe wrapper around Reader.List. +func ReaderListByMD[T resource.Resource](ctx context.Context, rdr controller.Reader, md TaggedMD[T], opts ...state.ListOption) (List[T], error) { + return ReaderList[T](ctx, rdr, md, opts...) +} diff --git a/pkg/safe/safe.go b/pkg/safe/safe.go index 0f4260d7..a0910915 100644 --- a/pkg/safe/safe.go +++ b/pkg/safe/safe.go @@ -23,3 +23,34 @@ func typeAssertOrZero[T resource.Resource](got resource.Resource, err error) (T, return result, nil } + +// TaggedType is a type safe wrapper around [resource.Type]. +type TaggedType[T resource.Resource] resource.Type + +// Naked returns the underlying [resource.Type]. +func (t TaggedType[T]) Naked() resource.Type { + return resource.Type(t) +} + +// TaggedMD is a type safe wrapper around [resource.Metadata]. +type TaggedMD[T resource.Resource] resource.Metadata + +// Namespace returns the namespace of the resource. +func (t TaggedMD[T]) Namespace() resource.Namespace { + return resource.Metadata(t).Namespace() +} + +// Type returns the type of the resource. +func (t TaggedMD[T]) Type() resource.Type { + return resource.Metadata(t).Type() +} + +// Naked returns the underlying [resource.Metadata]. +func (t TaggedMD[T]) Naked() resource.Metadata { + return resource.Metadata(t) +} + +// NewTaggedMD creates a new [TaggedMD]. +func NewTaggedMD[T resource.Resource](ns resource.Namespace, typ TaggedType[T], id resource.ID, ver resource.Version) TaggedMD[T] { + return TaggedMD[T](resource.NewMetadata(ns, typ.Naked(), id, ver)) +} diff --git a/pkg/safe/state.go b/pkg/safe/state.go index 6116cb3e..129a48b3 100644 --- a/pkg/safe/state.go +++ b/pkg/safe/state.go @@ -315,3 +315,25 @@ func (it *ListIterator[T]) Next() bool { func (it *ListIterator[T]) Value() T { //nolint:ireturn return it.list.Get(it.pos - 1) } + +// StateWatchByMD is a type safe wrapper around State.Watch. +func StateWatchByMD[T resource.Resource](ctx context.Context, st state.CoreState, md TaggedMD[T], eventCh chan<- WrappedStateEvent[T], opts ...state.WatchOption) error { + return StateWatch[T](ctx, st, md.Naked(), eventCh, opts...) +} + +// StateWatchForByMD is a type safe wrapper around State.WatchFor. +func StateWatchForByMD[T resource.Resource](ctx context.Context, st state.State, md TaggedMD[T], opts ...state.WatchForConditionFunc) (T, error) { //nolint:ireturn + got, err := st.WatchFor(ctx, md.Naked(), opts...) + + return typeAssertOrZero[T](got, err) +} + +// StateListByMD is a type safe wrapper around state.List. +func StateListByMD[T resource.Resource](ctx context.Context, st state.CoreState, md TaggedMD[T], options ...state.ListOption) (List[T], error) { + return StateList[T](ctx, st, md.Naked(), options...) +} + +// StateGetByMD is a type safe wrapper around state.Get. +func StateGetByMD[T resource.Resource](ctx context.Context, st state.CoreState, md TaggedMD[T], options ...state.GetOption) (T, error) { //nolint:ireturn + return StateGet[T](ctx, st, md.Naked(), options...) +} diff --git a/pkg/safe/state_test.go b/pkg/safe/state_test.go index 5b5b0071..987d7818 100644 --- a/pkg/safe/state_test.go +++ b/pkg/safe/state_test.go @@ -41,16 +41,32 @@ func setup(t *testing.T) (context.Context, string, string, *conformance.IntResou return ctx, testNamespace, testID, r, s, safeEventCh, unsafeEventCh } +func TestStateGet(t *testing.T) { + ctx, testNamespace, testID, r, s, _, _ := setup(t) + + metadata := safe.NewTaggedMD(testNamespace, conformance.IntResourceType, testID, resource.VersionUndefined) + + assert.NoError(t, s.Create(ctx, r)) + + intRes, err := safe.StateGetByMD(ctx, s, metadata) + assert.NoError(t, err) + + naked := metadata.Naked() + naked.SetVersion(intRes.Metadata().Version()) + + assert.True(t, naked.Equal(*intRes.Metadata())) +} + func TestStateWatch(t *testing.T) { ctx, testNamespace, testID, r, s, safeEventCh, unsafeEventCh := setup(t) - metadata := resource.NewMetadata(testNamespace, conformance.IntResourceType, testID, resource.VersionUndefined) + metadata := safe.NewTaggedMD(testNamespace, conformance.IntResourceType, testID, resource.VersionUndefined) assert.NoError(t, s.Create(ctx, r)) - assert.NoError(t, s.Watch(ctx, metadata, unsafeEventCh)) + assert.NoError(t, s.Watch(ctx, metadata.Naked(), unsafeEventCh)) - assert.NoError(t, safe.StateWatch(ctx, s, metadata, safeEventCh)) + assert.NoError(t, safe.StateWatchByMD(ctx, s, metadata, safeEventCh)) unsafeEvent := <-unsafeEventCh @@ -72,14 +88,14 @@ func TestStateWatch(t *testing.T) { func TestStateWatchFor(t *testing.T) { ctx, testNamespace, testID, r, s, _, _ := setup(t) - metadata := resource.NewMetadata(testNamespace, conformance.IntResourceType, testID, resource.VersionUndefined) + metadata := safe.NewTaggedMD(testNamespace, conformance.IntResourceType, testID, resource.VersionUndefined) assert.NoError(t, s.Create(ctx, r)) - unsafeResult, unsafeWatchForErr := s.WatchFor(ctx, metadata) + unsafeResult, unsafeWatchForErr := s.WatchFor(ctx, metadata.Naked()) assert.NoError(t, unsafeWatchForErr) - safeResult, safeWatchForErr := safe.StateWatchFor[*conformance.IntResource](ctx, s, metadata) + safeResult, safeWatchForErr := safe.StateWatchForByMD(ctx, s, metadata) assert.NoError(t, safeWatchForErr) assert.Equal(t, unsafeResult, safeResult) @@ -88,7 +104,7 @@ func TestStateWatchFor(t *testing.T) { func TestStateWatchKind(t *testing.T) { ctx, testNamespace, _, r, s, safeEventCh, unsafeEventCh := setup(t) - metadata := resource.NewMetadata(testNamespace, conformance.IntResourceType, "", resource.VersionUndefined) + metadata := safe.NewTaggedMD(testNamespace, conformance.IntResourceType, "", resource.VersionUndefined) assert.NoError(t, s.WatchKind(ctx, metadata, unsafeEventCh)) @@ -129,7 +145,7 @@ func TestListFilter(t *testing.T) { require.NoError(t, s.Create(ctx, r)) } - all, err := safe.StateList[*conformance.IntResource](ctx, s, resource.NewMetadata(testNamespace, conformance.IntResourceType, "", resource.VersionUndefined)) + all, err := safe.StateListByMD(ctx, s, safe.NewTaggedMD(testNamespace, conformance.IntResourceType, "", resource.VersionUndefined)) require.NoError(t, err) assert.Equal(t, 3, all.Len()) diff --git a/pkg/state/conformance/resources.go b/pkg/state/conformance/resources.go index c9b4a251..c4b85859 100644 --- a/pkg/state/conformance/resources.go +++ b/pkg/state/conformance/resources.go @@ -8,10 +8,11 @@ import ( "fmt" "github.com/cosi-project/runtime/pkg/resource" + "github.com/cosi-project/runtime/pkg/safe" ) // PathResourceType is the type of PathResource. -const PathResourceType = resource.Type("os/path") +const PathResourceType = safe.TaggedType[*PathResource]("os/path") // PathResource represents a path in the filesystem. // @@ -29,7 +30,7 @@ func (spec pathSpec) MarshalProto() ([]byte, error) { // NewPathResource creates new PathResource. func NewPathResource(ns resource.Namespace, path string) *PathResource { r := &PathResource{ - md: resource.NewMetadata(ns, PathResourceType, path, resource.VersionUndefined), + md: resource.NewMetadata(ns, PathResourceType.Naked(), path, resource.VersionUndefined), } return r diff --git a/pkg/state/conformance/state.go b/pkg/state/conformance/state.go index d7fb30cf..ad009388 100644 --- a/pkg/state/conformance/state.go +++ b/pkg/state/conformance/state.go @@ -979,16 +979,18 @@ func (suite *StateSuite) TestLabels() { err = suite.State.Create(ctx, path3) suite.Require().NoError(err) - r, err := suite.State.Get(ctx, path1.Metadata()) + pmd := safe.TaggedMD[*PathResource](*path1.Metadata()) + + r, err := safe.StateGetByMD(ctx, suite.State, pmd) suite.Require().NoError(err) - path1Copy := r.(*PathResource) //nolint:errcheck,forcetypeassert + path1Copy := r v, ok := path1Copy.Metadata().Labels().Get("app") suite.Assert().True(ok) suite.Assert().Equal("app1", v) - list, err := safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelExists("frozen"))) + list, err := safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelExists("frozen"))) suite.Require().NoError(err) suite.Require().Equal(2, list.Len()) @@ -996,57 +998,57 @@ func (suite *StateSuite) TestLabels() { suite.Assert().True(resourceEqualIgnoreVersion(path1, list.Get(0))) suite.Assert().True(resourceEqualIgnoreVersion(path2, list.Get(1))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelExists("frozen"), resource.LabelEqual("app", "app2"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelExists("frozen"), resource.LabelEqual("app", "app2"))) suite.Require().NoError(err) suite.Require().Equal(1, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path2, list.Get(0))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelExists("frozen"), resource.LabelEqual("app", "app3"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelExists("frozen"), resource.LabelEqual("app", "app3"))) suite.Require().NoError(err) suite.Require().Equal(0, list.Len()) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelEqual("app", "app3"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelEqual("app", "app3"))) suite.Require().NoError(err) suite.Require().Equal(1, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path3, list.Get(0))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelIn("app", []string{"app2", "app3"}))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelIn("app", []string{"app2", "app3"}))) suite.Require().NoError(err) suite.Require().Equal(2, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path2, list.Get(0))) suite.Assert().True(resourceEqualIgnoreVersion(path3, list.Get(1))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelLTNumeric("weight", "12000"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelLTNumeric("weight", "12000"))) suite.Require().NoError(err) suite.Require().Equal(1, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path1, list.Get(0))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelLTENumeric("weight", "20000"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelLTENumeric("weight", "20000"))) suite.Require().NoError(err) suite.Require().Equal(2, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path1, list.Get(0))) suite.Assert().True(resourceEqualIgnoreVersion(path2, list.Get(1))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelLTE("app", "app2"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelLTE("app", "app2"))) suite.Require().NoError(err) suite.Require().Equal(2, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path1, list.Get(0))) suite.Assert().True(resourceEqualIgnoreVersion(path2, list.Get(1))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), state.WithLabelQuery(resource.LabelLT("app", "app2"))) + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelLT("app", "app2"))) suite.Require().NoError(err) suite.Require().Equal(1, list.Len()) suite.Assert().True(resourceEqualIgnoreVersion(path1, list.Get(0))) - list, err = safe.StateList[*PathResource](ctx, suite.State, path1.Metadata(), + list, err = safe.StateListByMD(ctx, suite.State, pmd, state.WithLabelQuery(resource.LabelEqual("app", "app2")), state.WithLabelQuery(resource.LabelEqual("app", "app3")), ) diff --git a/pkg/state/filter_test.go b/pkg/state/filter_test.go index ac8a1b14..c4073b30 100644 --- a/pkg/state/filter_test.go +++ b/pkg/state/filter_test.go @@ -49,7 +49,7 @@ func TestFilterSingleResource(t *testing.T) { state.Filter( namespaced.NewState(inmem.Build), func(ctx context.Context, access state.Access) error { - if access.ResourceNamespace != namespace || access.ResourceType != resourceType || access.ResourceID != resourceID { + if access.ResourceNamespace != namespace || access.ResourceType != resourceType.Naked() || access.ResourceID != resourceID { return fmt.Errorf("access denied") } diff --git a/pkg/state/impl/inmem/local_test.go b/pkg/state/impl/inmem/local_test.go index cec69bae..9147358e 100644 --- a/pkg/state/impl/inmem/local_test.go +++ b/pkg/state/impl/inmem/local_test.go @@ -82,10 +82,10 @@ func TestBufferOverrun(t *testing.T) { watchKindCh := make(chan state.Event) watchCh := make(chan state.Event) - err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh) + err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType.Naked(), "", resource.VersionUndefined), watchKindCh) require.NoError(t, err) - err = st.Watch(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "0", resource.VersionUndefined), watchCh) + err = st.Watch(ctx, resource.NewMetadata(namespace, conformance.PathResourceType.Naked(), "0", resource.VersionUndefined), watchCh) require.NoError(t, err) // insert 10 resources @@ -176,7 +176,7 @@ func TestNoBufferOverrunDynamic(t *testing.T) { // start watching for changes watchKindCh := make(chan state.Event) - err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType, "", resource.VersionUndefined), watchKindCh) + err := st.WatchKind(ctx, resource.NewMetadata(namespace, conformance.PathResourceType.Naked(), "", resource.VersionUndefined), watchKindCh) require.NoError(t, err) // insert N resources diff --git a/pkg/state/impl/store/bolt/bbolt_test.go b/pkg/state/impl/store/bolt/bbolt_test.go index d321a16d..f909a659 100644 --- a/pkg/state/impl/store/bolt/bbolt_test.go +++ b/pkg/state/impl/store/bolt/bbolt_test.go @@ -23,7 +23,7 @@ import ( func TestBboltStore(t *testing.T) { //nolint:tparallel t.Parallel() - require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) tmpDir := t.TempDir() diff --git a/pkg/state/impl/store/compression/compression_test.go b/pkg/state/impl/store/compression/compression_test.go index b4f1b306..ea881d99 100644 --- a/pkg/state/impl/store/compression/compression_test.go +++ b/pkg/state/impl/store/compression/compression_test.go @@ -120,5 +120,5 @@ func generateString(lines int) string { } func init() { - ensure.NoError(protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + ensure.NoError(protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) } diff --git a/pkg/state/impl/store/encryption/marshaler_test.go b/pkg/state/impl/store/encryption/marshaler_test.go index fa39c0b5..1b0c087f 100644 --- a/pkg/state/impl/store/encryption/marshaler_test.go +++ b/pkg/state/impl/store/encryption/marshaler_test.go @@ -19,7 +19,7 @@ import ( ) func init() { - ensure.NoError(protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + ensure.NoError(protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) } func TestMarshaler_Key(t *testing.T) { diff --git a/pkg/state/impl/store/protobuf_test.go b/pkg/state/impl/store/protobuf_test.go index a8e67d86..467daf9d 100644 --- a/pkg/state/impl/store/protobuf_test.go +++ b/pkg/state/impl/store/protobuf_test.go @@ -50,5 +50,5 @@ func BenchmarkProto(b *testing.B) { } func init() { - ensure.NoError(protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + ensure.NoError(protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) } diff --git a/pkg/state/protobuf/client/client_test.go b/pkg/state/protobuf/client/client_test.go index 8c89eb18..881d7954 100644 --- a/pkg/state/protobuf/client/client_test.go +++ b/pkg/state/protobuf/client/client_test.go @@ -65,7 +65,7 @@ func TestProtobufSkipUnmarshal(t *testing.T) { stateClient := v1alpha1.NewStateClient(grpcConn) - require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) grpcState := state.WrapCore(client.NewAdapter(stateClient)) diff --git a/pkg/state/protobuf/protobuf_test.go b/pkg/state/protobuf/protobuf_test.go index aea32367..eb357b13 100644 --- a/pkg/state/protobuf/protobuf_test.go +++ b/pkg/state/protobuf/protobuf_test.go @@ -65,7 +65,7 @@ func TestProtobufConformance(t *testing.T) { stateClient := v1alpha1.NewStateClient(grpcConn) - require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType, &conformance.PathResource{})) + require.NoError(t, protobuf.RegisterResource(conformance.PathResourceType.Naked(), &conformance.PathResource{})) suite.Run(t, &conformance.StateSuite{ State: state.WrapCore(client.NewAdapter(stateClient)),