diff --git a/go.mod b/go.mod index 48e906f32..436ee6026 100644 --- a/go.mod +++ b/go.mod @@ -86,6 +86,7 @@ require ( github.com/google/gofuzz v1.2.0 // indirect github.com/googleapis/gax-go/v2 v2.3.0 // indirect github.com/googleapis/go-type-adapters v1.0.0 // indirect + github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect @@ -134,3 +135,5 @@ require ( ) replace github.com/aws/amazon-sagemaker-operator-for-k8s => github.com/aws/amazon-sagemaker-operator-for-k8s v1.0.1-0.20210303003444-0fb33b1fd49d + +replace github.com/flyteorg/flyteidl => github.com/flyteorg/flyteidl v1.3.9-0.20230224194627-a1df35060476 diff --git a/go.sum b/go.sum index 559ec0ef5..dfa6d6e1c 100644 --- a/go.sum +++ b/go.sum @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= -github.com/flyteorg/flyteidl v1.3.6 h1:PI846AdnrQZ84pxRVAzA3WGihv+xXmjQHO91nj/kV9g= -github.com/flyteorg/flyteidl v1.3.6/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= +github.com/flyteorg/flyteidl v1.3.9-0.20230224194627-a1df35060476 h1:mA3Ry5YjNu5BqjnCTbA+lFRTRFjGKEMDALRhLTtBuuU= +github.com/flyteorg/flyteidl v1.3.9-0.20230224194627-a1df35060476/go.mod h1:Pkt2skI1LiHs/2ZoekBnyPhuGOFMiuul6HHcKGZBsbM= github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0= github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= @@ -443,6 +443,7 @@ github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:Fecb github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.5/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= diff --git a/go/tasks/pluginmachinery/internal/webapi/core.go b/go/tasks/pluginmachinery/internal/webapi/core.go index 6d506af19..1db839d19 100644 --- a/go/tasks/pluginmachinery/internal/webapi/core.go +++ b/go/tasks/pluginmachinery/internal/webapi/core.go @@ -69,6 +69,7 @@ func (c CorePlugin) GetProperties() core.PluginProperties { } func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + c.metrics.NumberOfTasks.Inc(ctx) incomingState, err := c.unmarshalState(ctx, tCtx.PluginStateReader()) if err != nil { return core.UnknownTransition, err @@ -96,7 +97,10 @@ func (c CorePlugin) Handle(ctx context.Context, tCtx core.TaskExecutionContext) if err := tCtx.PluginStateWriter().Put(pluginStateVersion, nextState); err != nil { return core.UnknownTransition, err } - + c.metrics.NumberOfTasks.Dec(ctx) + logger.Infof(ctx, "number of requests [%v]", c.metrics.NumberOfTasks) + // logger.Infof(ctx, "request latency [%v]", time.Since(start).Round(time.Microsecond).String()) + logger.Infof(ctx, "phaseInfo [%v]", phaseInfo) return core.DoTransitionType(core.TransitionTypeBarrier, phaseInfo), nil } diff --git a/go/tasks/pluginmachinery/internal/webapi/metrics.go b/go/tasks/pluginmachinery/internal/webapi/metrics.go index d5f767a58..7e3b82f68 100644 --- a/go/tasks/pluginmachinery/internal/webapi/metrics.go +++ b/go/tasks/pluginmachinery/internal/webapi/metrics.go @@ -17,6 +17,7 @@ type Metrics struct { ResourceWaitTime prometheus.Summary SucceededUnmarshalState labeled.StopWatch FailedUnmarshalState labeled.Counter + NumberOfTasks labeled.Gauge } var ( @@ -40,5 +41,6 @@ func newMetrics(scope promutils.Scope) Metrics { time.Millisecond, scope), FailedUnmarshalState: labeled.NewCounter("unmarshal_state_failed", "Failed to unmarshal state", scope, labeled.EmitUnlabeledMetric), + NumberOfTasks: labeled.NewGauge("number_of_tasks", "number of running tasks", scope, labeled.EmitUnlabeledMetric), } } diff --git a/go/tasks/pluginmachinery/webapi/plugin.go b/go/tasks/pluginmachinery/webapi/plugin.go index 63b6b5e2b..853993834 100644 --- a/go/tasks/pluginmachinery/webapi/plugin.go +++ b/go/tasks/pluginmachinery/webapi/plugin.go @@ -81,6 +81,7 @@ type TaskExecutionContext interface { type GetContext interface { ResourceMeta() ResourceMeta + Resource() Resource } type DeleteContext interface { diff --git a/go/tasks/plugins/webapi/dummy/config.go b/go/tasks/plugins/webapi/dummy/config.go new file mode 100644 index 000000000..e7f982b33 --- /dev/null +++ b/go/tasks/plugins/webapi/dummy/config.go @@ -0,0 +1,62 @@ +package dummy + +import ( + "time" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" +) + +var ( + defaultConfig = Config{ + WebAPI: webapi.PluginConfig{ + ResourceQuotas: map[core.ResourceNamespace]int{ + "default": 1000, + }, + ReadRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + WriteRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + Caching: webapi.CachingConfig{ + Size: 500000, + ResyncInterval: config.Duration{Duration: 30 * time.Second}, + Workers: 10, + MaxSystemFailures: 5, + }, + ResourceMeta: nil, + }, + ResourceConstraints: core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: &core.ResourceConstraint{ + Value: 100, + }, + NamespaceScopeResourceConstraint: &core.ResourceConstraint{ + Value: 50, + }, + }, + } + + configSection = pluginsConfig.MustRegisterSubSection("dummy", &defaultConfig) +) + +// Config is config for 'databricks' plugin +type Config struct { + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` + + // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time + ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/go/tasks/plugins/webapi/dummy/plugin.go b/go/tasks/plugins/webapi/dummy/plugin.go new file mode 100644 index 000000000..cdf90cae0 --- /dev/null +++ b/go/tasks/plugins/webapi/dummy/plugin.go @@ -0,0 +1,132 @@ +package dummy + +import ( + "context" + "encoding/gob" + flyteIdlCore "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/ioutils" + "math/rand" + "net/http" + + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" +) + +// for mocking/testing purposes, and we'll override this method +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type Plugin struct { + metricScope promutils.Scope + cfg *Config + client HTTPClient +} + +type ResourceWrapper struct { + StatusCode int + JobID string + Message string +} + +type ResourceMetaWrapper struct { + RunID string + Token string +} + +func (p Plugin) GetConfig() webapi.PluginConfig { + return GetConfig().WebAPI +} + +func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + + // Resource requirements are assumed to be the same. + return "default", p.cfg.ResourceConstraints, nil +} + +func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, + webapi.Resource, error) { + _, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, nil, err + } + + return &ResourceMetaWrapper{RunID: "runID", Token: "token"}, + &ResourceWrapper{StatusCode: 200}, nil +} + +func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { + return &ResourceWrapper{ + StatusCode: 200, + JobID: "jobID", + }, nil +} + +func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { + return nil +} + +func (p Plugin) Status(ctx context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { + x := rand.Intn(100) + if x < 50 { + err := writeOutput(ctx, taskCtx, "s3://bucket/key") + if err != nil { + return core.PhaseInfo{}, err + } + return pluginsCore.PhaseInfoSuccess(&core.TaskInfo{}), nil + } + return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, &core.TaskInfo{}), nil +} + +func writeOutput(ctx context.Context, tCtx webapi.StatusContext, OutputLocation string) error { + _, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return err + } + + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + &flyteIdlCore.LiteralMap{ + Literals: map[string]*flyteIdlCore.Literal{ + "results": { + Value: &flyteIdlCore.Literal_Scalar{ + Scalar: &flyteIdlCore.Scalar{ + Value: &flyteIdlCore.Scalar_StructuredDataset{ + StructuredDataset: &flyteIdlCore.StructuredDataset{ + Uri: OutputLocation, + Metadata: &flyteIdlCore.StructuredDatasetMetadata{ + StructuredDatasetType: &flyteIdlCore.StructuredDatasetType{Format: ""}, + }, + }, + }, + }, + }, + }, + }, + }, nil, nil)) +} + +func newDummyTaskPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ + ID: "dummy", + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "snowflake", "spark"}, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + client: &http.Client{}, + }, nil + }, + } +} + +func init() { + gob.Register(ResourceMetaWrapper{}) + gob.Register(ResourceWrapper{}) + + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newDummyTaskPlugin()) +} diff --git a/go/tasks/plugins/webapi/fastapi/config.go b/go/tasks/plugins/webapi/fastapi/config.go new file mode 100644 index 000000000..f2638c4af --- /dev/null +++ b/go/tasks/plugins/webapi/fastapi/config.go @@ -0,0 +1,70 @@ +package fastapi + +import ( + "time" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" +) + +var ( + tokenKey = "FLYTE_FAST_API_TOKEN" // nolint: gosec + + defaultConfig = Config{ + WebAPI: webapi.PluginConfig{ + ResourceQuotas: map[core.ResourceNamespace]int{ + "default": 1000, + }, + ReadRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + WriteRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + Caching: webapi.CachingConfig{ + Size: 500000, + ResyncInterval: config.Duration{Duration: 30 * time.Second}, + Workers: 10, + MaxSystemFailures: 5, + }, + ResourceMeta: nil, + }, + ResourceConstraints: core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: &core.ResourceConstraint{ + Value: 100, + }, + NamespaceScopeResourceConstraint: &core.ResourceConstraint{ + Value: 50, + }, + }, + TokenKey: tokenKey, + } + + configSection = pluginsConfig.MustRegisterSubSection("fastapi", &defaultConfig) +) + +// Config is config for 'databricks' plugin +type Config struct { + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` + + // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time + ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` + + TokenKey string `json:"fastApiTokenKey" pflag:",Name of the key where to find Fast API access token in the secret manager."` + + // fastAPIEndpoint overrides fastapi server endpoint, only for testing + fastAPIEndpoint string +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/go/tasks/plugins/webapi/fastapi/plugin.go b/go/tasks/plugins/webapi/fastapi/plugin.go new file mode 100644 index 000000000..3b3445f78 --- /dev/null +++ b/go/tasks/plugins/webapi/fastapi/plugin.go @@ -0,0 +1,272 @@ +package fastapi + +import ( + "bytes" + "context" + "encoding/gob" + "encoding/json" + "fmt" + pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flytestdlib/errors" + "github.com/flyteorg/flytestdlib/logger" + "io/ioutil" + "net/http" + "time" + + "github.com/flyteorg/flytestdlib/promutils" + + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" +) + +const ( + ErrSystem errors.ErrorCode = "System" + postMethod string = "POST" + getMethod string = "GET" + deleteMethod string = "DELETE" + pluginAPI string = "plugins/v1/dummy" +) + +// for mocking/testing purposes, and we'll override this method +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +type Plugin struct { + metricScope promutils.Scope + cfg *Config + client HTTPClient +} + +type ResourceWrapper struct { + StatusCode int + State string +} + +type ResourceMetaWrapper struct { + OutputPrefix string + Token string + JobID string +} + +func (p Plugin) GetConfig() webapi.PluginConfig { + return GetConfig().WebAPI +} + +func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + + // Resource requirements are assumed to be the same. + return "default", p.cfg.ResourceConstraints, nil +} + +func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, + webapi.Resource, error) { + taskTemplatePath, err := taskCtx.TaskReader().Path(ctx) + if err != nil { + return nil, nil, err + } + + body := map[string]string{ + "inputs_path": taskCtx.InputReader().GetInputPath().String(), + "task_template_path": taskTemplatePath.String(), + } + + mJSON, err := json.Marshal(body) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal data: %v: %v", body, err) + } + + postDataJSON := []byte(string(mJSON)) + req, err := buildRequest(postMethod, postDataJSON, p.cfg.fastAPIEndpoint, "token", "") + if err != nil { + return nil, nil, err + } + + start := time.Now() + resp, err := p.client.Do(req) + logger.Infof(ctx, "fastapi create request latency [%v]", time.Since(start).Round(time.Microsecond).String()) + if err != nil { + return nil, nil, err + } + defer resp.Body.Close() + + data, err := buildResponse(resp) + if err != nil { + return nil, nil, err + } + if data["job_id"] == "" { + return nil, nil, pluginErrors.Wrapf(pluginErrors.RuntimeFailure, err, + "Unable to extract job_id from http response") + } + + jobID := fmt.Sprintf("%s", data["job_id"]) + + return &ResourceMetaWrapper{ + OutputPrefix: taskCtx.OutputWriter().GetOutputPrefixPath().String(), + JobID: jobID, + Token: "", + }, &ResourceWrapper{StatusCode: resp.StatusCode}, nil +} + +func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { + metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + prevState := "running" + if taskCtx.Resource() != nil { + resource := taskCtx.Resource().(*ResourceWrapper) + prevState = resource.State + } + + body := map[string]string{ + "output_prefix": metadata.OutputPrefix, + "job_id": metadata.JobID, + "prev_state": prevState, + } + + mJSON, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal data: %v: %v", body, err) + } + + getDataJSON := []byte(string(mJSON)) + req, err := buildRequest(getMethod, getDataJSON, p.cfg.fastAPIEndpoint, metadata.Token, metadata.JobID) + if err != nil { + logger.Errorf(ctx, "Failed to build fast api job request [%v]", err) + return nil, err + } + resp, err := p.client.Do(req) + if err != nil { + logger.Errorf(ctx, "Failed to get job status [%v]", resp) + return nil, err + } + defer resp.Body.Close() + data, err := buildResponse(resp) + if err != nil { + return nil, err + } + + state := fmt.Sprintf("%s", data["state"]) + return &ResourceWrapper{ + StatusCode: resp.StatusCode, + State: state, + }, nil +} + +func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { + if taskCtx.ResourceMeta() == nil { + return nil + } + exec := taskCtx.ResourceMeta().(ResourceMetaWrapper) + req, err := buildRequest(deleteMethod, nil, p.cfg.fastAPIEndpoint, exec.Token, exec.JobID) + if err != nil { + return err + } + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + logger.Info(ctx, "Deleted query execution [%v]", resp) + + return nil +} + +func (p Plugin) Status(_ context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { + resource := taskCtx.Resource().(*ResourceWrapper) + statusCode := resource.StatusCode + state := resource.State + + if statusCode == 0 { + return core.PhaseInfoUndefined, errors.Errorf(ErrSystem, "No Status field set.") + } + + taskInfo := &core.TaskInfo{} + message := "" + + switch statusCode { + case http.StatusAccepted: + return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil + case http.StatusOK: + switch state { + case "succeeded": + return pluginsCore.PhaseInfoSuccess(taskInfo), nil + case "failed": + return core.PhaseInfoFailure(string(rune(statusCode)), "failed to run the job", taskInfo), nil + default: + return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil + } + case http.StatusBadRequest: + fallthrough + case http.StatusInternalServerError: + fallthrough + case http.StatusUnauthorized: + return pluginsCore.PhaseInfoFailure(string(rune(statusCode)), message, taskInfo), nil + } + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", statusCode) +} + +func buildRequest(method string, data []byte, fastAPIEndpoint string, token string, jobID string) (*http.Request, error) { + var fastAPIURL string + // for mocking/testing purposes + if fastAPIEndpoint == "" { + fastAPIURL = fmt.Sprintf("http://backend-plugin-system.flyte.svc.cluster.local:8000/%v", pluginAPI) + } else { + fastAPIURL = fmt.Sprintf("%v%v", fastAPIEndpoint, pluginAPI) + } + + if method == deleteMethod { + fastAPIURL = fmt.Sprintf("%v/?job_id=%v", fastAPIURL, jobID) + } + + var req *http.Request + var err error + if data == nil { + req, err = http.NewRequest(method, fastAPIURL, nil) + } else { + req, err = http.NewRequest(method, fastAPIURL, bytes.NewBuffer(data)) + } + if err != nil { + return nil, err + } + + // TODO: authentication support + req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("Content-Type", "application/json") + return req, nil +} + +func buildResponse(response *http.Response) (map[string]interface{}, error) { + responseBody, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, err + } + var data map[string]interface{} + err = json.Unmarshal(responseBody, &data) + if err != nil { + return nil, err + } + return data, nil +} + +func newFastAPIPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ + ID: "fastapi", + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "snowflake", "spark"}, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + client: &http.Client{}, + }, nil + }, + } +} + +func init() { + gob.Register(ResourceMetaWrapper{}) + gob.Register(ResourceWrapper{}) + + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newFastAPIPlugin()) +} diff --git a/go/tasks/plugins/webapi/grpc/config.go b/go/tasks/plugins/webapi/grpc/config.go new file mode 100644 index 000000000..520af4ce0 --- /dev/null +++ b/go/tasks/plugins/webapi/grpc/config.go @@ -0,0 +1,71 @@ +package grpc + +import ( + "time" + + pluginsConfig "github.com/flyteorg/flyteplugins/go/tasks/config" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/config" +) + +var ( + grpcTokenKey = "FLYTE_GRPC_TOKEN" // nolint: gosec + + defaultConfig = Config{ + WebAPI: webapi.PluginConfig{ + ResourceQuotas: map[core.ResourceNamespace]int{ + "default": 1000, + }, + ReadRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + WriteRateLimiter: webapi.RateLimiterConfig{ + Burst: 100, + QPS: 10, + }, + Caching: webapi.CachingConfig{ + Size: 500000, + ResyncInterval: config.Duration{Duration: 30 * time.Second}, + Workers: 10, + MaxSystemFailures: 5, + }, + ResourceMeta: nil, + }, + ResourceConstraints: core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: &core.ResourceConstraint{ + Value: 100, + }, + NamespaceScopeResourceConstraint: &core.ResourceConstraint{ + Value: 50, + }, + }, + GrpcTokenKey: grpcTokenKey, + grpcEndpoint: "backend-plugin-system-grpc.flyte.svc.cluster.local:8000", + } + + configSection = pluginsConfig.MustRegisterSubSection("grpc", &defaultConfig) +) + +// Config is config for 'databricks' plugin +type Config struct { + // WebAPI defines config for the base WebAPI plugin + WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."` + + // ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time + ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."` + + GrpcTokenKey string `json:"grpcTokenKey" pflag:",Name of the key where to find grpc access token in the secret manager."` + + // grpcEndpoint overrides grpc server endpoint, only for testing + grpcEndpoint string +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) +} + +func SetConfig(cfg *Config) error { + return configSection.SetConfig(cfg) +} diff --git a/go/tasks/plugins/webapi/grpc/config_test.go b/go/tasks/plugins/webapi/grpc/config_test.go new file mode 100644 index 000000000..9e994f07f --- /dev/null +++ b/go/tasks/plugins/webapi/grpc/config_test.go @@ -0,0 +1,17 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGetAndSetConfig(t *testing.T) { + cfg := defaultConfig + cfg.WebAPI.Caching.Workers = 1 + cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second + err := SetConfig(&cfg) + assert.NoError(t, err) + assert.Equal(t, &cfg, GetConfig()) +} diff --git a/go/tasks/plugins/webapi/grpc/plugin.go b/go/tasks/plugins/webapi/grpc/plugin.go new file mode 100644 index 000000000..7a3030d88 --- /dev/null +++ b/go/tasks/plugins/webapi/grpc/plugin.go @@ -0,0 +1,165 @@ +package grpc + +import ( + "context" + "encoding/gob" + "fmt" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + pluginErrors "github.com/flyteorg/flyteplugins/go/tasks/errors" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + pluginsCore "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/webapi" + "github.com/flyteorg/flytestdlib/logger" + "github.com/flyteorg/flytestdlib/promutils" + "google.golang.org/grpc" + "time" +) + +type Plugin struct { + metricScope promutils.Scope + cfg *Config +} + +type ResourceWrapper struct { + State service.State + Message string +} + +type ResourceMetaWrapper struct { + OutputPrefix string + Token string + JobID string + TaskType string +} + +func (p Plugin) GetConfig() webapi.PluginConfig { + return GetConfig().WebAPI +} + +func (p Plugin) ResourceRequirements(_ context.Context, _ webapi.TaskExecutionContextReader) ( + namespace core.ResourceNamespace, constraints core.ResourceConstraintsSpec, err error) { + + // Resource requirements are assumed to be the same. + return "default", p.cfg.ResourceConstraints, nil +} + +func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextReader) (webapi.ResourceMeta, + webapi.Resource, error) { + taskTemplate, err := taskCtx.TaskReader().Read(ctx) + if err != nil { + return nil, nil, err + } + inputs, err := taskCtx.InputReader().Get(ctx) + if err != nil { + return nil, nil, err + } + + outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String() + + var opts []grpc.DialOption + opts = append(opts, grpc.WithInsecure()) + conn, err := grpc.Dial(p.cfg.grpcEndpoint, opts...) + if err != nil { + return nil, nil, fmt.Errorf("failed to connect backend plugin system") + } + defer conn.Close() + client := service.NewBackendPluginServiceClient(conn) + t := taskTemplate.Type + taskTemplate.Type = "dummy" // Dummy plugin is used to test performance + start := time.Now() + res, err := client.CreateTask(ctx, &service.TaskCreateRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix}) + logger.Infof(ctx, "grpc create request latency [%v]", time.Since(start).Round(time.Microsecond).String()) + taskTemplate.Type = t + if err != nil { + return nil, nil, err + } + + return &ResourceMetaWrapper{ + OutputPrefix: outputPrefix, + JobID: res.JobId, + Token: "", + TaskType: "dummy", + }, &ResourceWrapper{State: service.State_RUNNING}, nil +} + +func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) { + metadata := taskCtx.ResourceMeta().(*ResourceMetaWrapper) + prevState := service.State_RUNNING + if taskCtx.Resource() != nil { + resource := taskCtx.Resource().(*ResourceWrapper) + prevState = resource.State + } + + var opts []grpc.DialOption + opts = append(opts, grpc.WithInsecure()) + conn, err := grpc.Dial(p.cfg.grpcEndpoint, opts...) + if err != nil { + return nil, fmt.Errorf("failed to connect backend plugin system") + } + defer conn.Close() + + client := service.NewBackendPluginServiceClient(conn) + res, err := client.GetTask(ctx, &service.TaskGetRequest{TaskType: metadata.TaskType, JobId: metadata.JobID, OutputPrefix: metadata.OutputPrefix, PrevState: prevState}) + if err != nil { + return nil, err + } + + return &ResourceWrapper{ + State: res.State, + Message: res.Message, + }, nil +} + +func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error { + if taskCtx.ResourceMeta() == nil { + return nil + } + metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper) + + var opts []grpc.DialOption + opts = append(opts, grpc.WithInsecure()) + conn, err := grpc.Dial(p.cfg.grpcEndpoint, opts...) + if err != nil { + return fmt.Errorf("failed to connect backend plugin system") + } + defer conn.Close() + client := service.NewBackendPluginServiceClient(conn) + _, err = client.DeleteTask(ctx, &service.TaskDeleteRequest{TaskType: metadata.TaskType, JobId: metadata.JobID}) + return err +} + +func (p Plugin) Status(_ context.Context, taskCtx webapi.StatusContext) (phase core.PhaseInfo, err error) { + resource := taskCtx.Resource().(*ResourceWrapper) + taskInfo := &core.TaskInfo{} + + switch resource.State { + case service.State_RUNNING: + return core.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, taskInfo), nil + case service.State_FAILED: + return core.PhaseInfoFailure(resource.Message, "failed to run the job", taskInfo), nil + case service.State_SUCCEEDED: + return core.PhaseInfoSuccess(taskInfo), nil + } + return core.PhaseInfoUndefined, pluginErrors.Errorf(pluginsCore.SystemErrorCode, "unknown execution phase [%v].", resource.Message) +} + +func newGrpcPlugin() webapi.PluginEntry { + return webapi.PluginEntry{ + ID: "grpc", + SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task", "snowflake", "spark"}, + PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) { + return &Plugin{ + metricScope: iCtx.MetricsScope(), + cfg: GetConfig(), + }, nil + }, + } +} + +func init() { + gob.Register(ResourceMetaWrapper{}) + gob.Register(ResourceWrapper{}) + + pluginmachinery.PluginRegistry().RegisterRemotePlugin(newGrpcPlugin()) +} diff --git a/go/tasks/plugins/webapi/grpc/plugin_test.go b/go/tasks/plugins/webapi/grpc/plugin_test.go new file mode 100644 index 000000000..21e034e4c --- /dev/null +++ b/go/tasks/plugins/webapi/grpc/plugin_test.go @@ -0,0 +1 @@ +package grpc