diff --git a/tf5muxserver/mux_server.go b/tf5muxserver/mux_server.go index ab0393b..82421f7 100644 --- a/tf5muxserver/mux_server.go +++ b/tf5muxserver/mux_server.go @@ -55,6 +55,13 @@ type muxServer struct { // Underlying servers for requests that should be handled by all servers servers []tfprotov5.ProviderServer + + // interceptors []tfprotov5.Interceptor + interceptors []Interceptor +} + +type Interceptor struct { + BeforeListResource func(context.Context, *tfprotov5.ListResourceRequest) context.Context } // ProviderServer is a function compatible with tf6server.Serve. @@ -426,3 +433,48 @@ func NewMuxServer(_ context.Context, servers ...func() tfprotov5.ProviderServer) return &result, nil } + +type Option func(*muxServer) + +func Servers(servers ...func() tfprotov5.ProviderServer) Option { + return func(mux *muxServer) { + for _, server := range servers { + mux.servers = append(mux.servers, server()) + } + } +} + +func Interceptors(interceptors ...Interceptor) Option { + return func(mux *muxServer) { + mux.interceptors = append(mux.interceptors, interceptors...) + } +} + +// NewMuxServerWithOptions returns a muxed server that will route gRPC requests between +// tfprotov5.ProviderServers specified. The GetProviderSchema method of each +// is called to verify that the overall muxed server is compatible by ensuring: +// +// - All provider schemas exactly match +// - All provider meta schemas exactly match +// - Only one provider implements each managed resource +// - Only one provider implements each data source +// - Only one provider implements each function +// - Only one provider implements each ephemeral resource +// - Only one provider implements each list resource +// - Only one provider implements each resource identity +func NewMuxServerWithOptions(_ context.Context, options ...Option) (*muxServer, error) { + result := muxServer{ + dataSources: make(map[string]tfprotov5.ProviderServer), + ephemeralResources: make(map[string]tfprotov5.ProviderServer), + listResources: make(map[string]tfprotov5.ProviderServer), + functions: make(map[string]tfprotov5.ProviderServer), + resources: make(map[string]tfprotov5.ProviderServer), + resourceCapabilities: make(map[string]*tfprotov5.ServerCapabilities), + } + + for _, option := range options { + option(&result) + } + + return &result, nil +} diff --git a/tf5muxserver/mux_server_ListResource.go b/tf5muxserver/mux_server_ListResource.go index 3b6be28..4c59aea 100644 --- a/tf5muxserver/mux_server_ListResource.go +++ b/tf5muxserver/mux_server_ListResource.go @@ -17,6 +17,13 @@ func (s *muxServer) ListResource(ctx context.Context, req *tfprotov5.ListResourc ctx = logging.InitContext(ctx) ctx = logging.RpcContext(ctx, rpc) + for _, i := range s.interceptors { + if i.BeforeListResource == nil { + continue + } + ctx = i.BeforeListResource(ctx, req) + } + server, diags, err := s.getListResourceServer(ctx, req.TypeName) if err != nil {