diff --git a/Cargo.lock b/Cargo.lock index 25499d70e8..f0bbfafb30 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4935,6 +4935,7 @@ dependencies = [ "bytes 1.11.0", "chrono", "conditional-trait-gen", + "cookie", "criterion", "darling 0.20.11", "derive_more", diff --git a/Cargo.toml b/Cargo.toml index 48f39c2e49..bc9ceefc09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -145,6 +145,7 @@ combine = "4.6.7" conditional-trait-gen = "0.4.1" console-subscriber = "0.4.1" convert_case = "0.8.0" +cookie = "0.18.1" criterion = "0.5" crossterm = "0.28.1" darling = "0.20.11" diff --git a/cli/golem-cli/src/command_handler/api/deployment.rs b/cli/golem-cli/src/command_handler/api/deployment.rs index 8e772264f8..a1e5a8a71e 100644 --- a/cli/golem-cli/src/command_handler/api/deployment.rs +++ b/cli/golem-cli/src/command_handler/api/deployment.rs @@ -292,7 +292,7 @@ impl ApiDeploymentCommandHandler { &self, http_api_deployment: &DeploymentPlanHttpApiDeploymentEntry, deployable_http_api_deployment: &[HttpApiDefinitionName], - _diff: &diff::DiffForHashOf, + _diff: &diff::DiffForHashOf, ) -> anyhow::Result<()> { log_action( "Updating", diff --git a/cli/golem-cli/src/command_handler/app/mod.rs b/cli/golem-cli/src/command_handler/app/mod.rs index f8f42d1662..c6ba90345b 100644 --- a/cli/golem-cli/src/command_handler/app/mod.rs +++ b/cli/golem-cli/src/command_handler/app/mod.rs @@ -759,11 +759,11 @@ impl AppCommandHandler { let diffable_local_http_api_deployments = { let mut diffable_local_http_api_deployments = - BTreeMap::>::new(); + BTreeMap::>::new(); for (domain, http_api_deployment) in &deployable_manifest_http_api_deployments { diffable_local_http_api_deployments.insert( domain.0.clone(), - diff::HttpApiDeployment { + diff::HttpApiDeploymentLegacy { agent_types: http_api_deployment .iter() .map(|def| def.0.clone()) diff --git a/cli/golem/src/launch.rs b/cli/golem/src/launch.rs index fe3846e6ee..42c17a1a6d 100644 --- a/cli/golem/src/launch.rs +++ b/cli/golem/src/launch.rs @@ -40,7 +40,9 @@ use golem_worker_executor::services::golem_config::{ KeyValueStorageMultiSqliteConfig, ResourceLimitsConfig, ResourceLimitsGrpcConfig, ShardManagerServiceConfig, ShardManagerServiceGrpcConfig, WorkerServiceGrpcConfig, }; -use golem_worker_service::config::{RouteResolverConfig, WorkerServiceConfig}; +use golem_worker_service::config::{ + RouteResolverConfig, SqliteSessionStoreConfig, WorkerServiceConfig, +}; use golem_worker_service::WorkerService; use opentelemetry::global; use opentelemetry_sdk::metrics::MeterProviderBuilder; @@ -325,18 +327,21 @@ fn worker_service_config( port: 0, ..Default::default() }, - gateway_session_storage: golem_worker_service::config::GatewaySessionStorageConfig::Sqlite( - DbSqliteConfig { - database: args - .data_dir - .join("gateway-sessions.db") - .to_string_lossy() - .to_string(), - max_connections: 4, - foreign_keys: false, + gateway_session_storage: golem_worker_service::config::SessionStoreConfig::Sqlite( + SqliteSessionStoreConfig { + pending_login_expiration: Duration::from_hours(1), + cleanup_interval: Duration::from_mins(5), + sqlite_config: DbSqliteConfig { + database: args + .data_dir + .join("gateway-sessions.db") + .to_string_lossy() + .to_string(), + max_connections: 4, + foreign_keys: false, + }, }, ), - blob_storage: blob_storage_config(args), routing_table: RoutingTableConfig { host: args.router_addr.clone(), port: shard_manager_run_details.grpc_port, diff --git a/golem-api-grpc/proto/golem/customapi/core.proto b/golem-api-grpc/proto/golem/customapi/core.proto index ccfe0d23c8..120da4f50f 100644 --- a/golem-api-grpc/proto/golem/customapi/core.proto +++ b/golem-api-grpc/proto/golem/customapi/core.proto @@ -163,12 +163,13 @@ message CompiledRoute { RequestBodySchema body = 4; RouteBehaviour behavior = 5; optional golem.registry.SecuritySchemeId security_scheme = 6; - golem.component.CorsOptions cors = 7; + CorsOptions cors = 7; } message RouteBehaviour { oneof kind { CallAgent call_agent = 1; + CorsPreflight cors_preflight = 2; } message CallAgent { @@ -181,6 +182,11 @@ message RouteBehaviour { repeated MethodParameter method_parameters = 7; golem.component.DataSchema expected_agent_response = 8; } + + message CorsPreflight { + repeated string allowed_origins = 1; + repeated golem.component.HttpMethod allowed_methods = 2; + } } message SecuritySchemeDetails { @@ -192,3 +198,7 @@ message SecuritySchemeDetails { string redirect_url = 6; repeated string scopes = 7; } + +message CorsOptions { + repeated string allowed_patterns = 1; +} diff --git a/golem-common/src/base_model/agent.rs b/golem-common/src/base_model/agent.rs index 66b4eafdd1..2d5cbe3895 100644 --- a/golem-common/src/base_model/agent.rs +++ b/golem-common/src/base_model/agent.rs @@ -737,7 +737,9 @@ pub struct HttpEndpointDetails { pub cors_options: CorsOptions, } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, IntoValue, FromValue)] +#[derive( + Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, IntoValue, FromValue, +)] #[cfg_attr( feature = "full", derive(desert_rust::BinaryCodec, poem_openapi::Union) @@ -790,7 +792,9 @@ impl TryFrom for http::Method { } } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, IntoValue, FromValue)] +#[derive( + Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, IntoValue, FromValue, +)] #[cfg_attr( feature = "full", derive(desert_rust::BinaryCodec, poem_openapi::Object) diff --git a/golem-common/src/base_model/http_api_deployment.rs b/golem-common/src/base_model/http_api_deployment.rs index 00d29898d5..30dd54b97e 100644 --- a/golem-common/src/base_model/http_api_deployment.rs +++ b/golem-common/src/base_model/http_api_deployment.rs @@ -12,27 +12,38 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::security_scheme::SecuritySchemeName; use crate::base_model::agent::AgentTypeName; use crate::base_model::diff; use crate::base_model::domain_registration::Domain; use crate::base_model::environment::EnvironmentId; use crate::{declare_revision, declare_structs, newtype_uuid}; use chrono::DateTime; -use std::collections::BTreeSet; +use std::collections::BTreeMap; newtype_uuid!(HttpApiDeploymentId); declare_revision!(HttpApiDeploymentRevision); declare_structs! { + #[derive(Default)] + #[cfg_attr(feature = "full", derive(desert_rust::BinaryCodec))] + #[cfg_attr(feature = "full", desert(transparent))] + pub struct HttpApiDeploymentAgentOptions { + /// Security scheme to use for all agent methods that require auth. + /// Failure to provide a security scheme for an agent that requires one will lead to a deployment failure. + /// If the requested security scheme does not exist in the environment, the route will be disabled at runtime. + pub security_scheme: Option + } + pub struct HttpApiDeploymentCreation { pub domain: Domain, - pub agent_types: BTreeSet + pub agents: BTreeMap } pub struct HttpApiDeploymentUpdate { pub current_revision: HttpApiDeploymentRevision, - pub agent_types: Option> + pub agents: Option> } pub struct HttpApiDeployment { @@ -41,7 +52,7 @@ declare_structs! { pub environment_id: EnvironmentId, pub domain: Domain, pub hash: diff::Hash, - pub agent_types: BTreeSet, + pub agents: BTreeMap, pub created_at: DateTime, } } diff --git a/golem-common/src/base_model/mod.rs b/golem-common/src/base_model/mod.rs index aedaf74039..1682cc4485 100644 --- a/golem-common/src/base_model/mod.rs +++ b/golem-common/src/base_model/mod.rs @@ -157,7 +157,7 @@ impl FromValue for Timestamp { } } -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default)] #[cfg_attr( feature = "full", derive(desert_rust::BinaryCodec, poem_openapi::Object) diff --git a/golem-common/src/model/diff/deployment.rs b/golem-common/src/model/diff/deployment.rs index eadc3ef220..085ff69fdb 100644 --- a/golem-common/src/model/diff/deployment.rs +++ b/golem-common/src/model/diff/deployment.rs @@ -15,7 +15,7 @@ use crate::model::diff::component::Component; use crate::model::diff::hash::{hash_from_serialized_value, Hash, HashOf, Hashable}; use crate::model::diff::http_api_definition::HttpApiDefinition; -use crate::model::diff::http_api_deployment::HttpApiDeployment; +use crate::model::diff::http_api_deployment::HttpApiDeploymentLegacy; use crate::model::diff::ser::serialize_with_mode; use crate::model::diff::{BTreeMapDiff, Diffable}; use serde::Serialize; @@ -32,7 +32,7 @@ pub struct Deployment { pub http_api_definitions: BTreeMap>, #[serde(skip_serializing_if = "BTreeMap::is_empty")] #[serde(serialize_with = "serialize_with_mode")] - pub http_api_deployments: BTreeMap>, + pub http_api_deployments: BTreeMap>, } #[derive(Debug, Clone, PartialEq, Serialize)] @@ -43,7 +43,7 @@ pub struct DeploymentDiff { #[serde(skip_serializing_if = "BTreeMap::is_empty")] pub http_api_definitions: BTreeMapDiff>, #[serde(skip_serializing_if = "BTreeMap::is_empty")] - pub http_api_deployments: BTreeMapDiff>, + pub http_api_deployments: BTreeMapDiff>, } impl Diffable for Deployment { diff --git a/golem-common/src/model/diff/http_api_deployment.rs b/golem-common/src/model/diff/http_api_deployment.rs index 146449044d..cb9366660d 100644 --- a/golem-common/src/model/diff/http_api_deployment.rs +++ b/golem-common/src/model/diff/http_api_deployment.rs @@ -12,13 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. +use super::BTreeMapDiff; use crate::model::diff::{hash_from_serialized_value, BTreeSetDiff, Diffable, Hash, Hashable}; use serde::Serialize; -use std::collections::BTreeSet; +use std::collections::{BTreeMap, BTreeSet}; + +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct HttpApiDeploymentAgentOptions { + pub security_scheme: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct HttpApiDeploymentAgentOptionsDiff { + pub security_scheme_changed: bool, +} + +impl Diffable for HttpApiDeploymentAgentOptions { + type DiffResult = HttpApiDeploymentAgentOptionsDiff; + + fn diff(new: &Self, current: &Self) -> Option { + let security_scheme_changed = new.security_scheme != current.security_scheme; + + if security_scheme_changed { + Some(HttpApiDeploymentAgentOptionsDiff { + security_scheme_changed, + }) + } else { + None + } + } +} #[derive(Debug, Clone, PartialEq, Serialize)] pub struct HttpApiDeployment { - pub agent_types: BTreeSet, + pub agents: BTreeMap, } impl Hashable for HttpApiDeployment { @@ -28,6 +57,25 @@ impl Hashable for HttpApiDeployment { } impl Diffable for HttpApiDeployment { + type DiffResult = BTreeMapDiff; + + fn diff(new: &Self, current: &Self) -> Option { + new.agents.diff_with_current(¤t.agents) + } +} + +#[derive(Debug, Clone, PartialEq, Serialize)] +pub struct HttpApiDeploymentLegacy { + pub agent_types: BTreeSet, +} + +impl Hashable for HttpApiDeploymentLegacy { + fn hash(&self) -> Hash { + hash_from_serialized_value(self) + } +} + +impl Diffable for HttpApiDeploymentLegacy { type DiffResult = BTreeSetDiff; fn diff(new: &Self, current: &Self) -> Option { diff --git a/golem-common/src/model/http_api_deployment.rs b/golem-common/src/model/http_api_deployment.rs index 764ddde324..7034ec0182 100644 --- a/golem-common/src/model/http_api_deployment.rs +++ b/golem-common/src/model/http_api_deployment.rs @@ -19,7 +19,18 @@ pub use crate::base_model::http_api_deployment::*; impl HttpApiDeployment { pub fn to_diffable(&self) -> diff::HttpApiDeployment { diff::HttpApiDeployment { - agent_types: self.agent_types.iter().map(|def| def.0.clone()).collect(), + agents: self + .agents + .iter() + .map(|(k, v)| { + ( + k.0.clone(), + diff::HttpApiDeploymentAgentOptions { + security_scheme: v.security_scheme.as_ref().map(|v| v.0.clone()), + }, + ) + }) + .collect(), } } } diff --git a/golem-common/src/model/http_api_deployment_legacy.rs b/golem-common/src/model/http_api_deployment_legacy.rs index d96c287a4e..94adc1750a 100644 --- a/golem-common/src/model/http_api_deployment_legacy.rs +++ b/golem-common/src/model/http_api_deployment_legacy.rs @@ -17,8 +17,8 @@ use crate::model::diff; pub use crate::base_model::http_api_deployment_legacy::*; impl LegacyHttpApiDeployment { - pub fn to_diffable(&self) -> diff::HttpApiDeployment { - diff::HttpApiDeployment { + pub fn to_diffable(&self) -> diff::HttpApiDeploymentLegacy { + diff::HttpApiDeploymentLegacy { agent_types: self .api_definitions .iter() diff --git a/golem-registry-service/db/migration/postgres/002_code_first_routes.sql b/golem-registry-service/db/migration/postgres/002_code_first_routes.sql index 00887d7cec..577d2f25f0 100644 --- a/golem-registry-service/db/migration/postgres/002_code_first_routes.sql +++ b/golem-registry-service/db/migration/postgres/002_code_first_routes.sql @@ -17,7 +17,10 @@ DELETE FROM component_plugin_installations; DELETE FROM component_revisions; DELETE FROM components; -ALTER TABLE http_api_deployment_revisions RENAME COLUMN http_api_definitions TO agent_types; +ALTER TABLE http_api_deployment_revisions RENAME COLUMN http_api_definitions TO data; +ALTER TABLE http_api_deployment_revisions + ALTER COLUMN data TYPE BYTEA + USING data::bytea; DROP TABLE deployment_compiled_http_api_definition_routes; DROP TABLE deployment_domain_http_api_definitions; diff --git a/golem-registry-service/db/migration/sqlite/002_code_first_routes.sql b/golem-registry-service/db/migration/sqlite/002_code_first_routes.sql index 00887d7cec..95d248f48d 100644 --- a/golem-registry-service/db/migration/sqlite/002_code_first_routes.sql +++ b/golem-registry-service/db/migration/sqlite/002_code_first_routes.sql @@ -17,7 +17,30 @@ DELETE FROM component_plugin_installations; DELETE FROM component_revisions; DELETE FROM components; -ALTER TABLE http_api_deployment_revisions RENAME COLUMN http_api_definitions TO agent_types; +DROP TABLE http_api_deployment_revisions; + +CREATE TABLE http_api_deployment_revisions +( + http_api_deployment_id UUID NOT NULL, + revision_id BIGINT NOT NULL, + + hash BYTEA NOT NULL, + + created_at TIMESTAMP NOT NULL, + created_by UUID NOT NULL, + deleted BOOLEAN NOT NULL, + + data BYTEA NOT NULL, + + CONSTRAINT http_api_deployment_revisions_pk + PRIMARY KEY (http_api_deployment_id, revision_id), + CONSTRAINT http_api_deployment_revisions_deployments_fk + FOREIGN KEY (http_api_deployment_id) + REFERENCES http_api_deployments +); + +CREATE INDEX http_api_deployment_revisions_latest_revision_by_id_idx + ON http_api_deployment_revisions (http_api_deployment_id, revision_id DESC); DROP TABLE deployment_compiled_http_api_definition_routes; DROP TABLE deployment_domain_http_api_definitions; diff --git a/golem-registry-service/src/model/api_definition.rs b/golem-registry-service/src/model/api_definition.rs index 6afde2bf7c..0a4b86d747 100644 --- a/golem-registry-service/src/model/api_definition.rs +++ b/golem-registry-service/src/model/api_definition.rs @@ -14,13 +14,13 @@ use desert_rust::BinaryCodec; use golem_common::model::account::AccountId; -use golem_common::model::agent::{CorsOptions, HttpMethod}; +use golem_common::model::agent::HttpMethod; use golem_common::model::deployment::DeploymentRevision; use golem_common::model::domain_registration::Domain; use golem_common::model::environment::EnvironmentId; use golem_common::model::security_scheme::{SecuritySchemeId, SecuritySchemeName}; use golem_service_base::custom_api::{ - PathSegment, RequestBodySchema, RouteBehaviour, RouteId, SecuritySchemeDetails, + CorsOptions, PathSegment, RequestBodySchema, RouteBehaviour, RouteId, SecuritySchemeDetails, }; use std::collections::HashMap; diff --git a/golem-registry-service/src/repo/http_api_deployment.rs b/golem-registry-service/src/repo/http_api_deployment.rs index 8b2e4cc28e..0fbf369ee0 100644 --- a/golem-registry-service/src/repo/http_api_deployment.rs +++ b/golem-registry-service/src/repo/http_api_deployment.rs @@ -275,19 +275,21 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { domain: &str, revision: HttpApiDeploymentRevisionRecord, ) -> Result { - let opt_deleted_revision: Option = - self.with_ro("create - get opt deleted").fetch_optional_as( + let opt_deleted_revision: Option = self + .with_ro("create - get opt deleted") + .fetch_optional_as( sqlx::query_as(indoc! { r#" - SELECT h.http_api_deployment_id, h.domain, hr.revision_id, hr.hash, hr.agent_types + SELECT h.http_api_deployment_id, h.domain, hr.revision_id, hr.hash, hr.data FROM http_api_deployments h JOIN http_api_deployment_revisions hr ON h.http_api_deployment_id = hr.http_api_deployment_id AND h.current_revision_id = hr.revision_id WHERE environment_id = $1 AND domain = $2 AND deleted_at IS NOT NULL "#}) - .bind(environment_id) - .bind(domain) - ).await?; + .bind(environment_id) + .bind(domain), + ) + .await?; if let Some(deleted_revision) = opt_deleted_revision { let recreated_revision = revision.for_recreation( @@ -419,7 +421,7 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { .fetch_optional_as( sqlx::query_as(indoc! { r#" SELECT d.environment_id, d.domain, dr.http_api_deployment_id, - dr.revision_id, dr.hash, dr.agent_types, + dr.revision_id, dr.hash, dr.data, dr.created_at, dr.created_by, dr.deleted, d.created_at as entity_created_at FROM http_api_deployments d @@ -441,7 +443,7 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { .fetch_optional_as( sqlx::query_as(indoc! { r#" SELECT d.environment_id, d.domain, dr.http_api_deployment_id, - dr.revision_id, dr.hash, dr.agent_types, + dr.revision_id, dr.hash, dr.data, dr.created_at, dr.created_by, dr.deleted, d.created_at as entity_created_at FROM http_api_deployments d @@ -464,7 +466,7 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { .fetch_optional_as( sqlx::query_as(indoc! { r#" SELECT d.environment_id, d.domain, dr.http_api_deployment_id, - dr.revision_id, dr.hash, dr.agent_types, + dr.revision_id, dr.hash, dr.data, dr.created_at, dr.created_by, dr.deleted, d.created_at as entity_created_at FROM http_api_deployments d @@ -486,7 +488,7 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { .fetch_all_as( sqlx::query_as(indoc! { r#" SELECT d.environment_id, d.domain, dr.http_api_deployment_id, - dr.revision_id, dr.hash, dr.agent_types, + dr.revision_id, dr.hash, dr.data, dr.created_at, dr.created_by, dr.deleted, d.created_at as entity_created_at FROM http_api_deployments d @@ -509,7 +511,7 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { .fetch_all_as( sqlx::query_as(indoc! { r#" SELECT had.environment_id, had.domain, hadr.http_api_deployment_id, - hadr.revision_id, hadr.hash, hadr.agent_types, + hadr.revision_id, hadr.hash, hadr.data, hadr.created_at, hadr.created_by, hadr.deleted, had.created_at as entity_created_at FROM http_api_deployments had @@ -536,7 +538,7 @@ impl HttpApiDeploymentRepo for DbHttpApiDeploymentRepo { .fetch_optional_as( sqlx::query_as(indoc! { r#" SELECT had.environment_id, had.domain, hadr.http_api_deployment_id, - hadr.revision_id, hadr.hash, hadr.agent_types, + hadr.revision_id, hadr.hash, hadr.data, hadr.created_at, hadr.created_by, hadr.deleted, had.created_at as entity_created_at FROM http_api_deployments had @@ -578,15 +580,15 @@ impl HttpApiDeploymentRepoInternal for DbHttpApiDeploymentRepo { tx.fetch_one_as( sqlx::query_as(indoc! { r#" INSERT INTO http_api_deployment_revisions - (http_api_deployment_id, revision_id, agent_types, + (http_api_deployment_id, revision_id, data, hash, created_at, created_by, deleted) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING http_api_deployment_id, revision_id, hash, - created_at, created_by, deleted, agent_types + created_at, created_by, deleted, data "# }) .bind(revision.http_api_deployment_id) .bind(revision.revision_id) - .bind(revision.agent_types) + .bind(&revision.data) .bind(revision.hash) .bind_deletable_revision_audit(revision.audit), ) diff --git a/golem-registry-service/src/repo/model/http_api_deployment.rs b/golem-registry-service/src/repo/model/http_api_deployment.rs index 510ba0f75e..6fa968fb31 100644 --- a/golem-registry-service/src/repo/model/http_api_deployment.rs +++ b/golem-registry-service/src/repo/model/http_api_deployment.rs @@ -15,6 +15,7 @@ use super::datetime::SqlDateTime; use crate::repo::model::audit::{AuditFields, DeletableRevisionAuditFields}; use crate::repo::model::hash::SqlBlake3Hash; +use desert_rust::BinaryCodec; use golem_common::error_forwarding; use golem_common::model::account::AccountId; use golem_common::model::agent::AgentTypeName; @@ -24,13 +25,13 @@ use golem_common::model::diff::Hashable; use golem_common::model::domain_registration::Domain; use golem_common::model::environment::EnvironmentId; use golem_common::model::http_api_deployment::{ - HttpApiDeployment, HttpApiDeploymentId, HttpApiDeploymentRevision, + HttpApiDeployment, HttpApiDeploymentAgentOptions, HttpApiDeploymentId, + HttpApiDeploymentRevision, }; use golem_service_base::repo::RepoError; -use sqlx::encode::IsNull; -use sqlx::error::BoxDynError; -use sqlx::{Database, FromRow}; -use std::collections::BTreeSet; +use golem_service_base::repo::blob::Blob; +use sqlx::FromRow; +use std::collections::BTreeMap; use uuid::Uuid; #[derive(Debug, thiserror::Error)] @@ -45,52 +46,10 @@ pub enum HttpApiDeploymentRepoError { error_forwarding!(HttpApiDeploymentRepoError, RepoError); -// stored as string containing a json array -#[derive(Debug, Clone, PartialEq)] -pub struct AgentTypeSet(pub BTreeSet); - -impl sqlx::Type for AgentTypeSet -where - String: sqlx::Type, -{ - fn type_info() -> DB::TypeInfo { - >::type_info() - } - - fn compatible(ty: &DB::TypeInfo) -> bool { - >::compatible(ty) - } -} - -impl<'q, DB: Database> sqlx::Encode<'q, DB> for AgentTypeSet -where - String: sqlx::Encode<'q, DB>, -{ - fn encode_by_ref( - &self, - buf: &mut ::ArgumentBuffer<'q>, - ) -> Result { - let serialized = serde_json::to_string(&self.0)?; - serialized.encode(buf) - } - - fn size_hint(&self) -> usize { - match serde_json::to_string(&self.0) { - Ok(string) => string.size_hint(), - Err(_) => 0, - } - } -} - -impl<'r, DB: Database> sqlx::Decode<'r, DB> for AgentTypeSet -where - &'r str: sqlx::Decode<'r, DB>, -{ - fn decode(value: ::ValueRef<'r>) -> Result { - let deserialized: BTreeSet = - serde_json::from_str(<&'r str>::decode(value)?)?; - Ok(Self(deserialized)) - } +#[derive(Debug, Clone, PartialEq, BinaryCodec)] +#[desert(evolution())] +pub struct HttpApiDeploymentData { + pub agents: BTreeMap, } #[derive(Debug, Clone, FromRow, PartialEq)] @@ -111,8 +70,7 @@ pub struct HttpApiDeploymentRevisionRecord { #[sqlx(flatten)] pub audit: DeletableRevisionAuditFields, - // json string array as string - pub agent_types: AgentTypeSet, + pub data: Blob, } impl HttpApiDeploymentRevisionRecord { @@ -132,7 +90,7 @@ impl HttpApiDeploymentRevisionRecord { pub fn creation( http_api_deployment_id: HttpApiDeploymentId, - agent_types: BTreeSet, + agents: BTreeMap, actor: AccountId, ) -> Self { let mut value = Self { @@ -140,7 +98,7 @@ impl HttpApiDeploymentRevisionRecord { revision_id: HttpApiDeploymentRevision::INITIAL.into(), hash: SqlBlake3Hash::empty(), audit: DeletableRevisionAuditFields::new(actor.0), - agent_types: AgentTypeSet(agent_types), + data: Blob::new(HttpApiDeploymentData { agents }), }; value.update_hash(); value @@ -152,7 +110,9 @@ impl HttpApiDeploymentRevisionRecord { revision_id: value.revision.into(), hash: SqlBlake3Hash::empty(), audit, - agent_types: AgentTypeSet(value.agent_types), + data: Blob::new(HttpApiDeploymentData { + agents: value.agents, + }), }; value.update_hash(); value @@ -168,13 +128,28 @@ impl HttpApiDeploymentRevisionRecord { revision_id: current_revision_id, hash: SqlBlake3Hash::empty(), audit: DeletableRevisionAuditFields::deletion(created_by), - agent_types: AgentTypeSet(BTreeSet::new()), + data: Blob::new(HttpApiDeploymentData { + agents: BTreeMap::new(), + }), } } pub fn to_diffable(&self) -> diff::HttpApiDeployment { diff::HttpApiDeployment { - agent_types: self.agent_types.0.iter().map(|had| had.0.clone()).collect(), + agents: self + .data + .value() + .agents + .iter() + .map(|(k, v)| { + ( + k.0.clone(), + diff::HttpApiDeploymentAgentOptions { + security_scheme: v.security_scheme.as_ref().map(|v| v.0.clone()), + }, + ) + }) + .collect(), } } @@ -208,7 +183,7 @@ impl TryFrom for HttpApiDeployment { environment_id: EnvironmentId(value.environment_id), domain: Domain(value.domain), hash: value.revision.hash.into(), - agent_types: value.revision.agent_types.0, + agents: value.revision.data.into_value().agents, created_at: value.entity_created_at.into(), }) } diff --git a/golem-registry-service/src/services/deployment/deployment_context.rs b/golem-registry-service/src/services/deployment/deployment_context.rs index 0e56ef572c..7cf1d4bf62 100644 --- a/golem-registry-service/src/services/deployment/deployment_context.rs +++ b/golem-registry-service/src/services/deployment/deployment_context.rs @@ -19,6 +19,7 @@ use super::http_parameter_conversion::{ use crate::model::api_definition::UnboundCompiledRoute; use crate::model::component::Component; use crate::services::deployment::write::DeployValidationError; +use golem_common::model::Empty; use golem_common::model::agent::wit_naming::ToWitNaming; use golem_common::model::agent::{ AgentMethod, AgentType, AgentTypeName, DataSchema, ElementSchema, HttpEndpointDetails, @@ -28,10 +29,13 @@ use golem_common::model::agent::{ use golem_common::model::component::ComponentName; use golem_common::model::diff::{self, HashOf, Hashable}; use golem_common::model::domain_registration::Domain; -use golem_common::model::http_api_deployment::HttpApiDeployment; -use golem_service_base::custom_api::{ConstructorParameter, PathSegment, RouteBehaviour}; +use golem_common::model::http_api_deployment::{HttpApiDeployment, HttpApiDeploymentAgentOptions}; +use golem_service_base::custom_api::{ + CallAgentBehaviour, ConstructorParameter, CorsOptions, CorsPreflightBehaviour, OriginPattern, + PathSegment, RequestBodySchema, RouteBehaviour, +}; use itertools::Itertools; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, BTreeSet, HashMap}; macro_rules! ok_or_continue { ($expr:expr, $errors:ident) => {{ @@ -121,7 +125,7 @@ impl DeploymentContext { let mut errors = Vec::new(); for deployment in self.http_api_deployments.values() { - for agent_type in &deployment.agent_types { + for (agent_type, agent_options) in &deployment.agents { let registered_agent_type = ok_or_continue!( registered_agent_types.get(agent_type).ok_or( DeployValidationError::HttpApiDeploymentMissingAgentType { @@ -156,6 +160,7 @@ impl DeploymentContext { http_mount, ®istered_agent_type.agent_type.methods, constructor_parameters, + agent_options, &mut errors, ); @@ -184,9 +189,27 @@ impl DeploymentContext { http_mount: &HttpMountDetails, agent_methods: &[AgentMethod], constructor_parameters: Vec, + agent_options: &HttpApiDeploymentAgentOptions, errors: &mut Vec, ) -> Vec { - let mut compiled_routes = Vec::new(); + let mut compiled_routes: HashMap<(HttpMethod, Vec), UnboundCompiledRoute> = + HashMap::new(); + + struct PreflightMapEntry { + allowed_methods: BTreeSet, + allowed_origins: BTreeSet, + } + + impl PreflightMapEntry { + fn new() -> Self { + PreflightMapEntry { + allowed_methods: BTreeSet::new(), + allowed_origins: BTreeSet::new(), + } + } + } + + let mut preflight_map: HashMap, PreflightMapEntry> = HashMap::new(); for agent_method in agent_methods { for http_endpoint in &agent_method.http_endpoint { @@ -198,12 +221,33 @@ impl DeploymentContext { agent_method, ); - let cors = if !http_endpoint.cors_options.allowed_patterns.is_empty() { - http_endpoint.cors_options.clone() - } else { - http_mount.cors_options.clone() + let mut cors = CorsOptions { + allowed_patterns: vec![], }; + if !http_mount.cors_options.allowed_patterns.is_empty() { + cors.allowed_patterns.extend( + http_mount + .cors_options + .allowed_patterns + .iter() + .cloned() + .map(OriginPattern), + ); + } + if !http_endpoint.cors_options.allowed_patterns.is_empty() { + cors.allowed_patterns.extend( + http_endpoint + .cors_options + .allowed_patterns + .iter() + .cloned() + .map(OriginPattern), + ); + } + cors.allowed_patterns.sort(); + cors.allowed_patterns.dedup(); + let route_id = *current_route_id; *current_route_id = current_route_id.checked_add(1).unwrap(); @@ -225,7 +269,7 @@ impl DeploymentContext { errors ); - let path = http_mount + let path_segments: Vec = http_mount .path_prefix .iter() .cloned() @@ -233,13 +277,52 @@ impl DeploymentContext { .map(|p| compile_agent_path_segment(agent, implementer, p)) .collect(); + if !cors.allowed_patterns.is_empty() { + let entry = preflight_map + .entry(path_segments.clone()) + .or_insert(PreflightMapEntry::new()); + + entry + .allowed_methods + .insert(http_endpoint.http_method.clone()); + for allowed_pattern in &cors.allowed_patterns { + entry.allowed_origins.insert(allowed_pattern.clone()); + } + } + + let mut auth_required = false; + if let Some(auth_details) = &http_mount.auth_details { + auth_required = auth_details.required; + } + if let Some(auth_details) = &http_endpoint.auth_details { + auth_required = auth_details.required; + } + + let security_scheme = if auth_required { + let security_scheme = ok_or_continue!( + agent_options.security_scheme.clone().ok_or( + DeployValidationError::NoSecuritySchemeConfigured( + agent.type_name.clone() + ) + ), + errors + ); + + Some(security_scheme) + } else { + None + }; + + // TODO: check whether a security scheme with this name currently exists in the environment + // and emit a warning to the cli if it doesn't. + let compiled = UnboundCompiledRoute { route_id, domain: deployment.domain.clone(), method: http_endpoint.http_method.clone(), - path, + path: path_segments.clone(), body, - behaviour: RouteBehaviour::CallAgent { + behaviour: RouteBehaviour::CallAgent(CallAgentBehaviour { component_id: implementer.component_id, component_revision: implementer.component_revision, agent_type: agent.type_name.clone(), @@ -248,16 +331,65 @@ impl DeploymentContext { constructor_parameters: constructor_parameters.clone(), method_parameters, expected_agent_response: agent_method.output_schema.clone(), - }, - security_scheme: None, + }), + security_scheme, cors, }; - compiled_routes.push(compiled); + { + let key = (http_endpoint.http_method.clone(), path_segments); + if let std::collections::hash_map::Entry::Vacant(e) = compiled_routes.entry(key) + { + e.insert(compiled); + } else { + errors.push(make_route_validation_error( + "Duplicate route detected".into(), + )); + } + } + } + } + + // Generate synthetic OPTIONS routes for preflight requests + for ( + path_segments, + PreflightMapEntry { + allowed_methods, + allowed_origins, + }, + ) in preflight_map + { + let key = (HttpMethod::Options(Empty {}), path_segments.clone()); + if compiled_routes.contains_key(&key) { + // Skip synthetic OPTIONS if user already defined one + // TODO: Emit to the cli as warning + continue; } + + let route_id = *current_route_id; + *current_route_id = current_route_id.checked_add(1).unwrap(); + + compiled_routes.insert( + key, + UnboundCompiledRoute { + route_id, + domain: deployment.domain.clone(), + method: HttpMethod::Options(Empty {}), + path: path_segments, + body: RequestBodySchema::Unused, + behaviour: RouteBehaviour::CorsPreflight(CorsPreflightBehaviour { + allowed_origins, + allowed_methods, + }), + security_scheme: None, + cors: CorsOptions { + allowed_patterns: vec![], + }, + }, + ); } - compiled_routes + compiled_routes.into_values().collect() } } diff --git a/golem-registry-service/src/services/deployment/write.rs b/golem-registry-service/src/services/deployment/write.rs index a04a00e4c8..ab377f7c4d 100644 --- a/golem-registry-service/src/services/deployment/write.rs +++ b/golem-registry-service/src/services/deployment/write.rs @@ -107,6 +107,8 @@ pub enum DeployValidationError { ComponentNotFound(ComponentName), #[error("Agent type name {0} is provided by multiple components")] AmbiguousAgentTypeName(AgentTypeName), + #[error("No security scheme configured for agent {0} but agent has methods that require auth")] + NoSecuritySchemeConfigured(AgentTypeName), #[error( "Method {agent_method} of agent {agent_type} used by http api at {method} {domain}/{path} is invalid: {error}" )] diff --git a/golem-registry-service/src/services/http_api_deployment.rs b/golem-registry-service/src/services/http_api_deployment.rs index 1c22d978e9..a3726d94a6 100644 --- a/golem-registry-service/src/services/http_api_deployment.rs +++ b/golem-registry-service/src/services/http_api_deployment.rs @@ -135,8 +135,7 @@ impl HttpApiDeploymentService { })?; let id = HttpApiDeploymentId::new(); - let record = - HttpApiDeploymentRevisionRecord::creation(id, data.agent_types, auth.account_id()); + let record = HttpApiDeploymentRevisionRecord::creation(id, data.agents, auth.account_id()); let stored_http_api_deployment: HttpApiDeployment = self .http_api_deployment_repo @@ -200,8 +199,8 @@ impl HttpApiDeploymentService { }; http_api_deployment.revision = http_api_deployment.revision.next()?; - if let Some(api_definitions) = update.agent_types { - http_api_deployment.agent_types = api_definitions; + if let Some(api_definitions) = update.agents { + http_api_deployment.agents = api_definitions; }; let record = HttpApiDeploymentRevisionRecord::from_model( diff --git a/golem-registry-service/tests/repo/common.rs b/golem-registry-service/tests/repo/common.rs index 97d122a028..dda34f3c51 100644 --- a/golem-registry-service/tests/repo/common.rs +++ b/golem-registry-service/tests/repo/common.rs @@ -19,6 +19,7 @@ use futures::future::join_all; use golem_common::model::agent::AgentTypeName; use golem_common::model::component::ComponentFilePermissions; use golem_common::model::component_metadata::ComponentMetadata; +use golem_common::model::http_api_deployment::HttpApiDeploymentAgentOptions; use golem_registry_service::repo::environment::EnvironmentRevisionRecord; use golem_registry_service::repo::model::account::{ AccountExtRevisionRecord, AccountRepoError, AccountRevisionRecord, @@ -37,7 +38,7 @@ use golem_registry_service::repo::model::datetime::SqlDateTime; use golem_registry_service::repo::model::environment::EnvironmentRepoError; use golem_registry_service::repo::model::hash::SqlBlake3Hash; use golem_registry_service::repo::model::http_api_deployment::{ - AgentTypeSet, HttpApiDeploymentRepoError, HttpApiDeploymentRevisionRecord, + HttpApiDeploymentData, HttpApiDeploymentRepoError, HttpApiDeploymentRevisionRecord, }; use golem_registry_service::repo::model::new_repo_uuid; use golem_registry_service::repo::model::plugin::PluginRecord; @@ -816,7 +817,12 @@ pub async fn test_http_api_deployment_stage(deps: &Deps) { revision_id: 0, hash: SqlBlake3Hash::empty(), audit: DeletableRevisionAuditFields::new(user.revision.account_id), - agent_types: AgentTypeSet([AgentTypeName("test-agent".to_string())].into()), + data: Blob::new(HttpApiDeploymentData { + agents: BTreeMap::from_iter([( + AgentTypeName("test-agent".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), + }), } .with_updated_hash(); diff --git a/golem-service-base/src/custom_api/mod.rs b/golem-service-base/src/custom_api/mod.rs index 2c9e8b64c5..7b012da3df 100644 --- a/golem-service-base/src/custom_api/mod.rs +++ b/golem-service-base/src/custom_api/mod.rs @@ -18,7 +18,7 @@ mod protobuf; use crate::model::SafeIndex; use desert_rust::BinaryCodec; use golem_common::model::account::AccountId; -use golem_common::model::agent::{AgentTypeName, CorsOptions, DataSchema, HttpMethod}; +use golem_common::model::agent::{AgentTypeName, DataSchema, HttpMethod}; use golem_common::model::component::{ComponentId, ComponentRevision}; use golem_common::model::deployment::DeploymentRevision; use golem_common::model::environment::EnvironmentId; @@ -26,12 +26,12 @@ use golem_common::model::security_scheme::{Provider, SecuritySchemeId, SecurityS use golem_wasm::analysis::analysed_type; use golem_wasm::analysis::{AnalysedType, TypeList, TypeOption}; use openidconnect::{ClientId, ClientSecret, RedirectUrl, Scope}; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::fmt; pub type RouteId = i32; -#[derive(Debug, Clone, BinaryCodec)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, BinaryCodec)] #[desert(evolution())] pub enum PathSegment { Literal { value: String }, @@ -239,16 +239,28 @@ pub struct CompiledRoute { #[derive(Debug, BinaryCodec)] #[desert(evolution())] pub enum RouteBehaviour { - CallAgent { - component_id: ComponentId, - component_revision: ComponentRevision, - agent_type: AgentTypeName, - constructor_parameters: Vec, - phantom: bool, - method_name: String, - method_parameters: Vec, - expected_agent_response: DataSchema, - }, + CallAgent(CallAgentBehaviour), + CorsPreflight(CorsPreflightBehaviour), +} + +#[derive(Debug, BinaryCodec)] +#[desert(evolution())] +pub struct CallAgentBehaviour { + pub component_id: ComponentId, + pub component_revision: ComponentRevision, + pub agent_type: AgentTypeName, + pub constructor_parameters: Vec, + pub phantom: bool, + pub method_name: String, + pub method_parameters: Vec, + pub expected_agent_response: DataSchema, +} + +#[derive(Debug, BinaryCodec)] +#[desert(evolution())] +pub struct CorsPreflightBehaviour { + pub allowed_origins: BTreeSet, + pub allowed_methods: BTreeSet, } #[derive(Debug, Clone)] @@ -261,3 +273,46 @@ pub struct SecuritySchemeDetails { pub redirect_url: RedirectUrl, pub scopes: Vec, } + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, BinaryCodec)] +#[desert(evolution())] +pub struct CorsOptions { + pub allowed_patterns: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, BinaryCodec)] +#[desert(transparent)] +// Note: Wildcards are only considered during matching. When setting the allow-origin header +// always use the exact origin that made the request to avoid complications with +// presence of auth information. +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin#sect +pub struct OriginPattern(pub String); + +impl OriginPattern { + // match origin according to cors spec https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin + pub fn matches(&self, origin: &str) -> bool { + if self.0 == "*" { + true + } else { + self.0 == origin + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use test_r::test; + + #[test] + fn test_origin_pattern_matches() { + let wildcard = OriginPattern("*".to_string()); + assert!(wildcard.matches("https://example.com")); + assert!(wildcard.matches("https://foo.bar")); + + let exact = OriginPattern("https://example.com".to_string()); + assert!(exact.matches("https://example.com")); + assert!(!exact.matches("https://other.com")); + assert!(!exact.matches("http://example.com")); // scheme matters + } +} diff --git a/golem-service-base/src/custom_api/protobuf.rs b/golem-service-base/src/custom_api/protobuf.rs index 686602049c..acf51d37c2 100644 --- a/golem-service-base/src/custom_api/protobuf.rs +++ b/golem-service-base/src/custom_api/protobuf.rs @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::SecuritySchemeDetails; use super::{CompiledRoute, CompiledRoutes}; +use super::{CorsOptions, SecuritySchemeDetails}; use super::{PathSegment, PathSegmentType, RequestBodySchema, RouteBehaviour}; -use crate::custom_api::{ConstructorParameter, MethodParameter, QueryOrHeaderType}; +use crate::custom_api::{ + CallAgentBehaviour, ConstructorParameter, CorsPreflightBehaviour, MethodParameter, + OriginPattern, QueryOrHeaderType, +}; use golem_api_grpc::proto; -use golem_common::model::agent::AgentTypeName; +use golem_common::model::agent::{AgentTypeName, HttpMethod}; use golem_common::model::security_scheme::SecuritySchemeName; use golem_wasm::analysis::TypeEnum; use openidconnect::{ClientId, ClientSecret, RedirectUrl, Scope}; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::ops::Deref; impl TryFrom for SecuritySchemeDetails { @@ -158,7 +161,7 @@ impl TryFrom for RouteBehaviour { use proto::golem::customapi::route_behaviour::Kind; match value.kind.ok_or("RouteBehaviour.kind missing")? { - Kind::CallAgent(agent) => Ok(RouteBehaviour::CallAgent { + Kind::CallAgent(agent) => Ok(RouteBehaviour::CallAgent(CallAgentBehaviour { component_id: agent .component_id .ok_or("Missing component_id")? @@ -181,7 +184,21 @@ impl TryFrom for RouteBehaviour { .expected_agent_response .ok_or("Missing expected_agent_response")? .try_into()?, - }), + })), + Kind::CorsPreflight(cors_preflight) => { + Ok(RouteBehaviour::CorsPreflight(CorsPreflightBehaviour { + allowed_origins: cors_preflight + .allowed_origins + .into_iter() + .map(OriginPattern) + .collect(), + allowed_methods: cors_preflight + .allowed_methods + .into_iter() + .map(HttpMethod::try_from) + .collect::, _>>()?, + })) + } } } } @@ -191,7 +208,7 @@ impl From for proto::golem::customapi::RouteBehaviour { use proto::golem::customapi::route_behaviour::Kind; match value { - RouteBehaviour::CallAgent { + RouteBehaviour::CallAgent(CallAgentBehaviour { component_id, component_revision, agent_type, @@ -200,7 +217,7 @@ impl From for proto::golem::customapi::RouteBehaviour { method_name, method_parameters, expected_agent_response, - } => Self { + }) => Self { kind: Some(Kind::CallAgent( proto::golem::customapi::route_behaviour::CallAgent { component_id: Some(component_id.into()), @@ -217,6 +234,20 @@ impl From for proto::golem::customapi::RouteBehaviour { }, )), }, + RouteBehaviour::CorsPreflight(CorsPreflightBehaviour { + allowed_origins, + allowed_methods, + }) => Self { + kind: Some(Kind::CorsPreflight( + proto::golem::customapi::route_behaviour::CorsPreflight { + allowed_origins: allowed_origins.into_iter().map(|ao| ao.0).collect(), + allowed_methods: allowed_methods + .into_iter() + .map(proto::golem::component::HttpMethod::from) + .collect(), + }, + )), + }, } } } @@ -578,3 +609,27 @@ impl From for proto::golem::customapi::QueryOrHeaderType { Self { kind: Some(kind) } } } + +impl TryFrom for CorsOptions { + type Error = String; + + fn try_from( + value: golem_api_grpc::proto::golem::customapi::CorsOptions, + ) -> Result { + Ok(Self { + allowed_patterns: value + .allowed_patterns + .into_iter() + .map(OriginPattern) + .collect(), + }) + } +} + +impl From for golem_api_grpc::proto::golem::customapi::CorsOptions { + fn from(value: CorsOptions) -> Self { + Self { + allowed_patterns: value.allowed_patterns.into_iter().map(|op| op.0).collect(), + } + } +} diff --git a/golem-worker-service/Cargo.toml b/golem-worker-service/Cargo.toml index f0edd6bcd4..4b770b9c2c 100644 --- a/golem-worker-service/Cargo.toml +++ b/golem-worker-service/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "golem-worker-service" version = "0.0.0" -edition = "2021" +edition = "2024" homepage = "https://golem.cloud" repository = "https://github.com/golemcloud/golem" @@ -44,6 +44,7 @@ bigdecimal = { workspace = true } bytes = { workspace = true } chrono = { workspace = true } conditional-trait-gen = { workspace = true } +cookie = { workspace = true } darling = { workspace = true } derive_more = { workspace = true } desert_rust = { workspace = true } diff --git a/golem-worker-service/config/worker-service.sample.env b/golem-worker-service/config/worker-service.sample.env index 38d598f6c1..70e6a45b27 100644 --- a/golem-worker-service/config/worker-service.sample.env +++ b/golem-worker-service/config/worker-service.sample.env @@ -11,14 +11,13 @@ GOLEM__AUTH_SERVICE__AUTH_CTX_CACHE_TTL="10m" GOLEM__AUTH_SERVICE__ENVIRONMENT_AUTH_DETAILS_CACHE_EVICTION_PERIOD="1m" GOLEM__AUTH_SERVICE__ENVIRONMENT_AUTH_DETAILS_CACHE_MAX_CAPACITY=1024 GOLEM__AUTH_SERVICE__ENVIRONMENT_AUTH_DETAILS_CACHE_TTL="10m" -GOLEM__BLOB_STORAGE__TYPE="LocalFileSystem" -GOLEM__BLOB_STORAGE__CONFIG__ROOT="../data/blob_storage" GOLEM__COMPONENT_SERVICE__COMPONENT_CACHE_MAX_CAPACITY=1024 GOLEM__GATEWAY_SESSION_STORAGE__TYPE="Redis" GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__DATABASE=0 GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__HOST="localhost" GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__KEY_PREFIX="" #GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__PASSWORD= +GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__PENDING_LOGIN_EXPIRATION="1h" GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__POOL_SIZE=8 GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__PORT=6380 GOLEM__GATEWAY_SESSION_STORAGE__CONFIG__TRACING=false diff --git a/golem-worker-service/config/worker-service.toml b/golem-worker-service/config/worker-service.toml index 185616f14c..0f12abe030 100644 --- a/golem-worker-service/config/worker-service.toml +++ b/golem-worker-service/config/worker-service.toml @@ -13,12 +13,6 @@ environment_auth_details_cache_eviction_period = "1m" environment_auth_details_cache_max_capacity = 1024 environment_auth_details_cache_ttl = "10m" -[blob_storage] -type = "LocalFileSystem" - -[blob_storage.config] -root = "../data/blob_storage" - [component_service] component_cache_max_capacity = 1024 @@ -29,6 +23,7 @@ type = "Redis" database = 0 host = "localhost" key_prefix = "" +pending_login_expiration = "1h" pool_size = 8 port = 6380 tracing = false diff --git a/golem-worker-service/src/api/agents.rs b/golem-worker-service/src/api/agents.rs index 34a3f1d71e..fa7c949044 100644 --- a/golem-worker-service/src/api/agents.rs +++ b/golem-worker-service/src/api/agents.rs @@ -2,10 +2,10 @@ use crate::api::common::ApiEndpointError; use crate::service::auth::AuthService; use crate::service::worker::AgentsService; use chrono::{DateTime, Utc}; +use golem_common::model::IdempotencyKey; use golem_common::model::agent::{AgentTypeName, UntypedJsonDataValue}; use golem_common::model::application::ApplicationName; use golem_common::model::environment::EnvironmentName; -use golem_common::model::IdempotencyKey; use golem_common::recorded_http_api_request; use golem_service_base::api_tags::ApiTags; use golem_service_base::model::auth::GolemSecurityScheme; diff --git a/golem-worker-service/src/api/common.rs b/golem-worker-service/src/api/common.rs index 56ed5dc873..13317ce83b 100644 --- a/golem-worker-service/src/api/common.rs +++ b/golem-worker-service/src/api/common.rs @@ -12,21 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::gateway_execution::request_handler::RequestHandlerError; -use crate::gateway_execution::route_resolver::RouteResolverError; +use crate::custom_api::error::RequestHandlerError; +use crate::custom_api::route_resolver::RouteResolverError; use crate::service::auth::AuthServiceError; use crate::service::component::ComponentServiceError; use crate::service::limit::LimitServiceError; use crate::service::worker::{CallWorkerExecutorError, WorkerServiceError}; +use golem_common::SafeDisplay; use golem_common::metrics::api::ApiErrorDetails; use golem_common::model::error::ErrorBody; use golem_common::model::error::ErrorsBody; -use golem_common::SafeDisplay; use golem_service_base::clients::registry::RegistryServiceError; use golem_service_base::error::worker_executor::WorkerExecutorError; use golem_service_base::model::auth::AuthorizationError; -use poem_openapi::payload::Json; use poem_openapi::ApiResponse; +use poem_openapi::payload::Json; use serde::{Deserialize, Serialize}; /// Detail in case the error was caused by the worker failing @@ -275,23 +275,26 @@ impl From for ApiEndpointError { | RequestHandlerError::HeaderIsNotAscii { .. } | RequestHandlerError::BodyIsNotValidJson { .. } | RequestHandlerError::JsonBodyParsingFailed { .. } - | RequestHandlerError::UnsupportedMimeType { .. } => Self::bad_request(value), - - RequestHandlerError::ResolvingRouteFailed( + | RequestHandlerError::UnsupportedMimeType { .. } + | RequestHandlerError::ResolvingRouteFailed( RouteResolverError::CouldNotGetDomainFromRequest(_) | RouteResolverError::MalformedPath(_), ) => Self::bad_request(value), - RequestHandlerError::ResolvingRouteFailed(RouteResolverError::CouldNotBuildRouter) => { - Self::internal(value) - } + RequestHandlerError::ResolvingRouteFailed(RouteResolverError::NoMatchingRoute) => { Self::not_found(value) } + RequestHandlerError::OidcTokenExchangeFailed + | RequestHandlerError::UnknownOidcState => Self::forbidden(value), + RequestHandlerError::AgentResponseTypeMismatch { .. } | RequestHandlerError::InvariantViolated { .. } | RequestHandlerError::AgentInvocationFailed(_) - | RequestHandlerError::InternalError(_) => Self::internal(value), + | RequestHandlerError::InternalError(_) + | RequestHandlerError::ResolvingRouteFailed(RouteResolverError::CouldNotBuildRouter) => { + Self::internal(value) + } } } } diff --git a/golem-worker-service/src/api/mod.rs b/golem-worker-service/src/api/mod.rs index 470cb467e0..19bbb0e0a2 100644 --- a/golem-worker-service/src/api/mod.rs +++ b/golem-worker-service/src/api/mod.rs @@ -14,13 +14,11 @@ pub mod agents; pub mod common; -mod custom_http_request; mod worker; -use self::custom_http_request::CustomHttpRequestApi; use crate::api::agents::AgentsApi; use crate::api::worker::WorkerApi; -use crate::service::Services; +use crate::bootstrap::Services; use golem_service_base::api::HealthcheckApi; use poem_openapi::OpenApiService; @@ -44,7 +42,3 @@ pub fn make_open_api_service(services: &Services) -> OpenApiService { "1.0", ) } - -pub fn custom_http_request_api(services: &Services) -> CustomHttpRequestApi { - CustomHttpRequestApi::new(services.request_handler.clone()) -} diff --git a/golem-worker-service/src/api/worker.rs b/golem-worker-service/src/api/worker.rs index 0385996920..cc56b4f647 100644 --- a/golem-worker-service/src/api/worker.rs +++ b/golem-worker-service/src/api/worker.rs @@ -17,7 +17,7 @@ use crate::model; use crate::service::auth::AuthService; use crate::service::component::ComponentService; use crate::service::worker::ConnectWorkerStream; -use crate::service::worker::{proxy_worker_connection, InvocationParameters, WorkerService}; +use crate::service::worker::{InvocationParameters, WorkerService, proxy_worker_connection}; use futures::StreamExt; use futures::TryStreamExt; use golem_common::model::auth::TokenSecret; @@ -30,14 +30,14 @@ use golem_common::model::oplog::OplogCursor; use golem_common::model::oplog::OplogIndex; use golem_common::model::worker::{RevertWorkerTarget, WorkerCreationRequest, WorkerMetadataDto}; use golem_common::model::{IdempotencyKey, ScanCursor, WorkerFilter, WorkerId}; -use golem_common::{recorded_http_api_request, SafeDisplay}; +use golem_common::{SafeDisplay, recorded_http_api_request}; use golem_service_base::api_tags::ApiTags; use golem_service_base::model::auth::{ AuthCtx, EnvironmentAction, GolemSecurityScheme, WrappedGolemSecuritySchema, }; use golem_service_base::model::*; -use poem::web::websocket::{BoxWebSocketUpgraded, WebSocket}; use poem::Body; +use poem::web::websocket::{BoxWebSocketUpgraded, WebSocket}; use poem_openapi::param::{Header, Path, Query}; use poem_openapi::payload::{Binary, Json}; use poem_openapi::*; @@ -810,7 +810,7 @@ impl WorkerApi { match (from, query) { (Some(_), Some(_)) => Err(ApiEndpointError::BadRequest(Json(ErrorsBody { errors: vec![ - "Cannot specify both the 'from' and the 'query' parameters".to_string() + "Cannot specify both the 'from' and the 'query' parameters".to_string(), ], cause: None, }))), diff --git a/golem-worker-service/src/bootstrap.rs b/golem-worker-service/src/bootstrap.rs new file mode 100644 index 0000000000..74863b3bbc --- /dev/null +++ b/golem-worker-service/src/bootstrap.rs @@ -0,0 +1,164 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::config::{SessionStoreConfig, WorkerServiceConfig}; +use crate::custom_api::api_definition_lookup::{ + HttpApiDefinitionsLookup, RegistryServiceApiDefinitionsLookup, +}; +use crate::custom_api::call_agent::CallAgentHandler; +use crate::custom_api::request_handler::RequestHandler; +use crate::custom_api::route_resolver::RouteResolver; +use crate::custom_api::security::DefaultIdentityProvider; +use crate::custom_api::security::handler::OidcHandler; +use crate::custom_api::security::session_store::{ + RedisSessionStore, SessionStore, SqliteSessionStore, +}; +use crate::service::auth::{AuthService, RemoteAuthService}; +use crate::service::component::{ComponentService, RemoteComponentService}; +use crate::service::limit::{LimitService, RemoteLimitService}; +use crate::service::worker::{ + AgentsService, WorkerClient, WorkerExecutorWorkerClient, WorkerService, +}; +use golem_api_grpc::proto::golem::workerexecutor::v1::worker_executor_client::WorkerExecutorClient; +use golem_common::redis::RedisPool; +use golem_service_base::clients::registry::{GrpcRegistryService, RegistryService}; +use golem_service_base::db::sqlite::SqlitePool; +use golem_service_base::grpc::client::MultiTargetGrpcClient; +use golem_service_base::service::routing_table::{RoutingTableService, RoutingTableServiceDefault}; +use std::sync::Arc; +use tonic::codec::CompressionEncoding; + +#[derive(Clone)] +pub struct Services { + pub auth_service: Arc, + pub limit_service: Arc, + pub component_service: Arc, + pub worker_service: Arc, + pub request_handler: Arc, + pub agents_service: Arc, +} + +impl Services { + pub async fn new(config: &WorkerServiceConfig) -> anyhow::Result { + let registry_service_client: Arc = + Arc::new(GrpcRegistryService::new(&config.registry_service)); + + let auth_service: Arc = Arc::new(RemoteAuthService::new( + registry_service_client.clone(), + &config.auth_service, + )); + + let component_service: Arc = Arc::new(RemoteComponentService::new( + registry_service_client.clone(), + &config.component_service, + )); + + let limit_service: Arc = + Arc::new(RemoteLimitService::new(registry_service_client.clone())); + + let routing_table_service: Arc = Arc::new( + RoutingTableServiceDefault::new(config.routing_table.clone()), + ); + + let worker_executor_clients = MultiTargetGrpcClient::new( + "worker_executor", + |channel| { + WorkerExecutorClient::new(channel) + .send_compressed(CompressionEncoding::Gzip) + .accept_compressed(CompressionEncoding::Gzip) + }, + config.worker_executor.client.clone(), + ); + + let worker_client: Arc = Arc::new(WorkerExecutorWorkerClient::new( + worker_executor_clients.clone(), + config.worker_executor.retries.clone(), + routing_table_service.clone(), + )); + + let worker_service: Arc = Arc::new(WorkerService::new( + component_service.clone(), + auth_service.clone(), + limit_service.clone(), + worker_client.clone(), + )); + + let api_definition_lookup_service: Arc = Arc::new( + RegistryServiceApiDefinitionsLookup::new(registry_service_client.clone()), + ); + + let route_resolver = Arc::new(RouteResolver::new( + &config.route_resolver, + api_definition_lookup_service.clone(), + )); + + let call_agent_handler = Arc::new(CallAgentHandler::new(worker_service.clone())); + + let identity_provider = Arc::new(DefaultIdentityProvider); + + let session_store: Arc = match &config.gateway_session_storage { + SessionStoreConfig::Redis(inner) => { + let redis = RedisPool::configured(&inner.redis_config).await?; + + let session_store = RedisSessionStore::new( + redis, + fred::types::Expiration::EX( + inner.pending_login_expiration.as_secs().try_into()?, + ), + ); + + Arc::new(session_store) + } + + SessionStoreConfig::Sqlite(inner) => { + let pool = SqlitePool::configured(&inner.sqlite_config).await?; + + let gateway_session_with_sqlite = SqliteSessionStore::new( + pool, + inner.pending_login_expiration.as_secs().try_into()?, + inner.cleanup_interval, + ) + .await?; + + Arc::new(gateway_session_with_sqlite) + } + }; + + let oidc_handler = Arc::new(OidcHandler::new( + session_store.clone(), + identity_provider.clone(), + )); + + let request_handler = Arc::new(RequestHandler::new( + route_resolver.clone(), + call_agent_handler.clone(), + oidc_handler.clone(), + )); + + let agents_service: Arc = Arc::new(AgentsService::new( + registry_service_client.clone(), + component_service.clone(), + worker_service.clone(), + )); + + Ok(Self { + auth_service, + limit_service, + component_service, + worker_service, + request_handler, + agents_service, + }) + } +} diff --git a/golem-worker-service/src/config.rs b/golem-worker-service/src/config.rs index 959b4b5c93..21751412b0 100644 --- a/golem-worker-service/src/config.rs +++ b/golem-worker-service/src/config.rs @@ -12,15 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// use crate::service::gateway::api_definition::ApiDefinitionServiceConfig; +use golem_common::SafeDisplay; use golem_common::config::DbSqliteConfig; use golem_common::config::RedisConfig; use golem_common::config::{ConfigExample, ConfigLoader, HasConfigExamples}; use golem_common::model::RetryConfig; use golem_common::tracing::TracingConfig; -use golem_common::SafeDisplay; use golem_service_base::clients::registry::GrpcRegistryServiceConfig; -use golem_service_base::config::BlobStorageConfig; use golem_service_base::grpc::client::GrpcClientConfig; use golem_service_base::grpc::server::GrpcServerTlsConfig; use golem_service_base::service::routing_table::RoutingTableConfig; @@ -33,13 +31,12 @@ use std::time::Duration; pub struct WorkerServiceConfig { pub environment: String, pub tracing: TracingConfig, - pub gateway_session_storage: GatewaySessionStorageConfig, + pub gateway_session_storage: SessionStoreConfig, pub port: u16, pub custom_request_port: u16, pub grpc: GrpcApiConfig, pub routing_table: RoutingTableConfig, pub worker_executor: WorkerExecutorClientConfig, - pub blob_storage: BlobStorageConfig, pub workspace: String, pub registry_service: GrpcRegistryServiceConfig, pub cors_origin_regex: String, @@ -84,12 +81,7 @@ impl SafeDisplay for WorkerServiceConfig { ); let _ = writeln!(&mut result, "worker executor:"); let _ = writeln!(result, "{}", self.worker_executor.to_safe_string_indented()); - let _ = writeln!(&mut result, "blob storage:"); - let _ = writeln!( - &mut result, - "{}", - self.blob_storage.to_safe_string_indented() - ); + let _ = writeln!(&mut result, "workspace: {}", self.workspace); let _ = writeln!(&mut result, "registry service:"); let _ = writeln!( @@ -129,14 +121,13 @@ impl Default for WorkerServiceConfig { fn default() -> Self { Self { environment: "local".to_string(), - gateway_session_storage: GatewaySessionStorageConfig::default_redis(), + gateway_session_storage: SessionStoreConfig::Redis(Default::default()), tracing: TracingConfig::local_dev("worker-service"), port: 9005, custom_request_port: 9006, grpc: GrpcApiConfig::default(), routing_table: RoutingTableConfig::default(), worker_executor: WorkerExecutorClientConfig::default(), - blob_storage: BlobStorageConfig::default(), workspace: "release".to_string(), registry_service: GrpcRegistryServiceConfig::default(), cors_origin_regex: "https://*.golem.cloud".to_string(), @@ -183,20 +174,20 @@ impl Default for GrpcApiConfig { #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(tag = "type", content = "config")] -pub enum GatewaySessionStorageConfig { - Redis(RedisConfig), - Sqlite(DbSqliteConfig), +pub enum SessionStoreConfig { + Redis(RedisSessionStoreConfig), + Sqlite(SqliteSessionStoreConfig), } -impl SafeDisplay for GatewaySessionStorageConfig { +impl SafeDisplay for SessionStoreConfig { fn to_safe_string(&self) -> String { let mut result = String::new(); match self { - GatewaySessionStorageConfig::Redis(redis) => { + SessionStoreConfig::Redis(redis) => { let _ = writeln!(&mut result, "redis:"); let _ = writeln!(&mut result, "{}", redis.to_safe_string_indented()); } - GatewaySessionStorageConfig::Sqlite(sqlite) => { + SessionStoreConfig::Sqlite(sqlite) => { let _ = writeln!(&mut result, "sqlite:"); let _ = writeln!(&mut result, "{}", sqlite.to_safe_string_indented()); } @@ -205,15 +196,57 @@ impl SafeDisplay for GatewaySessionStorageConfig { } } -impl Default for GatewaySessionStorageConfig { +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RedisSessionStoreConfig { + #[serde(with = "humantime_serde")] + pub pending_login_expiration: std::time::Duration, + #[serde(flatten)] + pub redis_config: RedisConfig, +} + +impl Default for RedisSessionStoreConfig { fn default() -> Self { - Self::default_redis() + Self { + pending_login_expiration: Duration::from_hours(1), + redis_config: RedisConfig::default(), + } } } -impl GatewaySessionStorageConfig { - pub fn default_redis() -> Self { - Self::Redis(RedisConfig::default()) +impl SafeDisplay for RedisSessionStoreConfig { + fn to_safe_string(&self) -> String { + let mut result = String::new(); + let _ = writeln!( + &mut result, + "pending_login_expiration: {:?}", + self.pending_login_expiration + ); + let _ = writeln!(&mut result, "{}", self.redis_config.to_safe_string()); + result + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SqliteSessionStoreConfig { + #[serde(with = "humantime_serde")] + pub pending_login_expiration: std::time::Duration, + #[serde(with = "humantime_serde")] + pub cleanup_interval: std::time::Duration, + #[serde(flatten)] + pub sqlite_config: DbSqliteConfig, +} + +impl SafeDisplay for SqliteSessionStoreConfig { + fn to_safe_string(&self) -> String { + let mut result = String::new(); + let _ = writeln!( + &mut result, + "pending_login_expiration: {:?}", + self.pending_login_expiration + ); + let _ = writeln!(&mut result, "cleanup_interval: {:?}", self.cleanup_interval); + let _ = writeln!(&mut result, "{}", self.sqlite_config.to_safe_string()); + result } } diff --git a/golem-worker-service/src/gateway_execution/api_definition_lookup.rs b/golem-worker-service/src/custom_api/api_definition_lookup.rs similarity index 97% rename from golem-worker-service/src/gateway_execution/api_definition_lookup.rs rename to golem-worker-service/src/custom_api/api_definition_lookup.rs index 82029fb02f..a5180e2dbf 100644 --- a/golem-worker-service/src/gateway_execution/api_definition_lookup.rs +++ b/golem-worker-service/src/custom_api/api_definition_lookup.rs @@ -14,7 +14,7 @@ use async_trait::async_trait; use golem_common::model::domain_registration::Domain; -use golem_common::{error_forwarding, SafeDisplay}; +use golem_common::{SafeDisplay, error_forwarding}; use golem_service_base::clients::registry::{RegistryService, RegistryServiceError}; use golem_service_base::custom_api::CompiledRoutes; use std::sync::Arc; diff --git a/golem-worker-service/src/custom_api/call_agent/mod.rs b/golem-worker-service/src/custom_api/call_agent/mod.rs new file mode 100644 index 0000000000..d3334c541d --- /dev/null +++ b/golem-worker-service/src/custom_api/call_agent/mod.rs @@ -0,0 +1,288 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod parameter_parsing; +mod response_mapping; + +use self::parameter_parsing::{ + parse_path_segment_value, parse_path_segment_value_to_component_model, + parse_query_or_header_value, parse_request_body, +}; +use self::response_mapping::interpret_agent_response; +use super::error::RequestHandlerError; +use super::model::RichRequest; +use super::route_resolver::ResolvedRouteEntry; +use super::{ParsedRequestBody, RouteExecutionResult}; +use crate::service::worker::WorkerService; +use anyhow::anyhow; +use golem_common::model::agent::{ + AgentId, BinaryReference, BinaryReferenceValue, DataValue, ElementValue, ElementValues, + OidcPrincipal, Principal, UntypedDataValue, UntypedElementValue, +}; +use golem_common::model::{IdempotencyKey, WorkerId}; +use golem_service_base::custom_api::{CallAgentBehaviour, ConstructorParameter, MethodParameter}; +use golem_service_base::model::auth::AuthCtx; +use golem_wasm::json::ValueAndTypeJsonExtensions; +use golem_wasm::{IntoValue, ValueAndType}; +use std::sync::Arc; +use tracing::debug; +use uuid::Uuid; + +pub struct CallAgentHandler { + worker_service: Arc, +} + +impl CallAgentHandler { + pub fn new(worker_service: Arc) -> Self { + Self { worker_service } + } + + pub async fn handle_call_agent_behaviour( + &self, + request: &mut RichRequest, + resolved_route: &ResolvedRouteEntry, + behaviour: &CallAgentBehaviour, + ) -> Result { + let worker_id = self.build_worker_id(resolved_route, behaviour)?; + + let parsed_body = parse_request_body(request, &resolved_route.route.body).await?; + + let method_params = + self.resolve_method_arguments(resolved_route, request, behaviour, parsed_body)?; + + debug!("Invoking agent {worker_id}"); + + let agent_response = self + .invoke_agent( + &worker_id, + resolved_route, + method_params, + behaviour, + request, + ) + .await?; + + debug!("Received agent response: {agent_response:?}"); + + debug!( + "Json agent response: {}", + agent_response.clone().unwrap().to_json_value().unwrap() + ); + + let route_result = + interpret_agent_response(agent_response, &behaviour.expected_agent_response)?; + + debug!("Returning call agent route result: {route_result:?}"); + + Ok(route_result) + } + + fn build_worker_id( + &self, + resolved_route: &ResolvedRouteEntry, + behaviour: &CallAgentBehaviour, + ) -> Result { + let CallAgentBehaviour { + component_id, + agent_type, + constructor_parameters, + phantom, + .. + } = behaviour; + + let mut values = Vec::with_capacity(constructor_parameters.len()); + + for param in constructor_parameters { + match param { + ConstructorParameter::Path { + path_segment_index, + parameter_type, + } => { + let raw = resolved_route.captured_path_parameters + [usize::from(*path_segment_index)] + .clone(); + + let value = parse_path_segment_value_to_component_model(raw, parameter_type)?; + + values.push(ElementValue::ComponentModel(ValueAndType::new( + value, + parameter_type.clone().into(), + ))); + } + } + } + + let data_value = DataValue::Tuple(ElementValues { elements: values }); + + let phantom_id = phantom.then(Uuid::new_v4); + + let agent_id = AgentId::new(agent_type.clone(), data_value, phantom_id); + + Ok(WorkerId { + component_id: *component_id, + worker_name: agent_id.to_string(), + }) + } + + fn resolve_method_arguments( + &self, + resolved_route: &ResolvedRouteEntry, + request: &RichRequest, + behaviour: &CallAgentBehaviour, + mut body: ParsedRequestBody, + ) -> Result, RequestHandlerError> { + let query_params = request.query_params(); + let headers = request.headers(); + + let mut values = Vec::with_capacity(behaviour.method_parameters.len()); + + for param in &behaviour.method_parameters { + let value = match param { + MethodParameter::Path { + path_segment_index, + parameter_type, + } => { + let raw = resolved_route.captured_path_parameters + [usize::from(*path_segment_index)] + .clone(); + + parse_path_segment_value(raw, parameter_type)? + } + + MethodParameter::Query { + query_parameter_name, + parameter_type, + } => { + let empty = Vec::new(); + let vals = query_params.get(query_parameter_name).unwrap_or(&empty); + + parse_query_or_header_value(vals, parameter_type)? + } + + MethodParameter::Header { + header_name, + parameter_type, + } => { + let vals = headers + .get_all(header_name) + .iter() + .map(|h| { + h.to_str().map(String::from).map_err(|_| { + RequestHandlerError::HeaderIsNotAscii { + header_name: header_name.clone(), + } + }) + }) + .collect::, _>>()?; + + parse_query_or_header_value(&vals, parameter_type)? + } + + MethodParameter::JsonObjectBodyField { field_index } => match &body { + ParsedRequestBody::JsonBody(golem_wasm::Value::Record(fields)) => { + UntypedElementValue::ComponentModel( + fields[usize::from(*field_index)].clone(), + ) + } + + ParsedRequestBody::JsonBody(_) => { + return Err(RequestHandlerError::invariant_violated( + "Inconsistent API definition: JSON field parameter but body is not an object", + )); + } + + _ => { + return Err(RequestHandlerError::invariant_violated( + "JSON body parameter used but no JSON body schema", + )); + } + }, + + MethodParameter::UnstructuredBinaryBody => match &mut body { + ParsedRequestBody::UnstructuredBinary(binary_source) => { + let binary_source = binary_source.take().ok_or_else(|| { + RequestHandlerError::invariant_violated( + "Parsed body was already consumed", + ) + })?; + + UntypedElementValue::UnstructuredBinary(BinaryReferenceValue { + value: BinaryReference::Inline(binary_source), + }) + } + + _ => { + return Err(RequestHandlerError::invariant_violated( + "Binary body parameter used but no binary body schema", + )); + } + }, + }; + + values.push(value); + } + + Ok(values) + } + + async fn invoke_agent( + &self, + worker_id: &WorkerId, + resolved_route: &ResolvedRouteEntry, + params: Vec, + behaviour: &CallAgentBehaviour, + request: &RichRequest, + ) -> Result, RequestHandlerError> { + let method_params_data_value = UntypedDataValue::Tuple(params); + + let principal = principal_from_request(request)?; + + self.worker_service + .invoke_and_await_owned_agent( + worker_id, + Some(IdempotencyKey::fresh()), + "golem:agent/guest.{invoke}".to_string(), + vec![ + golem_wasm::protobuf::Val::from(behaviour.method_name.clone().into_value()), + golem_wasm::protobuf::Val::from(method_params_data_value.into_value()), + golem_wasm::protobuf::Val::from(principal.into_value()), + ], + None, + resolved_route.route.environment_id, + resolved_route.route.account_id, + AuthCtx::impersonated_user(resolved_route.route.account_id), + ) + .await + .map_err(Into::into) + } +} + +fn principal_from_request(request: &RichRequest) -> Result { + match request.authenticated_session() { + Some(session) => Ok(Principal::Oidc(OidcPrincipal { + sub: session.subject.clone(), + issuer: session.issuer.clone(), + email: session.email.clone(), + name: session.name.clone(), + email_verified: session.email_verified, + given_name: session.given_name.clone(), + family_name: session.family_name.clone(), + picture: session.picture.clone(), + preferred_username: session.preferred_username.clone(), + claims: serde_json::to_string(&session.claims) + .map_err(|e| anyhow!("CoreIdTokenClaims serialization error: {e}"))?, + })), + None => Ok(Principal::anonymous()), + } +} diff --git a/golem-worker-service/src/gateway_execution/parameter_parsing.rs b/golem-worker-service/src/custom_api/call_agent/parameter_parsing.rs similarity index 97% rename from golem-worker-service/src/gateway_execution/parameter_parsing.rs rename to golem-worker-service/src/custom_api/call_agent/parameter_parsing.rs index 5f202bfb9f..bfef971306 100644 --- a/golem-worker-service/src/gateway_execution/parameter_parsing.rs +++ b/golem-worker-service/src/custom_api/call_agent/parameter_parsing.rs @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::request::RichRequest; -use super::request_handler::RequestHandlerError; use super::ParsedRequestBody; +use crate::custom_api::error::RequestHandlerError; +use crate::custom_api::model::RichRequest; use anyhow::anyhow; use golem_common::model::agent::{BinarySource, BinaryType, UntypedElementValue}; use golem_service_base::custom_api::{PathSegmentType, QueryOrHeaderType, RequestBodySchema}; -use golem_wasm::json::ValueAndTypeJsonExtensions; use golem_wasm::ValueAndType; +use golem_wasm::json::ValueAndTypeJsonExtensions; pub fn parse_path_segment_value( value: String, @@ -263,13 +263,13 @@ async fn parse_binary_body( .map(|v| v.to_string()) .unwrap_or_else(|| "application/octet-stream".to_string()); - if let Some(allowed) = allowed_mime_types { - if !allowed.iter().any(|allowed| allowed == &mime_type) { - return Err(RequestHandlerError::UnsupportedMimeType { - mime_type, - allowed_mime_types: allowed.clone(), - }); - } + if let Some(allowed) = allowed_mime_types + && !allowed.iter().any(|allowed| allowed == &mime_type) + { + return Err(RequestHandlerError::UnsupportedMimeType { + mime_type, + allowed_mime_types: allowed.clone(), + }); } Ok(ParsedRequestBody::UnstructuredBinary(Some(BinarySource { @@ -528,7 +528,7 @@ mod request_body_tests { use super::*; use assert2::{assert, let_assert}; use golem_service_base::custom_api::RequestBodySchema; - use golem_wasm::analysis::{analysed_type, NameTypePair}; + use golem_wasm::analysis::{NameTypePair, analysed_type}; use http::Method; use poem::{Body, Request}; use serde_json::json; diff --git a/golem-worker-service/src/gateway_execution/agent_response_mapping.rs b/golem-worker-service/src/custom_api/call_agent/response_mapping.rs similarity index 74% rename from golem-worker-service/src/gateway_execution/agent_response_mapping.rs rename to golem-worker-service/src/custom_api/call_agent/response_mapping.rs index 35fbf9b8c9..9c9d5d60b8 100644 --- a/golem-worker-service/src/gateway_execution/agent_response_mapping.rs +++ b/golem-worker-service/src/custom_api/call_agent/response_mapping.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::request_handler::RequestHandlerError; -use super::RouteExecutionResult; +use crate::custom_api::error::RequestHandlerError; +use crate::custom_api::{ResponseBody, RouteExecutionResult}; use anyhow::anyhow; use golem_common::model::agent::{ AgentError, BinaryReference, DataSchema, DataValue, ElementValue, ElementValues, @@ -22,6 +22,7 @@ use golem_common::model::agent::{ use golem_wasm::analysis::AnalysedType; use golem_wasm::{FromValue, ValueAndType}; use http::StatusCode; +use std::collections::HashMap; use tracing::debug; pub fn interpret_agent_response( @@ -57,9 +58,11 @@ fn map_agent_error(agent_error: AgentError) -> Result Err(RequestHandlerError::invariant_violated( "unexpected agent error type", )), - AgentError::CustomError(inner) => { - Ok(RouteExecutionResult::CustomAgentError { body: inner }) - } + AgentError::CustomError(inner) => Ok(RouteExecutionResult { + status: StatusCode::INTERNAL_SERVER_ERROR, + headers: HashMap::new(), + body: ResponseBody::ComponentModelJsonBody { body: inner }, + }), } } @@ -74,8 +77,10 @@ fn map_successful_agent_response( match typed_value { DataValue::Tuple(ElementValues { elements }) => match elements.len() { - 0 => Ok(RouteExecutionResult::NoBody { + 0 => Ok(RouteExecutionResult { status: StatusCode::NO_CONTENT, + headers: HashMap::new(), + body: ResponseBody::NoBody, }), 1 => map_single_element_agent_response(elements.into_iter().next().unwrap()), _ => Err(RequestHandlerError::invariant_violated( @@ -97,7 +102,11 @@ fn map_single_element_agent_response( } ElementValue::UnstructuredBinary(BinaryReference::Inline(binary)) => { - Ok(RouteExecutionResult::UnstructuredBinaryBody { body: binary }) + Ok(RouteExecutionResult { + status: StatusCode::OK, + headers: HashMap::new(), + body: ResponseBody::UnstructuredBinaryBody { body: binary }, + }) } _ => Err(RequestHandlerError::invariant_violated( @@ -112,40 +121,55 @@ fn map_component_model_agent_response( use golem_wasm::Value; match value_and_type.value { - Value::Option(None) => Ok(RouteExecutionResult::NoBody { + Value::Option(None) => Ok(RouteExecutionResult { status: StatusCode::NOT_FOUND, + headers: HashMap::new(), + body: ResponseBody::NoBody, }), Value::Option(Some(inner)) => { let inner_type = unwrap_option_type(value_and_type.typ)?; - Ok(json_response_body(*inner, inner_type, StatusCode::OK)) + Ok(RouteExecutionResult { + status: StatusCode::OK, + headers: HashMap::new(), + body: json_response_body(*inner, inner_type), + }) } - Value::Result(Ok(None)) => Ok(RouteExecutionResult::NoBody { + Value::Result(Ok(None)) => Ok(RouteExecutionResult { status: StatusCode::NO_CONTENT, + headers: HashMap::new(), + body: ResponseBody::NoBody, }), Value::Result(Ok(Some(inner))) => { let inner_type = unwrap_result_ok_type(value_and_type.typ)?; - Ok(json_response_body(*inner, inner_type, StatusCode::OK)) + Ok(RouteExecutionResult { + status: StatusCode::OK, + headers: HashMap::new(), + body: json_response_body(*inner, inner_type), + }) } - Value::Result(Err(None)) => Ok(RouteExecutionResult::NoBody { + Value::Result(Err(None)) => Ok(RouteExecutionResult { status: StatusCode::INTERNAL_SERVER_ERROR, + headers: HashMap::new(), + body: ResponseBody::NoBody, }), Value::Result(Err(Some(inner))) => { let inner_type = unwrap_result_err_type(value_and_type.typ)?; - Ok(json_response_body( - *inner, - inner_type, - StatusCode::INTERNAL_SERVER_ERROR, - )) + Ok(RouteExecutionResult { + status: StatusCode::INTERNAL_SERVER_ERROR, + headers: HashMap::new(), + body: json_response_body(*inner, inner_type), + }) } - other => Ok(RouteExecutionResult::ComponentModelJsonBody { - body: ValueAndType::new(other, value_and_type.typ), + other => Ok(RouteExecutionResult { status: StatusCode::OK, + headers: HashMap::new(), + body: json_response_body(other, value_and_type.typ), }), } } @@ -192,13 +216,8 @@ fn unwrap_result_err_type(typ: AnalysedType) -> Result RouteExecutionResult { - RouteExecutionResult::ComponentModelJsonBody { +fn json_response_body(value: golem_wasm::Value, typ: AnalysedType) -> ResponseBody { + ResponseBody::ComponentModelJsonBody { body: ValueAndType::new(value, typ), - status, } } diff --git a/golem-worker-service/src/custom_api/cors.rs b/golem-worker-service/src/custom_api/cors.rs new file mode 100644 index 0000000000..f0ba3deff3 --- /dev/null +++ b/golem-worker-service/src/custom_api/cors.rs @@ -0,0 +1,108 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::error::RequestHandlerError; +use super::model::RichRequest; +use super::route_resolver::ResolvedRouteEntry; +use super::{ResponseBody, RouteExecutionResult}; +use golem_common::model::agent::HttpMethod; +use golem_service_base::custom_api::OriginPattern; +use http::StatusCode; +use std::collections::{BTreeSet, HashMap}; +use tracing::debug; + +pub fn handle_cors_preflight_behaviour( + request: &RichRequest, + allowed_origins: &BTreeSet, + allowed_methods: &BTreeSet, +) -> Result { + let origin = request.origin()?.ok_or(RequestHandlerError::MissingValue { + expected: "Origin header", + })?; + + let origin_allowed = allowed_origins + .iter() + .any(|pattern| pattern.matches(origin)); + if !origin_allowed { + return Ok(RouteExecutionResult { + status: StatusCode::FORBIDDEN, + headers: HashMap::new(), + body: ResponseBody::NoBody, + }); + } + + let allow_methods = allowed_methods + .iter() + .map(|m| { + let converted = http::Method::try_from(m.clone()).map_err(|_| { + RequestHandlerError::invariant_violated("HttpMethod conversion error") + })?; + let rendered = converted.to_string(); + Ok::<_, RequestHandlerError>(rendered) + }) + .collect::, _>>()? + .join(", "); + + let mut headers = HashMap::new(); + + headers.insert( + http::header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.to_string(), + ); + headers.insert(http::header::ACCESS_CONTROL_ALLOW_METHODS, allow_methods); + headers.insert( + http::header::ACCESS_CONTROL_ALLOW_HEADERS, + "Content-Type, Authorization".to_string(), + ); + headers.insert(http::header::ACCESS_CONTROL_MAX_AGE, "3600".to_string()); + headers.insert(http::header::VARY, "Origin".to_string()); + + Ok(RouteExecutionResult { + status: StatusCode::NO_CONTENT, + headers, + body: ResponseBody::NoBody, + }) +} + +pub async fn apply_cors_outgoing_middleware( + result: &mut RouteExecutionResult, + request: &RichRequest, + resolved_route: &ResolvedRouteEntry, +) -> Result<(), RequestHandlerError> { + debug!("Begin executing SetCorsResponseHeadersMiddleware"); + + let cors = &resolved_route.route.cors; + + if cors.allowed_patterns.is_empty() { + return Ok(()); + } + + if let Some(origin) = request.origin()? + && cors.allowed_patterns.iter().any(|p| p.matches(origin)) + { + result.headers.insert( + http::header::ACCESS_CONTROL_ALLOW_ORIGIN, + origin.to_string(), + ); + result + .headers + .insert(http::header::VARY, "Origin".to_string()); + result.headers.insert( + http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + "true".to_string(), + ); + } + + Ok(()) +} diff --git a/golem-worker-service/src/custom_api/error.rs b/golem-worker-service/src/custom_api/error.rs new file mode 100644 index 0000000000..f151be92d0 --- /dev/null +++ b/golem-worker-service/src/custom_api/error.rs @@ -0,0 +1,97 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::route_resolver::RouteResolverError; +use super::security::IdentityProviderError; +use super::security::session_store::SessionStoreError; +use crate::service::worker::WorkerServiceError; +use golem_common::{SafeDisplay, error_forwarding}; + +#[derive(Debug, thiserror::Error)] +pub enum RequestHandlerError { + #[error("Failed parsing value; Provided: {value}; Expected type: {expected}")] + ValueParsingFailed { + value: String, + expected: &'static str, + }, + #[error("Expected {expected} values to be provided, but found none")] + MissingValue { expected: &'static str }, + #[error("Expected {expected} values to be provided, but found too many")] + TooManyValues { expected: &'static str }, + #[error("Header value of {header_name} is not valid ascii")] + HeaderIsNotAscii { header_name: String }, + #[error("Request body was not valid json: {error}")] + BodyIsNotValidJson { error: String }, + #[error("Failed parsing json body: [{formatted}]", formatted=.errors.join(","))] + JsonBodyParsingFailed { errors: Vec }, + #[error("Agent response did not match expected type: {error}")] + AgentResponseTypeMismatch { error: String }, + #[error("Mime type {mime_type} is not supported. Allowed mime types: [{formatted_mime_types}]", formatted_mime_types=.allowed_mime_types.join(","))] + UnsupportedMimeType { + mime_type: String, + allowed_mime_types: Vec, + }, + #[error("Unknown OIDC state")] + UnknownOidcState, + #[error("OIDC token exchange failed")] + OidcTokenExchangeFailed, + #[error("Invariant violated: {msg}")] + InvariantViolated { msg: &'static str }, + #[error("Resolving route failed: {0}")] + ResolvingRouteFailed(#[from] RouteResolverError), + #[error("Invocation failed: {0}")] + AgentInvocationFailed(#[from] WorkerServiceError), + #[error(transparent)] + InternalError(#[from] anyhow::Error), +} + +impl RequestHandlerError { + pub fn invariant_violated(msg: &'static str) -> Self { + Self::InvariantViolated { msg } + } +} + +impl SafeDisplay for RequestHandlerError { + fn to_safe_string(&self) -> String { + match self { + Self::ValueParsingFailed { .. } => self.to_string(), + Self::MissingValue { .. } => self.to_string(), + Self::TooManyValues { .. } => self.to_string(), + Self::HeaderIsNotAscii { .. } => self.to_string(), + Self::BodyIsNotValidJson { .. } => self.to_string(), + Self::JsonBodyParsingFailed { .. } => self.to_string(), + Self::AgentResponseTypeMismatch { .. } => self.to_string(), + Self::UnsupportedMimeType { .. } => self.to_string(), + Self::UnknownOidcState => self.to_string(), + Self::OidcTokenExchangeFailed => self.to_string(), + + Self::InvariantViolated { .. } => "internal error".to_string(), + + Self::ResolvingRouteFailed(inner) => { + format!("Resolving route failed: {}", inner.to_safe_string()) + } + Self::AgentInvocationFailed(inner) => { + format!("Invocation failed: {}", inner.to_safe_string()) + } + + Self::InternalError(_) => "internal error".to_string(), + } + } +} + +error_forwarding!( + RequestHandlerError, + SessionStoreError, + IdentityProviderError +); diff --git a/golem-worker-service/src/custom_api/mod.rs b/golem-worker-service/src/custom_api/mod.rs new file mode 100644 index 0000000000..70be135f9c --- /dev/null +++ b/golem-worker-service/src/custom_api/mod.rs @@ -0,0 +1,32 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod api_definition_lookup; +pub mod call_agent; +mod cors; +pub mod error; +pub mod model; +pub mod poem_endpoint; +pub mod request_handler; +pub mod route_resolver; +pub mod router; +pub mod security; + +use self::poem_endpoint::CustomApiPoemEndpoint; +use crate::bootstrap::Services; +pub use model::*; + +pub fn make_custom_api_endpoint(services: &Services) -> CustomApiPoemEndpoint { + CustomApiPoemEndpoint::new(services.request_handler.clone()) +} diff --git a/golem-worker-service/src/custom_api/model.rs b/golem-worker-service/src/custom_api/model.rs new file mode 100644 index 0000000000..2120e713f1 --- /dev/null +++ b/golem-worker-service/src/custom_api/model.rs @@ -0,0 +1,228 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::error::RequestHandlerError; +use chrono::{DateTime, Utc}; +use golem_common::model::account::AccountId; +use golem_common::model::agent::BinarySource; +use golem_common::model::environment::EnvironmentId; +use golem_service_base::custom_api::{ + CallAgentBehaviour, CorsOptions, CorsPreflightBehaviour, SecuritySchemeDetails, +}; +use golem_service_base::custom_api::{PathSegment, RequestBodySchema, RouteBehaviour, RouteId}; +use http::{HeaderMap, Method}; +use http::{HeaderName, StatusCode}; +use openidconnect::Scope; +use openidconnect::core::CoreIdTokenClaims; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::sync::{Arc, OnceLock}; +use uuid::Uuid; + +const COOKIE_HEADER_NAMES: [&str; 2] = ["cookie", "Cookie"]; + +pub struct RichRequest { + pub underlying: poem::Request, + pub request_id: Uuid, + pub authenticated_session: Option, + + parsed_cookies: OnceLock>, + parsed_query_params: OnceLock>>, +} + +impl RichRequest { + pub fn new(underlying: poem::Request) -> RichRequest { + RichRequest { + underlying, + request_id: Uuid::new_v4(), + authenticated_session: None, + parsed_cookies: OnceLock::new(), + parsed_query_params: OnceLock::new(), + } + } + + pub fn origin(&self) -> Result, RequestHandlerError> { + match self.underlying.headers().get("Origin") { + Some(header) => { + let result = + header + .to_str() + .map_err(|_| RequestHandlerError::HeaderIsNotAscii { + header_name: "Origin".to_string(), + })?; + Ok(Some(result)) + } + None => Ok(None), + } + } + + pub fn headers(&self) -> &HeaderMap { + self.underlying.headers() + } + + pub fn query_params(&self) -> &HashMap> { + self.parsed_query_params.get_or_init(|| { + let mut params: HashMap> = HashMap::new(); + + if let Some(q) = self.underlying.uri().query() { + for (key, value) in url::form_urlencoded::parse(q.as_bytes()).into_owned() { + params.entry(key).or_default().push(value); + } + } + + params + }) + } + + pub fn get_single_param(&self, name: &'static str) -> Result<&str, RequestHandlerError> { + match self.query_params().get(name).map(|qp| qp.as_slice()) { + Some([single]) => Ok(single), + None | Some([]) => Err(RequestHandlerError::MissingValue { expected: name }), + _ => Err(RequestHandlerError::TooManyValues { expected: name }), + } + } + + pub fn cookies(&self) -> &HashMap { + self.parsed_cookies.get_or_init(|| { + let mut map = HashMap::new(); + for header_name in COOKIE_HEADER_NAMES.iter() { + if let Some(value) = self.underlying.header(header_name) { + for part in value.split(';') { + let mut kv = part.splitn(2, '='); + if let (Some(k), Some(v)) = (kv.next(), kv.next()) { + map.insert(k.trim().to_string(), v.trim().to_string()); + } + } + } + } + map + }) + } + + pub fn cookie(&self, name: &str) -> Option<&str> { + self.cookies().get(name).map(|s| s.as_str()) + } + + pub fn set_authenticated_session(&mut self, session: OidcSession) { + self.authenticated_session = Some(session); + } + + pub fn authenticated_session(&self) -> Option<&OidcSession> { + self.authenticated_session.as_ref() + } +} + +pub struct OidcSession { + pub subject: String, + pub issuer: String, + + pub email: Option, + pub name: Option, + pub email_verified: Option, + pub given_name: Option, + pub family_name: Option, + pub picture: Option, + pub preferred_username: Option, + + pub claims: CoreIdTokenClaims, + pub scopes: HashSet, + pub expires_at: DateTime, +} + +impl OidcSession { + pub fn is_expired(&self) -> bool { + Utc::now() >= self.expires_at + } + + pub fn scopes(&self) -> &HashSet { + &self.scopes + } +} + +#[derive(Debug)] +pub struct RichCompiledRoute { + pub account_id: AccountId, + pub environment_id: EnvironmentId, + pub route_id: RouteId, + pub method: Method, + pub path: Vec, + pub body: RequestBodySchema, + pub behavior: RichRouteBehaviour, + pub security_scheme: Option>, + pub cors: CorsOptions, +} + +#[derive(Debug)] +pub enum RichRouteBehaviour { + CallAgent(CallAgentBehaviour), + CorsPreflight(CorsPreflightBehaviour), + OidcCallback(OidcCallbackBehaviour), +} + +impl From for RichRouteBehaviour { + fn from(value: RouteBehaviour) -> Self { + match value { + RouteBehaviour::CallAgent(inner) => Self::CallAgent(inner), + RouteBehaviour::CorsPreflight(inner) => Self::CorsPreflight(inner), + } + } +} + +#[derive(Debug)] +pub struct OidcCallbackBehaviour { + pub security_scheme: Arc, +} + +#[derive(Debug)] +pub struct RouteExecutionResult { + pub status: StatusCode, + pub headers: HashMap, + pub body: ResponseBody, +} + +pub enum ResponseBody { + NoBody, + ComponentModelJsonBody { body: golem_wasm::ValueAndType }, + UnstructuredBinaryBody { body: BinarySource }, +} + +impl fmt::Debug for ResponseBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ResponseBody::NoBody => f.debug_struct("NoBody").finish(), + ResponseBody::ComponentModelJsonBody { body } => f + .debug_struct("ComponentModelJsonBody") + .field("body", body) + .finish(), + ResponseBody::UnstructuredBinaryBody { .. } => f.write_str("UnstructuredBinaryBody"), + } + } +} + +pub enum ParsedRequestBody { + Unused, + JsonBody(golem_wasm::Value), + // Always Some initially, will be None after being consumed by handler code + UnstructuredBinary(Option), +} + +impl fmt::Debug for ParsedRequestBody { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParsedRequestBody::Unused => f.write_str("Unused"), + ParsedRequestBody::JsonBody(value) => f.debug_tuple("JsonBody").field(value).finish(), + ParsedRequestBody::UnstructuredBinary(_) => f.write_str("UnstructuredBinary"), + } + } +} diff --git a/golem-worker-service/src/api/custom_http_request.rs b/golem-worker-service/src/custom_api/poem_endpoint.rs similarity index 90% rename from golem-worker-service/src/api/custom_http_request.rs rename to golem-worker-service/src/custom_api/poem_endpoint.rs index 3e975424b3..1860ad3623 100644 --- a/golem-worker-service/src/api/custom_http_request.rs +++ b/golem-worker-service/src/custom_api/poem_endpoint.rs @@ -12,21 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::future::Future; -use std::sync::Arc; - use crate::api::common::ApiEndpointError; -use crate::gateway_execution::request_handler::RequestHandler; +use crate::custom_api::request_handler::RequestHandler; use futures::{FutureExt, TryFutureExt}; use golem_common::recorded_http_api_request; use poem::{Endpoint, IntoResponse, Request, Response}; +use std::future::Future; +use std::sync::Arc; use tracing::Instrument; -pub struct CustomHttpRequestApi { +pub struct CustomApiPoemEndpoint { pub request_handler: Arc, } -impl CustomHttpRequestApi { +impl CustomApiPoemEndpoint { pub fn new(request_handler: Arc) -> Self { Self { request_handler } } @@ -50,7 +49,7 @@ impl CustomHttpRequestApi { } } -impl Endpoint for CustomHttpRequestApi { +impl Endpoint for CustomApiPoemEndpoint { type Output = Response; fn call(&self, req: Request) -> impl Future> + Send { diff --git a/golem-worker-service/src/custom_api/request_handler.rs b/golem-worker-service/src/custom_api/request_handler.rs new file mode 100644 index 0000000000..0de49e77d7 --- /dev/null +++ b/golem-worker-service/src/custom_api/request_handler.rs @@ -0,0 +1,144 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::call_agent::CallAgentHandler; +use super::cors::{apply_cors_outgoing_middleware, handle_cors_preflight_behaviour}; +use super::error::RequestHandlerError; +use super::model::RichRequest; +use super::model::RichRouteBehaviour; +use super::route_resolver::{ResolvedRouteEntry, RouteResolver}; +use super::security::handler::OidcHandler; +use super::{OidcCallbackBehaviour, ResponseBody, RouteExecutionResult}; +use anyhow::anyhow; +use golem_service_base::custom_api::CorsPreflightBehaviour; +use golem_wasm::json::ValueAndTypeJsonExtensions; +use poem::{Request, Response}; +use std::sync::Arc; +use tracing::{Instrument, debug}; + +pub struct RequestHandler { + route_resolver: Arc, + call_agent_handler: Arc, + oidc_handler: Arc, +} + +#[allow(irrefutable_let_patterns)] +impl RequestHandler { + pub fn new( + route_resolver: Arc, + call_agent_handler: Arc, + oidc_handler: Arc, + ) -> Self { + Self { + route_resolver, + call_agent_handler, + oidc_handler, + } + } + + pub async fn handle_request(&self, request: Request) -> Result { + debug!("Begin http request handling for request {request:?}"); + + let matching_route = self.route_resolver.resolve_matching_route(&request).await?; + let mut request = RichRequest::new(request); + + let execution_result = self + .execute_route_and_middlewares(&mut request, &matching_route) + .instrument(tracing::span!( + tracing::Level::INFO, + "handle_route", + domain = %matching_route.domain, + method = %matching_route.route.method, + route = %matching_route.route.path.iter().map(|p| p.to_string()).collect::>().join("/") + )) + .await?; + + let response = route_execution_result_to_response(execution_result)?; + + Ok(response) + } + + async fn execute_route_and_middlewares( + &self, + request: &mut RichRequest, + resolved_route: &ResolvedRouteEntry, + ) -> Result { + if let Some(short_circuit) = self + .oidc_handler + .apply_oidc_incoming_middleware(request, resolved_route) + .await? + { + return Ok(short_circuit); + } + + let mut result = self.execute_route(request, resolved_route).await?; + + apply_cors_outgoing_middleware(&mut result, request, resolved_route).await?; + + Ok(result) + } + + async fn execute_route( + &self, + request: &mut RichRequest, + resolved_route: &ResolvedRouteEntry, + ) -> Result { + match &resolved_route.route.behavior { + RichRouteBehaviour::CallAgent(behaviour) => { + self.call_agent_handler + .handle_call_agent_behaviour(request, resolved_route, behaviour) + .await + } + + RichRouteBehaviour::CorsPreflight(CorsPreflightBehaviour { + allowed_origins, + allowed_methods, + }) => handle_cors_preflight_behaviour(request, allowed_origins, allowed_methods), + + RichRouteBehaviour::OidcCallback(OidcCallbackBehaviour { security_scheme }) => { + self.oidc_handler + .handle_oidc_callback_behaviour(request, security_scheme) + .await + } + } + } +} + +fn route_execution_result_to_response( + result: RouteExecutionResult, +) -> Result { + let mut response_builder = Response::builder().status(result.status); + + for (name, value) in result.headers { + response_builder = response_builder.header(name, value); + } + + match result.body { + ResponseBody::NoBody => Ok(response_builder.finish()), + + ResponseBody::ComponentModelJsonBody { body } => { + let body = poem::Body::from_json( + body.to_json_value() + .map_err(|e| anyhow!("ComponentModelJsonBody conversion error: {e}"))?, + ) + .map_err(anyhow::Error::from)?; + + Ok(response_builder.body(body)) + } + + ResponseBody::UnstructuredBinaryBody { body } => Ok(response_builder + .body(body.data) + .set_content_type(body.binary_type.mime_type)), + } +} diff --git a/golem-worker-service/src/gateway_execution/route_resolver.rs b/golem-worker-service/src/custom_api/route_resolver.rs similarity index 56% rename from golem-worker-service/src/gateway_execution/route_resolver.rs rename to golem-worker-service/src/custom_api/route_resolver.rs index 65baeab361..44611ddd27 100644 --- a/golem-worker-service/src/gateway_execution/route_resolver.rs +++ b/golem-worker-service/src/custom_api/route_resolver.rs @@ -13,19 +13,15 @@ // limitations under the License. use super::api_definition_lookup::{ApiDefinitionLookupError, HttpApiDefinitionsLookup}; +use super::model::RichCompiledRoute; +use super::router::Router; use crate::config::RouteResolverConfig; -use crate::gateway_router::Router; -// use crate::model::{HttpMiddleware, RichCompiledRoute, RichGatewayBindingCompiled, SwaggerHtml}; -// use crate::swagger_ui::generate_swagger_html; +use crate::custom_api::{OidcCallbackBehaviour, RichRouteBehaviour}; +use golem_common::SafeDisplay; use golem_common::cache::SimpleCache; use golem_common::cache::{BackgroundEvictionMode, Cache, FullCacheEvictionMode}; use golem_common::model::domain_registration::Domain; -// use golem_service_base::custom_api::openapi::HttpApiDefinitionOpenApiSpec; -// use golem_service_base::custom_api::HttpCors; -use golem_service_base::custom_api::CompiledRoutes; -// use golem_service_base::custom_api::{RouteBehaviour, SwaggerUiBindingCompiled}; -use crate::model::RichCompiledRoute; -use golem_common::SafeDisplay; +use golem_service_base::custom_api::{CompiledRoutes, CorsOptions, PathSegment, RequestBodySchema}; use std::collections::HashMap; use std::sync::Arc; use tracing::debug; @@ -151,46 +147,19 @@ impl RouteResolver { .map(|(id, details)| (id, Arc::new(details))) .collect(); - // let swagger_ui_htmls = Self::precompute_swagger_ui_htmls( - // domain, - // &compiled_routes.routes, - // &compiled_routes.security_schemes, - // ) - // .await?; - - // let cors_routes: HashMap = compiled_routes - // .routes - // .iter() - // .filter_map(|r| match &r.binding { - // RouteBehaviour::HttpCorsPreflight(inner) => { - // Some((r.path.clone(), inner.http_cors.clone())) - // } - // _ => None, - // }) - // .collect(); - let mut enriched_routes = Vec::with_capacity(compiled_routes.routes.len()); for route in compiled_routes.routes { - // let mut middlewares = Vec::new(); - let security_scheme = if let Some(security_scheme_id) = route.security_scheme { let security_scheme = security_schemes .get(&security_scheme_id) .ok_or(format!("Security scheme {security_scheme_id} not found"))? .clone(); - // middlewares.push(HttpMiddleware::AuthenticateRequest(security_scheme.clone())); Some(security_scheme) } else { None }; - // if route.method != RouteMethod::Options { - // if let Some(cors) = cors_routes.get(&route.path) { - // middlewares.push(HttpMiddleware::Cors(cors.clone())); - // } - // } - let enriched = RichCompiledRoute { account_id: compiled_routes.account_id, environment_id: compiled_routes.environment_id, @@ -201,7 +170,7 @@ impl RouteResolver { .map_err(|e| format!("Failed converting HttpMethod to http::Method: {e}"))?, path: route.path, body: route.body, - behavior: route.behavior, + behavior: route.behavior.into(), security_scheme, cors: route.cors, }; @@ -209,95 +178,40 @@ impl RouteResolver { enriched_routes.push(enriched); } - // let auth_call_back_routes = Self::get_auth_call_back_routes( - // &compiled_routes.account_id, - // &compiled_routes.environment_id, - // compiled_routes.security_schemes, - // )?; + // add synthethic oidc callback routes + for scheme in security_schemes.values() { + let redirect_url_path_segments: Vec = scheme + .redirect_url + .url() + .path_segments() + .ok_or_else(|| "Failed splitting security scheme redirect url".to_string())? + .map(|s| PathSegment::Literal { + value: s.to_string(), + }) + .collect(); + + let callback_route = RichCompiledRoute { + account_id: compiled_routes.account_id, + environment_id: compiled_routes.environment_id, + // TODO: Have some helper for synthethic vs user defined routes + route_id: -1, + method: http::Method::GET, + path: redirect_url_path_segments, + body: RequestBodySchema::Unused, + behavior: RichRouteBehaviour::OidcCallback(OidcCallbackBehaviour { + security_scheme: scheme.clone(), + }), + security_scheme: None, + cors: CorsOptions { + allowed_patterns: Vec::new(), + }, + }; - // for auth_call_back_route in auth_call_back_routes { - // let existing_route = transformed_routes.iter().find(|r| { - // r.method == auth_call_back_route.method && r.path == auth_call_back_route.path - // }); - // if existing_route.is_none() { - // transformed_routes.push(auth_call_back_route); - // } - // } + enriched_routes.push(callback_route); + } Ok(enriched_routes) } - - // async fn precompute_swagger_ui_htmls( - // domain: &Domain, - // compiled_routes: &[CompiledRoute], - // security_schemes: &HashMap, - // ) -> Result>, String> { - // let definitions_that_need_ui: HashMap = - // compiled_routes - // .iter() - // .filter_map(|r| match &r.binding { - // RouteBehaviour::SwaggerUi(inner) => { - // Some((inner.http_api_definition_id, *inner.clone())) - // } - // _ => None, - // }) - // .collect(); - - // let mut swagger_uis = HashMap::with_capacity(definitions_that_need_ui.len()); - // for (_, swagger_ui_binding) in definitions_that_need_ui { - // let matching_routes: Vec<&CompiledRoute> = compiled_routes - // .iter() - // .filter(|cr| cr.http_api_definition_id == swagger_ui_binding.http_api_definition_id) - // .collect(); - - // let openapi_definition = HttpApiDefinitionOpenApiSpec::from_routes( - // &swagger_ui_binding.http_api_definition_name, - // &swagger_ui_binding.http_api_definition_version, - // matching_routes, - // security_schemes, - // ) - // .await?; - - // let swagger_html = generate_swagger_html(&domain.0, openapi_definition)?; - // swagger_uis.insert( - // swagger_ui_binding.http_api_definition_id, - // Arc::new(swagger_html), - // ); - // } - - // Ok(swagger_uis) - // } - - // fn get_auth_call_back_routes( - // account_id: &AccountId, - // environment_id: &EnvironmentId, - // security_schemes: HashMap, - // ) -> Result, String> { - // let mut routes = vec![]; - - // for (_, scheme) in security_schemes { - // // In a security scheme, the auth-call-back (aka redirect_url) is full URL - // // and not just the relative path - // let redirect_url = scheme.redirect_url.clone(); - // let path = redirect_url.url().path(); - // let path = AllPathPatterns::parse(path)?; - // let method = RouteMethod::Get; - // let binding = RichGatewayBindingCompiled::HttpAuthCallBack(Box::new(scheme)); - - // let route = RichCompiledRoute { - // path, - // method, - // binding, - // middlewares: Vec::new(), - // account_id: *account_id, - // environment_id: *environment_id, - // }; - - // routes.push(route) - // } - - // Ok(routes) - // } } fn authority_from_request(request: &poem::Request) -> Result { diff --git a/golem-worker-service/src/gateway_router/mod.rs b/golem-worker-service/src/custom_api/router/mod.rs similarity index 97% rename from golem-worker-service/src/gateway_router/mod.rs rename to golem-worker-service/src/custom_api/router/mod.rs index 010c81df56..6975c4e355 100644 --- a/golem-worker-service/src/gateway_router/mod.rs +++ b/golem-worker-service/src/custom_api/router/mod.rs @@ -15,7 +15,7 @@ pub mod tree; use self::tree::RadixNode; -use crate::model::RichCompiledRoute; +use super::RichCompiledRoute; use golem_service_base::custom_api::PathSegment; use http::Method; use std::sync::Arc; @@ -69,7 +69,7 @@ impl Router> { #[cfg(test)] mod tests { - use crate::gateway_router::Router; + use super::Router; use golem_service_base::custom_api::PathSegment; use http::Method; use test_r::test; diff --git a/golem-worker-service/src/gateway_router/tree.rs b/golem-worker-service/src/custom_api/router/tree.rs similarity index 100% rename from golem-worker-service/src/gateway_router/tree.rs rename to golem-worker-service/src/custom_api/router/tree.rs diff --git a/golem-worker-service/src/custom_api/security/handler.rs b/golem-worker-service/src/custom_api/security/handler.rs new file mode 100644 index 0000000000..95a61e393f --- /dev/null +++ b/golem-worker-service/src/custom_api/security/handler.rs @@ -0,0 +1,198 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::IdentityProvider; +use super::model::AuthorizationUrl; +use super::session_store::SessionStore; +use crate::custom_api::error::RequestHandlerError; +use crate::custom_api::model::{OidcSession, RichRequest}; +use crate::custom_api::route_resolver::ResolvedRouteEntry; +use crate::custom_api::security::model::SessionId; +use crate::custom_api::{ResponseBody, RouteExecutionResult}; +use cookie::Cookie; +use golem_service_base::custom_api::SecuritySchemeDetails; +use http::StatusCode; +use openidconnect::{AuthorizationCode, OAuth2TokenResponse}; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tracing::debug; +use uuid::Uuid; + +const GOLEM_SESSION_ID_COOKIE_NAME: &str = "golem_session_id"; + +pub struct OidcHandler { + session_store: Arc, + identity_provider: Arc, +} + +impl OidcHandler { + pub fn new( + session_store: Arc, + identity_provider: Arc, + ) -> Self { + Self { + session_store, + identity_provider, + } + } + + pub async fn handle_oidc_callback_behaviour( + &self, + request: &mut RichRequest, + scheme: &Arc, + ) -> Result { + let code = request.get_single_param("code")?; + let state = request.get_single_param("state")?; + + let pending_login = self + .session_store + .take_pending_oidc_login(state) + .await? + .ok_or(RequestHandlerError::UnknownOidcState)?; + + let client = self.identity_provider.get_client(scheme).await?; + + let nonce = pending_login.nonce.clone(); + + let token_response = self + .identity_provider + .exchange_code_for_tokens(&client, &AuthorizationCode::new(code.to_string())) + .await + .map_err(|err| { + tracing::warn!("OIDC token exchange failed: {err}"); + RequestHandlerError::OidcTokenExchangeFailed + })?; + + let id_token_verifier = self.identity_provider.get_id_token_verifier(&client); + let id_token_claims = + self.identity_provider + .get_claims(&id_token_verifier, &token_response, &nonce)?; + + let session = OidcSession { + subject: id_token_claims.subject().to_string(), + issuer: id_token_claims.issuer().to_string(), + + email: id_token_claims.email().map(|v| v.to_string()), + name: id_token_claims + .name() + .and_then(|v| v.get(None)) + .map(|v| v.to_string()), + email_verified: id_token_claims.email_verified(), + given_name: id_token_claims + .given_name() + .and_then(|v| v.get(None)) + .map(|v| v.to_string()), + family_name: id_token_claims + .family_name() + .and_then(|v| v.get(None)) + .map(|v| v.to_string()), + picture: id_token_claims + .picture() + .and_then(|v| v.get(None)) + .map(|v| v.to_string()), + preferred_username: id_token_claims.preferred_username().map(|v| v.to_string()), + + claims: id_token_claims.clone(), + scopes: HashSet::from_iter(token_response.scopes().cloned().unwrap_or_default()), + expires_at: id_token_claims.expiration(), + }; + + let session_id = SessionId(Uuid::new_v4()); + + self.session_store + .store_authenticated_session(&session_id, session) + .await?; + + let cookie = Cookie::build((GOLEM_SESSION_ID_COOKIE_NAME, session_id.0.to_string())) + .path("/") + .http_only(true) + .secure(true) + .same_site(cookie::SameSite::Lax) + .build(); + + let mut headers = HashMap::new(); + headers.insert(http::header::SET_COOKIE, cookie.to_string()); + headers.insert(http::header::LOCATION, pending_login.original_uri.clone()); + + Ok(RouteExecutionResult { + status: StatusCode::FOUND, + headers, + body: ResponseBody::NoBody, + }) + } + + pub async fn apply_oidc_incoming_middleware( + &self, + request: &mut RichRequest, + resolved_route: &ResolvedRouteEntry, + ) -> Result, RequestHandlerError> { + debug!("Begin executing OidcSecurityMiddleware"); + + let Some(security_scheme) = resolved_route.route.security_scheme.as_ref() else { + return Ok(None); + }; + + let session_id = if let Some(s) = request.cookie(GOLEM_SESSION_ID_COOKIE_NAME) + && let Ok(parsed) = Uuid::parse_str(s) + { + SessionId(parsed) + } else { + // missing or invalid session_id -> restart flow + let execution_result = + start_oidc_flow_for_route(security_scheme, self.identity_provider.clone()).await?; + return Ok(Some(execution_result)); + }; + + let session_opt = self + .session_store + .get_authenticated_session(&session_id) + .await?; + + let Some(session) = session_opt else { + // session information missing, restart flow + let auth_url = + start_oidc_flow_for_route(security_scheme, self.identity_provider.clone()).await?; + return Ok(Some(auth_url)); + }; + + request.set_authenticated_session(session); + + Ok(None) + } +} + +async fn start_oidc_flow_for_route( + security_scheme: &SecuritySchemeDetails, + identity_provider: Arc, +) -> Result { + let client = identity_provider.get_client(security_scheme).await?; + let auth_url = identity_provider.get_authorization_url( + &client, + security_scheme.scopes.clone(), + None, + None, + ); + + Ok(start_oidc_flow(auth_url)) +} + +fn start_oidc_flow(auth_url: AuthorizationUrl) -> RouteExecutionResult { + let mut headers = std::collections::HashMap::new(); + headers.insert(http::header::LOCATION, auth_url.url.to_string()); + RouteExecutionResult { + status: http::StatusCode::FOUND, + headers, + body: ResponseBody::NoBody, + } +} diff --git a/golem-worker-service/src/gateway_security/identity_provider.rs b/golem-worker-service/src/custom_api/security/identity_provider.rs similarity index 61% rename from golem-worker-service/src/gateway_security/identity_provider.rs rename to golem-worker-service/src/custom_api/security/identity_provider.rs index 10a4ecf4ab..69348ae52e 100644 --- a/golem-worker-service/src/gateway_security/identity_provider.rs +++ b/golem-worker-service/src/custom_api/security/identity_provider.rs @@ -13,67 +13,66 @@ // limitations under the License. use super::identity_provider_metadata::GolemIdentityProviderMetadata; -use crate::gateway_security::open_id_client::OpenIdClient; +use super::model::AuthorizationUrl; +use super::open_id_client::OpenIdClient; use async_trait::async_trait; +use golem_common::IntoAnyhow; use golem_common::model::security_scheme::Provider; -use golem_common::SafeDisplay; use golem_service_base::custom_api::SecuritySchemeDetails; use openidconnect::core::{ CoreClient, CoreIdTokenClaims, CoreIdTokenVerifier, CoreProviderMetadata, CoreResponseType, CoreTokenResponse, }; use openidconnect::{AuthenticationFlow, AuthorizationCode, CsrfToken, Nonce, Scope}; -use std::fmt::{Display, Formatter}; use tracing::debug; -use url::Url; -// A high level abstraction of an identity-provider, that expose -// necessary functionalities that gets called at various points in gateway security integration. +#[derive(Debug, thiserror::Error)] +pub enum IdentityProviderError { + #[error("Failed to initialize client: {0}")] + ClientInitError(String), + #[error("Invalid issuer URL: {0}")] + InvalidIssuerUrl(String), + #[error("Failed to discover provider metadata: {0}")] + FailedToDiscoverProviderMetadata(String), + #[error("Failed to exchange code for tokens: {0}")] + FailedToExchangeCodeForTokens(String), + #[error("ID token verification error: {0}")] + IdTokenVerificationError(String), +} + +impl IntoAnyhow for IdentityProviderError { + fn into_anyhow(self) -> anyhow::Error { + anyhow::Error::from(self).context("IdentityProviderError") + } +} + #[async_trait] pub trait IdentityProvider: Send + Sync { - // Fetches the provider metadata from the issuer url, and this must be called - // during the registration of the security scheme with golem. - // The security scheme regisration stores the provider metadata, along with the security scheme - // in the security scheme store of Golem async fn get_provider_metadata( &self, provider: &Provider, ) -> Result; - // Exchange of Code token happens during the auth_call_back phase of the OpenID workflow - // In other words, this gets called only during the execution of static binding backing auth_call_back endpoint. async fn exchange_code_for_tokens( &self, client: &OpenIdClient, code: &AuthorizationCode, ) -> Result; - // A client can be created given provider-metadata at any phase of the security workflow in API Gateway. - // It can be created to create the authorisation URL to redirect user to the provider's login page - // Or It can be created before exchange of token during the execution of static binding backing auth_call_back endpoint. async fn get_client( &self, security_scheme: &SecuritySchemeDetails, ) -> Result; - // Get IDToken verifier - // For the most part, this is an internal detail to openidconnect, however, - // to test verifying claims using our own key pairs, this can be exposed fn get_id_token_verifier<'a>(&self, client: &'a OpenIdClient) -> CoreIdTokenVerifier<'a>; - // Claims are fetched from the ID token, and this gets called during the execution of static binding backing auth_call_back endpoint. - // If needed this can be called just before serving the protected route, to fetch the claims from the ID token as a middleware - // and feed it to the protected route handler through Rib. In any case, claims needs to be stored in a session - // as the OAuth2 workflow in OpenID gets initiated by the gateway and not the client user-agent. fn get_claims( &self, client: &CoreIdTokenVerifier, - core_token_response: CoreTokenResponse, + core_token_response: &CoreTokenResponse, nonce: &Nonce, ) -> Result; - // This gets called during the redirect to the provider's login page, - // and this is the first step in the OAuth2 workflow in serving a protected route. fn get_authorization_url( &self, client: &OpenIdClient, @@ -83,67 +82,24 @@ pub trait IdentityProvider: Send + Sync { ) -> AuthorizationUrl; } -pub struct AuthorizationUrl { - pub url: Url, - pub csrf_state: CsrfToken, - pub nonce: Nonce, -} - -#[derive(Debug, Clone)] -pub enum IdentityProviderError { - ClientInitError(String), - InvalidIssuerUrl(String), - FailedToDiscoverProviderMetadata(String), - FailedToExchangeCodeForTokens(String), - IdTokenVerificationError(String), -} - -// To satisfy thiserror -// https://github.com/golemcloud/golem/issues/1071 -impl Display for IdentityProviderError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.to_safe_string()) - } -} - -impl SafeDisplay for IdentityProviderError { - fn to_safe_string(&self) -> String { - match self { - IdentityProviderError::ClientInitError(err) => format!("ClientInitError: {err}"), - IdentityProviderError::InvalidIssuerUrl(err) => format!("InvalidIssuerUrl: {err}"), - IdentityProviderError::FailedToDiscoverProviderMetadata(err) => { - format!("FailedToDiscoverProviderMetadata: {err}") - } - IdentityProviderError::FailedToExchangeCodeForTokens(err) => { - format!("FailedToExchangeCodeForTokens: {err}") - } - IdentityProviderError::IdTokenVerificationError(err) => { - format!("IdTokenVerificationError: {err}") - } - } - } -} - pub struct DefaultIdentityProvider; #[async_trait] impl IdentityProvider for DefaultIdentityProvider { - // To be called during API definition registration to then store them in the database async fn get_provider_metadata( &self, provider: &Provider, ) -> Result { - let provide_metadata = CoreProviderMetadata::discover_async( + let provider_metadata = CoreProviderMetadata::discover_async( provider.issuer_url(), openidconnect::reqwest::async_http_client, ) .await .map_err(|err| IdentityProviderError::FailedToDiscoverProviderMetadata(err.to_string()))?; - Ok(provide_metadata) + Ok(provider_metadata) } - // To be called during call_back authentication URL which is a injected URL async fn exchange_code_for_tokens( &self, client: &OpenIdClient, @@ -189,7 +145,7 @@ impl IdentityProvider for DefaultIdentityProvider { fn get_claims( &self, id_token_verifier: &CoreIdTokenVerifier, - core_token_response: CoreTokenResponse, + core_token_response: &CoreTokenResponse, nonce: &Nonce, ) -> Result { let id_token_claims: &CoreIdTokenClaims = core_token_response diff --git a/golem-worker-service/src/gateway_security/identity_provider_metadata.rs b/golem-worker-service/src/custom_api/security/identity_provider_metadata.rs similarity index 100% rename from golem-worker-service/src/gateway_security/identity_provider_metadata.rs rename to golem-worker-service/src/custom_api/security/identity_provider_metadata.rs diff --git a/golem-worker-service/src/gateway_security/mod.rs b/golem-worker-service/src/custom_api/security/mod.rs similarity index 93% rename from golem-worker-service/src/gateway_security/mod.rs rename to golem-worker-service/src/custom_api/security/mod.rs index c47cd40544..ef986a331d 100644 --- a/golem-worker-service/src/gateway_security/mod.rs +++ b/golem-worker-service/src/custom_api/security/mod.rs @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod handler; mod identity_provider; mod identity_provider_metadata; +mod model; mod open_id_client; +pub mod session_store; pub use identity_provider::*; pub use open_id_client::*; diff --git a/golem-worker-service/src/custom_api/security/model.rs b/golem-worker-service/src/custom_api/security/model.rs new file mode 100644 index 0000000000..18d58af8b3 --- /dev/null +++ b/golem-worker-service/src/custom_api/security/model.rs @@ -0,0 +1,33 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use golem_common::model::security_scheme::SecuritySchemeId; +use openidconnect::{CsrfToken, Nonce}; +use url::Url; +use uuid::Uuid; + +pub struct SessionId(pub Uuid); + +#[derive(Debug)] +pub struct PendingOidcLogin { + pub scheme_id: SecuritySchemeId, + pub original_uri: String, + pub nonce: Nonce, +} + +pub struct AuthorizationUrl { + pub url: Url, + pub csrf_state: CsrfToken, + pub nonce: Nonce, +} diff --git a/golem-worker-service/src/gateway_security/open_id_client.rs b/golem-worker-service/src/custom_api/security/open_id_client.rs similarity index 100% rename from golem-worker-service/src/gateway_security/open_id_client.rs rename to golem-worker-service/src/custom_api/security/open_id_client.rs diff --git a/golem-worker-service/src/custom_api/security/session_store.rs b/golem-worker-service/src/custom_api/security/session_store.rs new file mode 100644 index 0000000000..0d927366c4 --- /dev/null +++ b/golem-worker-service/src/custom_api/security/session_store.rs @@ -0,0 +1,525 @@ +// Copyright 2024-2025 Golem Cloud +// +// Licensed under the Golem Source License v1.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://license.golem.cloud/LICENSE +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::model::{PendingOidcLogin, SessionId}; +use crate::custom_api::model::OidcSession; +use anyhow::anyhow; +use async_trait::async_trait; +use bytes::Bytes; +use chrono::{TimeDelta, Utc}; +use fred::types::Expiration; +use golem_common::error_forwarding; +use golem_common::redis::{RedisError, RedisPool}; +use golem_service_base::db::sqlite::SqlitePool; +use golem_service_base::repo::RepoError; +use sqlx::Row; +use std::time::Duration; +use tokio::task; +use tokio::time::interval; +use tracing::{Instrument, error}; + +#[derive(Debug, thiserror::Error)] +pub enum SessionStoreError { + #[error(transparent)] + InternalError(#[from] anyhow::Error), +} + +error_forwarding!(SessionStoreError, RepoError); + +impl From for SessionStoreError { + fn from(value: RedisError) -> Self { + Self::InternalError(anyhow::Error::from(value).context("RedisError")) + } +} + +#[async_trait] +pub trait SessionStore: Send + Sync { + async fn store_pending_oidc_login( + &self, + state: String, + login: PendingOidcLogin, + ) -> Result<(), SessionStoreError>; + + async fn take_pending_oidc_login( + &self, + state: &str, + ) -> Result, SessionStoreError>; + + async fn store_authenticated_session( + &self, + session_id: &SessionId, + session: OidcSession, + ) -> Result<(), SessionStoreError>; + + async fn get_authenticated_session( + &self, + session_id: &SessionId, + ) -> Result, SessionStoreError>; +} + +#[derive(Clone)] +pub struct RedisSessionStore { + redis: RedisPool, + pending_login_expiration: Expiration, +} + +impl RedisSessionStore { + pub fn new(redis: RedisPool, pending_login_expiration: Expiration) -> Self { + Self { + redis, + pending_login_expiration, + } + } + + fn redis_key_for_session(session_id: &SessionId) -> String { + format!("oidc_session:{}", session_id.0) + } + + fn redis_key_for_pending(state: &str) -> String { + format!("oidc_pending_login:{}", state) + } +} + +#[async_trait] +impl SessionStore for RedisSessionStore { + async fn store_pending_oidc_login( + &self, + state: String, + login: PendingOidcLogin, + ) -> Result<(), SessionStoreError> { + let record = records::PendingOidcLoginRecord::from(login); + let serialized = golem_common::serialization::serialize(&record) + .map_err(|e| anyhow!("PendingOidcLoginRecord serialization error: {e}"))?; + + let _: () = self + .redis + .with("session_store", "store_pending_oidc_login") + .set( + Self::redis_key_for_pending(&state), + serialized, + Some(self.pending_login_expiration.clone()), + None, + false, + ) + .await?; + + Ok(()) + } + + async fn take_pending_oidc_login( + &self, + state: &str, + ) -> Result, SessionStoreError> { + let key = Self::redis_key_for_pending(state); + let maybe_bytes: Option = self + .redis + .with("session_store", "take_pending_oidc_login") + .get(&key) + .await?; + + if let Some(bytes) = maybe_bytes { + let record: records::PendingOidcLoginRecord = + golem_common::serialization::deserialize(&bytes) + .map_err(|e| anyhow!("PendingOidcLogin deserialization error: {e}"))?; + let login = PendingOidcLogin::from(record); + + let _: i32 = self + .redis + .with("session_store", "del_pending") + .del(&key) + .await?; + Ok(Some(login)) + } else { + Ok(None) + } + } + + async fn store_authenticated_session( + &self, + session_id: &SessionId, + session: OidcSession, + ) -> Result<(), SessionStoreError> { + let record = records::OidcSessionRecord::try_from(session)?; + let serialized = golem_common::serialization::serialize(&record) + .map_err(|e| anyhow!("OidcSessionRecord serialization error: {e}"))?; + + let now = chrono::Utc::now(); + let ttl_secs = (record.expires_at - now).num_seconds(); + let expiration = if ttl_secs > 0 { + Expiration::EX(ttl_secs) + } else { + Expiration::EX(1) + }; + + let _: () = self + .redis + .with("session_store", "store_authenticated_session") + .set( + Self::redis_key_for_session(session_id), + serialized, + Some(expiration), + None, + false, + ) + .await?; + + Ok(()) + } + + async fn get_authenticated_session( + &self, + session_id: &SessionId, + ) -> Result, SessionStoreError> { + let maybe_bytes: Option = self + .redis + .with("session_store", "get_authenticated_session") + .get(&Self::redis_key_for_session(session_id)) + .await?; + + if let Some(bytes) = maybe_bytes { + let record: records::OidcSessionRecord = + golem_common::serialization::deserialize(&bytes) + .map_err(|e| anyhow!("OidcSession deserialization error: {e}"))?; + let session = OidcSession::try_from(record)?; + + Ok(Some(session)) + } else { + Ok(None) + } + } +} + +pub struct SqliteSessionStore { + pool: SqlitePool, + pending_login_expiration: i64, +} + +impl SqliteSessionStore { + pub async fn new( + pool: SqlitePool, + pending_login_expiration: i64, + cleanup_interval: Duration, + ) -> anyhow::Result { + Self::init(&pool).await?; + Self::spawn_expiration_task(pool.clone(), cleanup_interval); + Ok(Self { + pool, + pending_login_expiration, + }) + } + + async fn init(pool: &SqlitePool) -> anyhow::Result<()> { + pool.with_rw("session_store", "init") + .execute(sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS oidc_pending_login ( + state TEXT PRIMARY KEY, + value BLOB NOT NULL, + expires_at INTEGER NOT NULL + ); + + CREATE TABLE IF NOT EXISTS oidc_session ( + session_id TEXT PRIMARY KEY, + value BLOB NOT NULL, + expires_at INTEGER NOT NULL + ); + "#, + )) + .await?; + + Ok(()) + } + + fn spawn_expiration_task(db_pool: SqlitePool, cleanup_interval: Duration) { + task::spawn( + async move { + let mut cleanup_interval = interval(cleanup_interval); + + loop { + cleanup_interval.tick().await; + + if let Err(e) = Self::cleanup_expired_oidc_pending_login( + db_pool.clone(), + Self::current_time(), + ) + .await + { + error!("Failed to expire oidc pending logins: {}", e); + } + + if let Err(e) = + Self::cleanup_expired_oidc_session(db_pool.clone(), Self::current_time()) + .await + { + error!("Failed to expire oidc sessions: {}", e); + } + } + } + .in_current_span(), + ); + } + + async fn cleanup_expired_oidc_pending_login( + pool: SqlitePool, + current_time: i64, + ) -> anyhow::Result<()> { + let query = + sqlx::query("DELETE FROM oidc_pending_login WHERE expiry_time < ?;").bind(current_time); + + pool.with_rw("session_store", "cleanup_expired_oidc_pending_login") + .execute(query) + .await?; + + Ok(()) + } + + async fn cleanup_expired_oidc_session( + pool: SqlitePool, + current_time: i64, + ) -> anyhow::Result<()> { + let query = + sqlx::query("DELETE FROM oidc_session WHERE expiry_time < ?;").bind(current_time); + + pool.with_rw("session_store", "cleanup_expired_oidc_session") + .execute(query) + .await?; + + Ok(()) + } + + pub fn current_time() -> i64 { + chrono::Utc::now().timestamp() + } +} + +#[async_trait] +impl SessionStore for SqliteSessionStore { + async fn store_pending_oidc_login( + &self, + state: String, + login: PendingOidcLogin, + ) -> Result<(), SessionStoreError> { + let record = records::PendingOidcLoginRecord::from(login); + let serialized = golem_common::serialization::serialize(&record) + .map_err(|e| SessionStoreError::InternalError(anyhow::anyhow!(e)))?; + + let expiry = Utc::now() + .checked_add_signed(TimeDelta::seconds(self.pending_login_expiration)) + .ok_or_else(|| anyhow!("Failed to compute expiry"))? + .timestamp(); + + self + .pool + .with_rw("session_store", "store_pending_oidc_login") + .execute( + sqlx::query("INSERT OR REPLACE INTO oidc_pending_login (state, value, expires_at) VALUES (?, ?, ?)") + .bind(state) + .bind(serialized) + .bind(expiry) + ) + .await?; + + Ok(()) + } + + async fn take_pending_oidc_login( + &self, + state: &str, + ) -> Result, SessionStoreError> { + let row = self + .pool + .with_ro("session_store", "take_pending_oidc_login_read") + .fetch_optional( + sqlx::query("SELECT value FROM oidc_pending_login WHERE state = ?").bind(state), + ) + .await?; + + if let Some(row) = row { + let bytes: Vec = row.get(0); + let record: records::PendingOidcLoginRecord = + golem_common::serialization::deserialize(&bytes) + .map_err(|e| SessionStoreError::InternalError(anyhow::anyhow!(e)))?; + + let login = PendingOidcLogin::from(record); + + self.pool + .with_rw("session_store", "take_pending_oidc_login_write") + .execute(sqlx::query("DELETE FROM oidc_pending_login WHERE state = ?").bind(state)) + .await?; + + Ok(Some(login)) + } else { + Ok(None) + } + } + + async fn store_authenticated_session( + &self, + session_id: &SessionId, + session: OidcSession, + ) -> Result<(), SessionStoreError> { + let record = records::OidcSessionRecord::try_from(session)?; + let serialized = golem_common::serialization::serialize(&record) + .map_err(|e| SessionStoreError::InternalError(anyhow::anyhow!(e)))?; + + let expires_at = record.expires_at.timestamp(); + + self + .pool + .with_rw("session_store", "store_authenticated_session") + .execute( + sqlx::query("INSERT OR REPLACE INTO oidc_session (session_id, value, expires_at) VALUES (?, ?, ?)") + .bind(session_id.0) + .bind(serialized) + .bind(expires_at) + ) + .await?; + + Ok(()) + } + + async fn get_authenticated_session( + &self, + session_id: &SessionId, + ) -> Result, SessionStoreError> { + let row = self + .pool + .with_ro("session_store", "get_authenticated_session_read") + .fetch_optional( + sqlx::query("SELECT value, expires_at FROM oidc_session WHERE session_id = ?") + .bind(session_id.0), + ) + .await?; + + if let Some(row) = row { + let bytes: Vec = row.get(0); + let record: records::OidcSessionRecord = + golem_common::serialization::deserialize(&bytes) + .map_err(|e| SessionStoreError::InternalError(anyhow::anyhow!(e)))?; + + let session = OidcSession::try_from(record)?; + + Ok(Some(session)) + } else { + Ok(None) + } + } +} + +mod records { + use super::SessionStoreError; + use crate::custom_api::model::OidcSession; + use crate::custom_api::security::model::PendingOidcLogin; + use anyhow::anyhow; + use chrono::{DateTime, Utc}; + use desert_rust::BinaryCodec; + use golem_common::model::security_scheme::SecuritySchemeId; + use openidconnect::{Nonce, Scope}; + use std::collections::HashSet; + + #[derive(Debug, BinaryCodec)] + #[desert(evolution())] + pub struct PendingOidcLoginRecord { + pub scheme_id: SecuritySchemeId, + pub original_uri: String, + pub nonce: String, + } + + impl From for PendingOidcLoginRecord { + fn from(value: PendingOidcLogin) -> Self { + Self { + scheme_id: value.scheme_id, + original_uri: value.original_uri, + nonce: value.nonce.secret().clone(), + } + } + } + + impl From for PendingOidcLogin { + fn from(value: PendingOidcLoginRecord) -> Self { + Self { + scheme_id: value.scheme_id, + original_uri: value.original_uri, + nonce: Nonce::new(value.nonce), + } + } + } + + #[derive(Debug, BinaryCodec)] + #[desert(evolution())] + pub struct OidcSessionRecord { + pub subject: String, + pub issuer: String, + + pub email: Option, + pub name: Option, + pub email_verified: Option, + pub given_name: Option, + pub family_name: Option, + pub picture: Option, + pub preferred_username: Option, + + pub claims: String, + pub scopes: HashSet, + pub expires_at: DateTime, + } + + impl TryFrom for OidcSessionRecord { + type Error = SessionStoreError; + + fn try_from(value: OidcSession) -> Result { + Ok(Self { + subject: value.subject, + issuer: value.issuer, + + email: value.email, + name: value.name, + email_verified: value.email_verified, + given_name: value.given_name, + family_name: value.family_name, + picture: value.picture, + preferred_username: value.preferred_username, + + claims: serde_json::to_string(&value.claims) + .map_err(|e| anyhow!("CoreIdTokenClaims serialization error: {e}"))?, + scopes: value.scopes.into_iter().map(|s| s.to_string()).collect(), + expires_at: value.expires_at, + }) + } + } + + impl TryFrom for OidcSession { + type Error = SessionStoreError; + + fn try_from(value: OidcSessionRecord) -> Result { + Ok(Self { + subject: value.subject, + issuer: value.issuer, + + email: value.email, + name: value.name, + email_verified: value.email_verified, + given_name: value.given_name, + family_name: value.family_name, + picture: value.picture, + preferred_username: value.preferred_username, + + claims: serde_json::from_str(&value.claims) + .map_err(|e| anyhow!("CoreIdTokenClaims deserialization error: {e}"))?, + scopes: value.scopes.into_iter().map(Scope::new).collect(), + expires_at: value.expires_at, + }) + } + } +} diff --git a/golem-worker-service/src/gateway_execution/auth_call_back_binding_handler.rs b/golem-worker-service/src/gateway_execution/auth_call_back_binding_handler.rs deleted file mode 100644 index e97038aa24..0000000000 --- a/golem-worker-service/src/gateway_execution/auth_call_back_binding_handler.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::gateway_execution::gateway_session_store::{ - DataKey, DataValue, GatewaySessionError, GatewaySessionStore, SessionId, -}; -use crate::gateway_security::{IdentityProvider, IdentityProviderError}; -use async_trait::async_trait; -use golem_common::SafeDisplay; -use golem_service_base::custom_api::SecuritySchemeDetails; -use openidconnect::core::CoreTokenResponse; -use openidconnect::{AuthorizationCode, OAuth2TokenResponse}; -use std::collections::HashMap; -use std::sync::Arc; - -#[async_trait] -pub trait AuthCallBackBindingHandler: Send + Sync { - async fn handle_auth_call_back( - &self, - query_params: &HashMap, - security_scheme: &SecuritySchemeDetails, - ) -> Result; -} - -pub struct AuthenticationSuccess { - pub token_response: CoreTokenResponse, - pub target_path: String, - pub id_token: Option, - pub access_token: String, - pub session: String, -} - -#[derive(Debug)] -pub enum AuthorisationError { - Internal(String), - CodeNotFound, - InvalidCode, - StateNotFound, - InvalidState, - InvalidSession, - InvalidNonce, - MissingParametersInSession, - AccessTokenNotFound, - InvalidToken, - IdTokenNotFound, - ConflictingState, // Possible CSRF attack - NonceNotFound, - FailedCodeExchange(IdentityProviderError), - ClaimFetchError(IdentityProviderError), - IdentityProviderError(IdentityProviderError), - SessionError(GatewaySessionError), -} - -// Only SafeDisplay is allowed for AuthorisationError -impl SafeDisplay for AuthorisationError { - fn to_safe_string(&self) -> String { - match self { - AuthorisationError::Internal(_) => "Failed authentication".to_string(), - AuthorisationError::InvalidNonce => "Failed authentication".to_string(), - AuthorisationError::CodeNotFound => "The authorisation code is missing.".to_string(), - AuthorisationError::InvalidCode => "The authorisation code is invalid.".to_string(), - AuthorisationError::StateNotFound => { - "Missing parameters from identity provider".to_string() - } - AuthorisationError::InvalidState => { - "Invalid parameters from identity provider.".to_string() - } - AuthorisationError::InvalidSession => "The session is no longer valid.".to_string(), - AuthorisationError::MissingParametersInSession => "Session failures".to_string(), - AuthorisationError::ClaimFetchError(err) => { - format!( - "Failed to fetch claims. Error details: {}", - err.to_safe_string() - ) - } - AuthorisationError::InvalidToken => "Invalid token".to_string(), - AuthorisationError::IdentityProviderError(err) => { - format!("Identity provider error: {}", err.to_safe_string()) - } - AuthorisationError::AccessTokenNotFound => { - "Unable to continue with authorisation".to_string() - } - AuthorisationError::IdTokenNotFound => { - "Unable to continue with authentication.".to_string() - } - AuthorisationError::ConflictingState => "Suspicious login attempt".to_string(), - AuthorisationError::FailedCodeExchange(err) => { - format!( - "Failed to exchange code for tokens. Error details: {}", - err.to_safe_string() - ) - } - AuthorisationError::NonceNotFound => { - "Suspicious authorisation attempt. Failed checks.".to_string() - } - AuthorisationError::SessionError(err) => format!( - "An error occurred while updating the session. Error details: {}", - err.to_safe_string() - ), - } - } -} - -pub struct DefaultAuthCallBackBindingHandler { - gateway_session_store: Arc, - identity_provider: Arc, -} - -impl DefaultAuthCallBackBindingHandler { - pub fn new( - gateway_session_store: Arc, - identity_provider: Arc, - ) -> Self { - Self { - gateway_session_store, - identity_provider, - } - } -} - -#[async_trait] -impl AuthCallBackBindingHandler for DefaultAuthCallBackBindingHandler { - async fn handle_auth_call_back( - &self, - query_params: &HashMap, - security_scheme: &SecuritySchemeDetails, - ) -> Result { - let code = query_params - .get("code") - .map(|c| AuthorizationCode::new(c.to_string())); - let state = query_params.get("state").cloned(); - - let authorisation_code = code.ok_or(AuthorisationError::CodeNotFound)?; - let state = state.ok_or(AuthorisationError::StateNotFound)?; - - let target_path = self - .gateway_session_store - .get( - &SessionId(state.clone()), - &DataKey("redirect_url".to_string()), - ) - .await - .map_err(AuthorisationError::SessionError)? - .as_string() - .ok_or(AuthorisationError::Internal( - "Invalid redirect url (target url of the protected resource)".to_string(), - ))?; - - let open_id_client = self - .identity_provider - .get_client(security_scheme) - .await - .map_err(AuthorisationError::IdentityProviderError)?; - - let token_response = self - .identity_provider - .exchange_code_for_tokens(&open_id_client, &authorisation_code) - .await - .map_err(AuthorisationError::FailedCodeExchange)?; - - let access_token = token_response.access_token().secret().clone(); - let id_token = token_response - .extra_fields() - .id_token() - .map(|x| x.to_string()); - - // access token in session store - self.gateway_session_store - .insert( - SessionId(state.clone()), - DataKey::access_token(), - DataValue(serde_json::Value::String(access_token.clone())), - ) - .await - .map_err(AuthorisationError::SessionError)?; - - if let Some(id_token) = &id_token { - // id token in session store - self.gateway_session_store - .insert( - SessionId(state.clone()), - DataKey::id_token(), - DataValue(serde_json::Value::String(id_token.to_string())), - ) - .await - .map_err(AuthorisationError::SessionError)?; - } - - Ok(AuthenticationSuccess { - token_response, - target_path, - id_token, - access_token, - session: state, - }) - } -} diff --git a/golem-worker-service/src/gateway_execution/file_server_binding_handler.rs b/golem-worker-service/src/gateway_execution/file_server_binding_handler.rs deleted file mode 100644 index e8b64aa96b..0000000000 --- a/golem-worker-service/src/gateway_execution/file_server_binding_handler.rs +++ /dev/null @@ -1,301 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::getter::{get_response_headers_or_default, get_status_code}; -use crate::service::component::{ComponentService, ComponentServiceError}; -use crate::service::worker::{WorkerService, WorkerServiceError}; -use bytes::Bytes; -use futures::Stream; -use futures::TryStreamExt; -use golem_common::model::account::AccountId; -use golem_common::model::component::{ComponentDto, ComponentFilePath, ComponentId}; -use golem_common::model::environment::EnvironmentId; -use golem_common::model::WorkerId; -use golem_common::SafeDisplay; -use golem_service_base::model::auth::AuthCtx; -use golem_service_base::service::initial_component_files::InitialComponentFilesService; -use golem_wasm::analysis::AnalysedType; -use golem_wasm::{Value, ValueAndType}; -use http::StatusCode; -use poem::web::headers::ContentType; -use rib::RibResult; -use std::pin::Pin; -use std::str::FromStr; -use std::sync::Arc; - -pub struct FileServerBindingHandler { - component_service: Arc, - initial_component_files_service: Arc, - worker_service: Arc, -} - -impl FileServerBindingHandler { - pub fn new( - component_service: Arc, - initial_component_files_service: Arc, - worker_service: Arc, - ) -> Self { - Self { - component_service, - initial_component_files_service, - worker_service, - } - } - - pub async fn handle_file_server_binding_result( - &self, - worker_name: &str, - component_id: ComponentId, - environment_id: EnvironmentId, - account_id: AccountId, - original_result: RibResult, - ) -> Result { - let binding_details = FileServerBindingDetails::from_rib_result(original_result) - .map_err(FileServerBindingError::InvalidRibResult)?; - - let component_metadata = self - .get_component_metadata(worker_name, component_id, account_id) - .await?; - - // if we are serving a read_only file, we can just go straight to the blob storage. - let matching_ro_file = component_metadata - .files - .iter() - .find(|file| file.path == binding_details.file_path && file.is_read_only()); - - if let Some(file) = matching_ro_file { - let data = self - .initial_component_files_service - .get(environment_id, file.content_hash) - .await - .map_err(|e| { - FileServerBindingError::InternalError(format!( - "Failed looking up file in storage: {e}" - )) - })? - .ok_or(FileServerBindingError::InternalError(format!( - "File not found in file storage: {}", - file.content_hash - ))) - .map(|stream| { - let mapped = stream.map_err(std::io::Error::other); - Box::pin(mapped) - })?; - - Ok(FileServerBindingSuccess { - binding_details, - data, - }) - } else { - // Read write files need to be fetched from a running worker. - // Ask the worker service to get the file contents. If no worker is running, one will be started. - - let worker_id = WorkerId::from_component_metadata_and_worker_id( - component_id, - &component_metadata.metadata, - worker_name, - ) - .map_err(|e| { - FileServerBindingError::InternalError(format!("Invalid worker name: {e}")) - })?; - - let stream = self - .worker_service - .get_file_contents( - &worker_id, - binding_details.file_path.clone(), - AuthCtx::impersonated_user(account_id), - ) - .await?; - - let stream = stream.map_err(|e| std::io::Error::other(e.to_string())); - - Ok(FileServerBindingSuccess { - binding_details, - data: Box::pin(stream), - }) - } - } - - async fn get_component_metadata( - &self, - worker_name: &str, - component_id: ComponentId, - account_id: AccountId, - ) -> Result { - // Two cases, we either have an existing worker or not (either not configured or not existing). - // If there is no worker we need use the lastest component version, if there is none we need to use the exact component version - // the worker is using. Not doing that would make the blob_storage optimization for read-only files visible to users. - - let component_revision = { - let worker_metadata = self - .worker_service - .get_metadata( - &WorkerId { - component_id, - worker_name: worker_name.to_string(), - }, - AuthCtx::impersonated_user(account_id), - ) - .await; - - match worker_metadata { - Ok(metadata) => Some(metadata.component_revision), - Err(WorkerServiceError::WorkerNotFound(_)) => None, - Err(other) => Err(other)?, - } - }; - - let component_metadata = if let Some(component_revision) = component_revision { - self.component_service - .get_revision(component_id, component_revision) - .await - .map_err(FileServerBindingError::ComponentServiceError)? - } else { - self.component_service - .get_latest_by_id(component_id) - .await - .map_err(FileServerBindingError::ComponentServiceError)? - }; - - Ok(component_metadata) - } -} - -pub struct FileServerBindingSuccess { - pub binding_details: FileServerBindingDetails, - pub data: Pin> + Send + 'static>>, -} - -#[derive(Debug, thiserror::Error)] -pub enum FileServerBindingError { - #[error(transparent)] - WorkerServiceError(#[from] WorkerServiceError), - #[error(transparent)] - ComponentServiceError(#[from] ComponentServiceError), - #[error("Internal error: {0}")] - InternalError(String), - #[error("Invalid rib result: {0}")] - InvalidRibResult(String), -} - -impl SafeDisplay for FileServerBindingError { - fn to_safe_string(&self) -> String { - match self { - Self::WorkerServiceError(inner) => inner.to_safe_string(), - Self::ComponentServiceError(inner) => inner.to_safe_string(), - - Self::InternalError(_) => self.to_string(), - Self::InvalidRibResult(_) => self.to_string(), - } - } -} - -#[derive(Debug, Clone)] -pub struct FileServerBindingDetails { - pub content_type: ContentType, - pub status_code: StatusCode, - pub file_path: ComponentFilePath, -} - -impl FileServerBindingDetails { - pub fn from_rib_result(result: RibResult) -> Result { - // Three supported formats: - // 1. A string path. Mime type is guessed from the path. Status code is 200. - // 2. A record with a 'file-path' field. Mime type and status are optionally taken from the record, otherwise guessed. - // 3. A result of either of the above, with the same rules applied. - match result { - RibResult::Val(value) => match value { - ValueAndType { - value: Value::Result(value), - typ: AnalysedType::Result(typ), - } => match value { - Ok(ok) => { - let ok = ValueAndType::new( - *ok.ok_or("ok unset".to_string())?, - (*typ.ok.ok_or("Missing 'ok' type")?).clone(), - ); - Self::from_rib_happy(ok) - } - Err(err) => { - let value = err.ok_or("err unset".to_string())?; - Err(format!("Error result: {value:?}")) - } - }, - other => Self::from_rib_happy(other), - }, - RibResult::Unit => Err("Expected a value".to_string()), - } - } - - /// Like the above, just without the result case. - fn from_rib_happy(value: ValueAndType) -> Result { - match &value { - ValueAndType { - value: Value::String(raw_path), - .. - } => Self::make_from(raw_path.clone(), None, None), - ValueAndType { - value: Value::Record(field_values), - typ: AnalysedType::Record(record), - } => { - let path_position = record - .fields - .iter() - .position(|pair| &pair.name == "file-path") - .ok_or("Record must contain 'file-path' field")?; - - let path = if let Value::String(path) = &field_values[path_position] { - path - } else { - return Err("file-path must be a string".to_string()); - }; - - let status = get_status_code(field_values, record)?; - let headers = get_response_headers_or_default(&value)?; - let content_type = headers.get_content_type(); - - Self::make_from(path.to_string(), content_type, status) - } - _ => Err("Response value expected".to_string()), - } - } - - fn make_from( - path: String, - content_type: Option, - status_code: Option, - ) -> Result { - let file_path = ComponentFilePath::from_either_str(&path)?; - - let content_type = match content_type { - Some(content_type) => content_type, - None => { - let mime_type = mime_guess::from_path(&path) - .first() - .ok_or("Could not determine mime type")?; - ContentType::from_str(mime_type.as_ref()) - .map_err(|e| format!("Invalid mime type: {e}"))? - } - }; - - let status_code = status_code.unwrap_or(StatusCode::OK); - - Ok(FileServerBindingDetails { - status_code, - content_type, - file_path, - }) - } -} diff --git a/golem-worker-service/src/gateway_execution/gateway_http_input_executor.rs b/golem-worker-service/src/gateway_execution/gateway_http_input_executor.rs deleted file mode 100644 index a04b5888b4..0000000000 --- a/golem-worker-service/src/gateway_execution/gateway_http_input_executor.rs +++ /dev/null @@ -1,927 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::auth_call_back_binding_handler::AuthenticationSuccess; -use super::file_server_binding_handler::{FileServerBindingError, FileServerBindingSuccess}; -use super::http_handler_binding_handler::{HttpHandlerBindingHandler, HttpHandlerBindingResult}; -use super::request::{split_resolved_route_entry, RichRequest, SplitResolvedRouteEntryResult}; -use super::route_resolver::{GatewayBindingResolverError, RouteResolver}; -use super::to_response::GatewayHttpResult; -use super::{GatewayWorkerRequestExecutor, WorkerDetails}; -use crate::gateway_execution::auth_call_back_binding_handler::AuthCallBackBindingHandler; -use crate::gateway_execution::file_server_binding_handler::FileServerBindingHandler; -use crate::gateway_execution::gateway_session_store::GatewaySessionStore; -use crate::gateway_execution::to_response::{GatewayHttpError, ToHttpResponse}; -use crate::gateway_execution::to_response_failure::ToHttpResponseFromSafeDisplay; -use crate::gateway_middleware::{ - process_middleware_in, process_middleware_out, MiddlewareError, MiddlewareSuccess, -}; -use crate::gateway_security::IdentityProvider; -use crate::http_invocation_context::{extract_request_attributes, invocation_context_from_request}; -use crate::model::{HttpMiddleware, RichGatewayBindingCompiled}; -use golem_common::model::account::AccountId; -use golem_common::model::component::{ComponentId, ComponentRevision}; -use golem_common::model::environment::EnvironmentId; -use golem_common::model::invocation_context::{ - AttributeValue, InvocationContextSpan, InvocationContextStack, SpanId, TraceId, -}; -use golem_common::model::IdempotencyKey; -use golem_common::SafeDisplay; -use golem_service_base::custom_api::SecuritySchemeDetails; -use golem_service_base::custom_api::{ - FileServerBindingCompiled, HttpHandlerBindingCompiled, IdempotencyKeyCompiled, - InvocationContextCompiled, ResponseMappingCompiled, WorkerBindingCompiled, WorkerNameCompiled, -}; -use golem_service_base::headers::TraceContextHeaders; -use golem_wasm::analysis::analysed_type::record; -use golem_wasm::analysis::{AnalysedType, NameTypePair}; -use golem_wasm::json::ValueAndTypeJsonExtensions; -use golem_wasm::{IntoValue, IntoValueAndType, ValueAndType}; -use http::StatusCode; -use poem::Body; -use rib::{RibInput, RibInputTypeInfo, RibResult, TypeName}; -use std::collections::HashMap; -use std::ops::Deref; -use std::str::FromStr; -use std::sync::Arc; -use tracing::error; -use uuid::Uuid; - -pub struct GatewayHttpInputExecutor { - route_resolver: Arc, - gateway_worker_request_executor: Arc, - file_server_binding_handler: Arc, - auth_call_back_binding_handler: Arc, - http_handler_binding_handler: Arc, - gateway_session_store: Arc, - identity_provider: Arc, -} - -impl GatewayHttpInputExecutor { - pub fn new( - route_resolver: Arc, - gateway_worker_request_executor: Arc, - file_server_binding_handler: Arc, - auth_call_back_binding_handler: Arc, - http_handler_binding_handler: Arc, - gateway_session_store: Arc, - identity_provider: Arc, - ) -> Self { - Self { - route_resolver, - gateway_worker_request_executor, - file_server_binding_handler, - auth_call_back_binding_handler, - http_handler_binding_handler, - gateway_session_store, - identity_provider, - } - } - - pub async fn execute_http_request(&self, request: poem::Request) -> poem::Response { - let resolved_route_entry = match self.route_resolver.resolve_route(&request).await { - Ok(value) => value, - Err(GatewayBindingResolverError::CouldNotBuildRouter) => { - return poem::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from_string( - "Failed to build router for request".to_string(), - )); - } - Err(GatewayBindingResolverError::CouldNotGetDomainFromRequest(err)) => { - return poem::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from_string(err)); - } - Err(GatewayBindingResolverError::NoMatchingRoute) => { - return poem::Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from_string("Route not found".to_string())); - } - }; - - let SplitResolvedRouteEntryResult { - binding, - middlewares, - rich_request, - account_id, - environment_id, - } = split_resolved_route_entry(request, resolved_route_entry); - - let mut rich_request = match self.apply_middlewares_in(rich_request, &middlewares).await { - Ok(req) => req, - Err(resp) => { - tracing::debug!("Middleware short-circuited the request handling"); - return resp; - } - }; - - match binding { - RichGatewayBindingCompiled::HttpCorsPreflight(inner) => { - inner - .http_cors - .to_response(&rich_request, &self.gateway_session_store) - .await - } - - RichGatewayBindingCompiled::HttpAuthCallBack(security_scheme) => { - let result = self - .handle_http_auth_callback_binding(&security_scheme, &rich_request) - .await; - - let response = result - .to_response(&rich_request, &self.gateway_session_store) - .await; - - apply_middlewares_out(response, &middlewares).await - } - - RichGatewayBindingCompiled::Worker(resolved_worker_binding) => { - let result = self - .handle_worker_binding(&mut rich_request, *resolved_worker_binding, account_id) - .await; - - let response = result - .to_response(&rich_request, &self.gateway_session_store) - .await; - - apply_middlewares_out(response, &middlewares).await - } - - RichGatewayBindingCompiled::HttpHandler(http_handler_binding) => { - let result = self - .handle_http_handler_binding( - &mut rich_request, - *http_handler_binding, - account_id, - ) - .await; - - let response = result - .to_response(&rich_request, &self.gateway_session_store) - .await; - - apply_middlewares_out(response, &middlewares).await - } - - RichGatewayBindingCompiled::FileServer(resolved_file_server_binding) => { - let result = self - .handle_file_server_binding( - &mut rich_request, - *resolved_file_server_binding, - account_id, - environment_id, - ) - .await; - - let response = result - .to_response(&rich_request, &self.gateway_session_store) - .await; - - apply_middlewares_out(response, &middlewares).await - } - - RichGatewayBindingCompiled::SwaggerUi(swagger_binding) => { - let result = swagger_binding.swagger_html.deref().clone(); - - let response = result - .to_response(&rich_request, &self.gateway_session_store) - .await; - - apply_middlewares_out(response, &middlewares).await - } - } - } - - async fn handle_worker_binding( - &self, - request: &mut RichRequest, - binding: WorkerBindingCompiled, - account_id: AccountId, - ) -> GatewayHttpResult { - let WorkerBindingCompiled { - response_compiled, - component_id, - component_revision, - idempotency_key_compiled, - invocation_context_compiled, - .. - } = binding; - - let worker_detail = self - .get_worker_details( - request, - None, - idempotency_key_compiled, - component_id, - component_revision, - invocation_context_compiled, - ) - .await?; - - self.execute_response_mapping_script(response_compiled, request, worker_detail, account_id) - .await - } - - async fn handle_http_handler_binding( - &self, - request: &mut RichRequest, - binding: HttpHandlerBindingCompiled, - account_id: AccountId, - ) -> GatewayHttpResult { - let HttpHandlerBindingCompiled { - component_id, - component_revision, - worker_name_compiled, - idempotency_key_compiled, - invocation_context_compiled, - .. - } = binding; - - let worker_detail = self - .get_worker_details( - request, - Some(worker_name_compiled), - idempotency_key_compiled, - component_id, - component_revision, - invocation_context_compiled, - ) - .await?; - - let incoming_http_request = request - .as_wasi_http_input() - .await - .map_err(GatewayHttpError::BadRequest)?; - - let result = self - .http_handler_binding_handler - .handle_http_handler_binding(&worker_detail, incoming_http_request, account_id) - .await; - - match result { - Ok(_) => tracing::debug!("http handler binding successful"), - Err(ref e) => tracing::warn!("http handler binding failed: {e:?}"), - } - - Ok(result) - } - - async fn handle_file_server_binding( - &self, - request: &mut RichRequest, - binding: FileServerBindingCompiled, - account_id: AccountId, - environment_id: EnvironmentId, - ) -> GatewayHttpResult { - let FileServerBindingCompiled { - component_id, - component_revision, - worker_name_compiled, - response_compiled, - .. - } = binding; - - let worker_detail = self - .get_worker_details( - request, - Some(worker_name_compiled), - None, - component_id, - component_revision, - None, - ) - .await?; - - let worker_name = worker_detail - .worker_name - .as_ref() - .ok_or_else(|| { - GatewayHttpError::FileServerBindingError(FileServerBindingError::InternalError( - "Missing worker name".to_string(), - )) - })? - .clone(); - - let response_script_result = self - .execute_response_mapping_script(response_compiled, request, worker_detail, account_id) - .await?; - - self.file_server_binding_handler - .handle_file_server_binding_result( - &worker_name, - component_id, - environment_id, - account_id, - response_script_result, - ) - .await - .map_err(GatewayHttpError::FileServerBindingError) - } - - async fn handle_http_auth_callback_binding( - &self, - security_scheme: &SecuritySchemeDetails, - request: &RichRequest, - ) -> GatewayHttpResult { - self.auth_call_back_binding_handler - .handle_auth_call_back(&request.query_params(), security_scheme) - .await - .map_err(GatewayHttpError::AuthorisationError) - } - - async fn evaluate_worker_name_rib_script( - &self, - script: WorkerNameCompiled, - request: &mut RichRequest, - ) -> GatewayHttpResult { - let WorkerNameCompiled { - compiled_worker_name, - rib_input: rib_input_type_info, - .. - } = script; - - let rib_input: RibInput = resolve_rib_input(request, &rib_input_type_info).await?; - - let result = rib::interpret_pure(compiled_worker_name, rib_input, None) - .await - .map_err(|err| GatewayHttpError::RibInterpretPureError(err.to_string()))? - .get_literal() - .ok_or(GatewayHttpError::BadRequest( - "Worker name is not a Rib expression that resolves to String".to_string(), - ))? - .as_string(); - - Ok(result) - } - - async fn evaluate_idempotency_key_rib_script( - &self, - script: IdempotencyKeyCompiled, - request: &mut RichRequest, - ) -> GatewayHttpResult { - let IdempotencyKeyCompiled { - compiled_idempotency_key, - rib_input, - .. - } = script; - - let rib_input: RibInput = resolve_rib_input(request, &rib_input).await?; - - let value = rib::interpret_pure(compiled_idempotency_key, rib_input, None) - .await - .map_err(|err| GatewayHttpError::RibInterpretPureError(err.to_string()))? - .get_literal() - .ok_or(GatewayHttpError::BadRequest( - "Idempotency key is not a Rib expression that resolves to String".to_string(), - ))? - .as_string(); - - Ok(IdempotencyKey::new(value)) - } - - async fn evaluate_invocation_context_rib_script( - &self, - script: InvocationContextCompiled, - request: &mut RichRequest, - ) -> GatewayHttpResult<(Option, HashMap)> { - let InvocationContextCompiled { - compiled_invocation_context, - rib_input, - .. - } = script; - - let rib_input: RibInput = resolve_rib_input(request, &rib_input).await?; - - let value = rib::interpret_pure(compiled_invocation_context, rib_input, None) - .await - .map_err(|err| GatewayHttpError::RibInterpretPureError(err.to_string()))? - .get_record() - .ok_or(GatewayHttpError::BadRequest( - "Invocation context must be a Rib expression that resolves to record".to_string(), - ))?; - let record: HashMap = HashMap::from_iter(value); - - let trace_id = record - .get("trace_id") - .or(record.get("trace-id")) - .map(to_attribute_value) - .transpose()? - .map(TraceId::from_attribute_value) - .transpose() - .map_err(|err| GatewayHttpError::BadRequest(format!("Invalid Trace ID: {err}")))?; - - Ok((trace_id, record)) - } - - fn materialize_user_invocation_context( - record: HashMap, - parent: Option>, - request_attributes: HashMap, - ) -> GatewayHttpResult> { - let span_id = record - .get("span_id") - .or(record.get("span-id")) - .map(to_attribute_value) - .transpose()? - .map(SpanId::from_attribute_value) - .transpose() - .map_err(|err| GatewayHttpError::BadRequest(format!("Invalid Span ID: {err}")))?; - - let span = InvocationContextSpan::local() - .span_id(span_id) - .parent(parent) - .with_attributes(request_attributes) - .build(); - - for (key, value) in record { - if key != "span_id" && key != "span-id" && key != "trace_id" && key != "trace-id" { - span.set_attribute(key, to_attribute_value(&value)?); - } - } - - Ok(span) - } - - async fn get_worker_details( - &self, - request: &mut RichRequest, - worker_name_compiled: Option, - idempotency_key_compiled: Option, - component_id: ComponentId, - component_revision: ComponentRevision, - invocation_context_compiled: Option, - ) -> GatewayHttpResult { - let worker_name = if let Some(worker_name_compiled) = worker_name_compiled { - let result = self - .evaluate_worker_name_rib_script(worker_name_compiled, request) - .await?; - Some(result) - } else { - None - }; - - // We prefer to take the idempotency key from the rib script, - // if that is not available, we fall back to our custom header. - // If neither is available, the worker-executor will later generate an idempotency key. - let idempotency_key = if let Some(idempotency_key_compiled) = idempotency_key_compiled { - let result = self - .evaluate_idempotency_key_rib_script(idempotency_key_compiled, request) - .await?; - Some(result) - } else { - request - .underlying - .headers() - .get("idempotency-key") - .and_then(|h| h.to_str().ok()) - .map(|value| IdempotencyKey::new(value.to_string())) - }; - - let invocation_context = if let Some(invocation_context_compiled) = - invocation_context_compiled - { - let request_attributes = extract_request_attributes(&request.underlying); - - let trace_context_headers = TraceContextHeaders::parse(request.underlying.headers()); - - let (user_defined_trace_id, user_defined_span) = self - .evaluate_invocation_context_rib_script(invocation_context_compiled, request) - .await?; - - match (trace_context_headers, &user_defined_trace_id) { - (Some(ctx), None) => { - // Trace context found in headers and not overridden, starting a new span in it - let mut ctx = InvocationContextStack::new( - ctx.trace_id, - InvocationContextSpan::external_parent(ctx.parent_id), - ctx.trace_states, - ); - let user_defined_span = Self::materialize_user_invocation_context( - user_defined_span, - Some(ctx.spans.first().clone()), - request_attributes, - )?; - ctx.push(user_defined_span); - ctx - } - (_, Some(trace_id)) => { - // Forced a new trace, ignoring the trace context in the headers - let user_defined_span = Self::materialize_user_invocation_context( - user_defined_span, - None, - request_attributes, - )?; - InvocationContextStack::new(trace_id.clone(), user_defined_span, Vec::new()) - } - (None, _) => { - // No trace context in headers, starting a new trace - let user_defined_span = Self::materialize_user_invocation_context( - user_defined_span, - None, - request_attributes, - )?; - InvocationContextStack::new( - user_defined_trace_id.unwrap_or_else(TraceId::generate), - user_defined_span, - Vec::new(), - ) - } - } - } else { - invocation_context_from_request(&request.underlying) - }; - - Ok(WorkerDetails { - component_id, - component_revision, - worker_name, - idempotency_key, - invocation_context, - }) - } - - async fn execute_response_mapping_script( - &self, - compiled_response_mapping: ResponseMappingCompiled, - request: &mut RichRequest, - worker_detail: WorkerDetails, - account_id: AccountId, - ) -> GatewayHttpResult { - let WorkerDetails { - invocation_context, - idempotency_key, - .. - } = worker_detail; - - let ResponseMappingCompiled { - response_mapping_compiled, - rib_input, - .. - } = compiled_response_mapping; - - let rib_input = resolve_rib_input(request, &rib_input).await?; - - self.gateway_worker_request_executor - .evaluate_rib( - idempotency_key, - invocation_context, - account_id, - response_mapping_compiled, - rib_input, - ) - .await - .map_err(GatewayHttpError::EvaluationError) - } - - async fn apply_middlewares_in( - &self, - mut request: RichRequest, - middlewares: &Vec, - ) -> Result { - let input_middleware_result = process_middleware_in( - middlewares, - &request, - &self.gateway_session_store, - &self.identity_provider, - ) - .await; - - let input_middleware_result = match input_middleware_result { - Ok(MiddlewareSuccess::PassThrough { - session_id: session_id_opt, - }) => { - if let Some(session_id) = session_id_opt.as_ref() { - let result = request - .add_auth_details(session_id, &self.gateway_session_store) - .await; - - if let Err(err_response) = result { - Err(MiddlewareError::InternalError(err_response)) - } else { - Ok(MiddlewareSuccess::PassThrough { - session_id: session_id_opt, - }) - } - } else { - Ok(MiddlewareSuccess::PassThrough { - session_id: session_id_opt, - }) - } - } - other => other, - }; - - match input_middleware_result { - Ok(MiddlewareSuccess::Redirect(response)) => Err(response)?, - Ok(MiddlewareSuccess::PassThrough { .. }) => Ok(request), - Err(err) => { - error!("Middleware error: {}", err.to_safe_string()); - let response = err.to_response_from_safe_display(|error| match error { - MiddlewareError::InternalError(_) => StatusCode::INTERNAL_SERVER_ERROR, - MiddlewareError::Unauthorized(_) => StatusCode::UNAUTHORIZED, - MiddlewareError::CorsError(_) => StatusCode::FORBIDDEN, - }); - Err(response)? - } - } - } -} - -async fn resolve_rib_input( - rich_request: &mut RichRequest, - required_types: &RibInputTypeInfo, -) -> Result { - let mut values: Vec = vec![]; - let mut types: Vec = vec![]; - - let request_analysed_type = required_types.types.get("request"); - - match request_analysed_type { - Some(AnalysedType::Record(type_record)) => { - for record in type_record.fields.iter() { - let field_name = record.name.as_str(); - - types.push(NameTypePair { - name: field_name.to_string(), - typ: record.typ.clone(), - }); - - match field_name { - "body" => { - let body = rich_request.request_body().await.map_err(|err| { - GatewayHttpError::BadRequest(format!( - "invalid http request body. {err}" - )) - })?; - - let body_value = - ValueAndType::parse_with_type(body, &record.typ).map_err(|err| { - GatewayHttpError::BadRequest(format!( - "invalid http request body\n{}\nexpected request body: {}", - err.join("\n"), - TypeName::try_from(record.typ.clone()) - .map(|x| x.to_string()) - .unwrap_or_else(|_| format!("{:?}", &record.typ)) - )) - })?; - - values.push(body_value.value); - } - "headers" | "header" => { - let header_values = get_wasm_rpc_value_for_primitives( - &record.typ, - rich_request, - &|request, key| { - request - .headers() - .get(key) - .map(|x| x.to_str().unwrap().to_string()) - .ok_or(format!("missing header: {}", &key)) - }, - ) - .map_err(|err| { - GatewayHttpError::BadRequest(format!( - "invalid http request header. {err}" - )) - })?; - - values.push(header_values); - } - "query" => { - let query_value = get_wasm_rpc_value_for_primitives( - &record.typ, - rich_request, - &|request, key| { - request - .query_params() - .get(key) - .map(|x| x.to_string()) - .ok_or(format!("Missing query parameter: {key}")) - }, - ) - .map_err(|err| { - GatewayHttpError::BadRequest(format!( - "invalid http request query. {err}" - )) - })?; - - values.push(query_value); - } - "path" => { - let path_values = get_wasm_rpc_value_for_primitives( - &record.typ, - rich_request, - &|request, key| { - request - .path_params() - .get(key) - .map(|x| x.to_string()) - .ok_or(format!("Missing path parameter: {key}")) - }, - ) - .map_err(|err| { - GatewayHttpError::BadRequest(format!( - "invalid http request path. {err}" - )) - })?; - - values.push(path_values); - } - - "auth" => { - let auth_data = - rich_request - .auth_data() - .ok_or(GatewayHttpError::BadRequest( - "missing auth data".to_string(), - ))?; - - let auth_value = ValueAndType::parse_with_type(auth_data, &record.typ) - .map_err(|err| { - GatewayHttpError::BadRequest(format!( - "invalid auth data\n{}\nexpected auth: {}", - err.join("\n"), - TypeName::try_from(record.typ.clone()) - .map(|x| x.to_string()) - .unwrap_or_else(|_| format!("{:?}", &record.typ)) - )) - })?; - - values.push(auth_value.value); - } - - "request_id" => { - // Limitation of the current GlobalVariableTypeSpec. We cannot tell rib to directly the type of this field, only of all children. - // Add a dummy value field that needs to be used so inference works. - let value_and_type = RequestIdContainer { - value: rich_request.request_id, - } - .into_value_and_type(); - let expected_type = value_and_type.typ.with_optional_name(None); - - if record.typ != expected_type { - return Err(GatewayHttpError::InternalError(format!( - "invalid expected rib script input type for request.request_id: {:?}; Should be: {:?}", - record.typ, - expected_type - ))); - } - - values.push(value_and_type.value); - } - - field_name => { - // This is already type checked during API registration, - // however we still fail if we happen to have other inputs - // at this stage instead of silently ignoring them. - return Err(GatewayHttpError::InternalError(format!( - "invalid rib script with unknown input: request.{field_name}" - ))); - } - } - } - - let mut result_map: HashMap = HashMap::new(); - - result_map.insert( - "request".to_string(), - ValueAndType::new(golem_wasm::Value::Record(values), record(types)), - ); - - Ok(RibInput { input: result_map }) - } - - Some(_) => Err(GatewayHttpError::InternalError( - "invalid rib script with unsupported type for `request`".to_string(), - )), - - None => Ok(RibInput::default()), - } -} - -async fn apply_middlewares_out( - mut response: poem::Response, - middlewares: &Vec, -) -> poem::Response { - let result = process_middleware_out(middlewares, &mut response).await; - match result { - Ok(_) => response, - Err(err) => { - error!("Middleware error: {}", err.to_safe_string()); - err.to_response_from_safe_display(|_| StatusCode::INTERNAL_SERVER_ERROR) - } - } -} - -fn to_attribute_value(value: &ValueAndType) -> GatewayHttpResult { - match &value.value { - golem_wasm::Value::String(value) => Ok(AttributeValue::String(value.clone())), - _ => Err(GatewayHttpError::BadRequest( - "Invocation context values must be string".to_string(), - )), - } -} - -/// Map against the required types and get `wasm_rpc::Value` from http request -/// # Parameters -/// - `analysed_type: &AnalysedType` -/// - RibInput requirement follows a pseudo form like `{request : {headers: record-type, query: record-type, path: record-type, body: analysed-type}}`. -/// - The `analysed_type` here is the type of headers, query, or path (and not body). i.e, `record-type` in the above pseudo form. -/// - This `record-type` is expected to have primitive field types. Example for a Rib `request.path.user-id` `user-id` is some primitive and `path` should be hence a record. -/// - This analysed doesn't handle (or shouldn't correspond to) the `body` field because it can be anything and not a record of primitives -/// - `request: RichRequest` -/// - The incoming request from the client -/// - `fetch_input: &FnOnce(RichRequest) -> String`, making sure we fetch anything out of the request only if it is needed -/// -fn get_wasm_rpc_value_for_primitives( - required_type: &AnalysedType, - request: &RichRequest, - fetch_key_value: &F, -) -> Result -where - F: Fn(&RichRequest, &String) -> Result, -{ - let mut header_values: Vec = vec![]; - - if let AnalysedType::Record(record_type) = required_type { - for field in record_type.fields.iter() { - let typ = &field.typ; - - let header_value = fetch_key_value(request, &field.name)?; - - let value_and_type = match typ { - AnalysedType::Str(_) => { - parse_to_value::(field.name.clone(), header_value, "string")? - } - AnalysedType::Bool(_) => { - parse_to_value::(field.name.clone(), header_value, "bool")? - } - AnalysedType::U8(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::U16(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::U32(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::U64(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::S8(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::S16(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::S32(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::S64(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::F32(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - AnalysedType::F64(_) => { - parse_to_value::(field.name.clone(), header_value, "number")? - } - _ => { - return Err(format!("Invalid type: {}", field.name)); - } - }; - - header_values.push(value_and_type); - } - } - - Ok(golem_wasm::Value::Record(header_values)) -} - -fn parse_to_value( - field_name: String, - field_value: String, - type_name: &str, -) -> Result { - let value = field_value.parse::().map_err(|_| { - format!("Invalid value for key {field_name}. Expected {type_name}, Found {field_value}") - })?; - Ok(value.into_value_and_type().value) -} - -#[derive(golem_wasm_derive::IntoValue)] -struct RequestIdContainer { - value: Uuid, -} diff --git a/golem-worker-service/src/gateway_execution/gateway_session_store.rs b/golem-worker-service/src/gateway_execution/gateway_session_store.rs deleted file mode 100644 index e00505ff98..0000000000 --- a/golem-worker-service/src/gateway_execution/gateway_session_store.rs +++ /dev/null @@ -1,382 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use async_trait::async_trait; -use bytes::Bytes; -use desert_rust::BinaryCodec; -use golem_common::redis::{RedisError, RedisPool}; -use golem_common::SafeDisplay; -use golem_service_base::db::sqlite::SqlitePool; -use sqlx::Row; -use std::collections::HashMap; -use std::hash::Hash; -use std::time::Duration; -use tokio::task; -use tokio::time::interval; -use tracing::{error, info, Instrument}; - -#[async_trait] -pub trait GatewaySessionStore: Send + Sync { - async fn insert( - &self, - session_id: SessionId, - data_key: DataKey, - data_value: DataValue, - ) -> Result<(), GatewaySessionError>; - - async fn get( - &self, - session_id: &SessionId, - data_key: &DataKey, - ) -> Result; -} - -#[derive(Debug, Clone)] -pub enum GatewaySessionError { - InternalError(String), - MissingValue { - session_id: SessionId, - data_key: DataKey, - }, -} - -impl SafeDisplay for GatewaySessionError { - fn to_safe_string(&self) -> String { - match self { - GatewaySessionError::InternalError(e) => format!("Internal error: {e}"), - GatewaySessionError::MissingValue { session_id, .. } => { - format!("Invalid session {}", session_id.0) - } - } - } -} - -#[derive(Debug, Hash, PartialEq, Eq, Clone)] -pub struct SessionId(pub String); - -#[derive(Debug, Hash, PartialEq, Eq, Clone)] -pub struct DataKey(pub String); - -impl DataKey { - pub fn nonce() -> DataKey { - DataKey("nonce".to_string()) - } - - pub fn access_token() -> DataKey { - DataKey("access_token".to_string()) - } - - pub fn id_token() -> DataKey { - DataKey("id_token".to_string()) - } - - pub fn claims() -> DataKey { - DataKey("claims".to_string()) - } - - pub fn redirect_url() -> DataKey { - DataKey("redirect_url".to_string()) - } -} - -#[derive(Debug, Eq, PartialEq, Clone, BinaryCodec)] -#[desert(transparent)] -pub struct DataValue(pub serde_json::Value); - -impl DataValue { - pub fn as_string(&self) -> Option { - self.0.as_str().map(|s| s.to_string()) - } -} - -#[derive(Clone)] -pub struct SessionData { - pub value: HashMap, -} - -#[derive(Clone)] -pub struct RedisGatewaySession { - redis: RedisPool, - expiration: RedisGatewaySessionExpiration, -} - -impl RedisGatewaySession { - pub fn new(redis: RedisPool, expiration: RedisGatewaySessionExpiration) -> Self { - Self { redis, expiration } - } - - pub fn redis_key(session_id: &SessionId) -> String { - format!("gateway_session:{}", session_id.0) - } -} - -#[derive(Clone)] -pub struct RedisGatewaySessionExpiration { - pub session_expiry: Duration, -} - -impl RedisGatewaySessionExpiration { - pub fn new(session_expiry: Duration) -> Self { - Self { session_expiry } - } -} - -impl Default for RedisGatewaySessionExpiration { - fn default() -> Self { - Self::new(Duration::from_secs(60 * 60)) - } -} - -#[async_trait] -impl GatewaySessionStore for RedisGatewaySession { - async fn insert( - &self, - session_id: SessionId, - data_key: DataKey, - data_value: DataValue, - ) -> Result<(), GatewaySessionError> { - let serialised = golem_common::serialization::serialize(&data_value) - .map_err(|e| GatewaySessionError::InternalError(e.to_string()))?; - - let result: Result<(), RedisError> = self - .redis - .with("gateway_session", "insert") - .hset( - Self::redis_key(&session_id), - (data_key.0.as_str(), serialised), - ) - .await; - - result.map_err(|e| { - error!("Failed to insert session data into Redis: {}", e); - GatewaySessionError::InternalError(e.to_string()) - })?; - - self.redis - .with("gateway_session", "insert") - .expire( - Self::redis_key(&session_id), - self.expiration.session_expiry.as_secs() as i64, - ) - .await - .map_err(|e| { - error!("Failed to set expiry on session data in Redis: {}", e); - GatewaySessionError::InternalError(e.to_string()) - }) - } - - async fn get( - &self, - session_id: &SessionId, - data_key: &DataKey, - ) -> Result { - let result: Option = self - .redis - .with("gateway_session", "get_data_value") - .hget(Self::redis_key(session_id), data_key.0.as_str()) - .await - .map_err(|e| { - error!("Failed to get session data from Redis: {}", e); - GatewaySessionError::InternalError(e.to_string()) - })?; - - if let Some(result) = result { - let data_value: DataValue = golem_common::serialization::deserialize(&result) - .map_err(|e| GatewaySessionError::InternalError(e.to_string()))?; - - Ok(data_value) - } else { - Err(GatewaySessionError::MissingValue { - session_id: session_id.clone(), - data_key: data_key.clone(), - }) - } - } -} - -#[derive(Debug, Clone)] -pub struct SqliteGatewaySession { - pool: SqlitePool, - expiration: SqliteGatewaySessionExpiration, -} - -#[derive(Debug, Clone)] -pub struct SqliteGatewaySessionExpiration { - pub session_expiry: Duration, - pub cleanup_interval: Duration, -} - -impl SqliteGatewaySessionExpiration { - pub fn new(session_expiry: Duration, cleanup_interval: Duration) -> Self { - Self { - session_expiry, - cleanup_interval, - } - } -} - -impl Default for SqliteGatewaySessionExpiration { - fn default() -> Self { - Self::new(Duration::from_secs(60 * 60), Duration::from_secs(60)) - } -} - -impl SqliteGatewaySession { - pub async fn new( - pool: SqlitePool, - expiration: SqliteGatewaySessionExpiration, - ) -> Result { - let result = Self { pool, expiration }; - - result.init().await?; - - let cloned_session = result.clone(); - - Self::spawn_expiration_task( - cloned_session.expiration.cleanup_interval, - cloned_session.pool, - ); - - Ok(result) - } - - async fn init(&self) -> Result<(), String> { - self.pool - .with_rw("gateway_session", "init") - .execute(sqlx::query( - r#" - CREATE TABLE IF NOT EXISTS gateway_session ( - session_id TEXT NOT NULL, - data_key TEXT NOT NULL, - data_value BLOB NOT NULL, - expiry_time INTEGER NOT NULL, - PRIMARY KEY (session_id, data_key) - ); - "#, - )) - .await - .map_err(|err| err.to_safe_string())?; - - info!("Initialized gateway session SQLite table"); - - Ok(()) - } - - pub fn spawn_expiration_task(cleanup_internal: Duration, db_pool: SqlitePool) { - task::spawn( - async move { - let mut cleanup_interval = interval(cleanup_internal); - - loop { - cleanup_interval.tick().await; - - if let Err(e) = - Self::cleanup_expired(db_pool.clone(), Self::current_time()).await - { - error!("Failed to expire sessions: {}", e); - } - } - } - .in_current_span(), - ); - } - - pub async fn cleanup_expired(pool: SqlitePool, current_time: i64) -> Result<(), String> { - let query = - sqlx::query("DELETE FROM gateway_session WHERE expiry_time < ?;").bind(current_time); - - pool.with_rw("gateway_session", "cleanup_expired") - .execute(query) - .await - .map(|_| ()) - .map_err(|err| err.to_safe_string()) - } - - pub fn current_time() -> i64 { - chrono::Utc::now().timestamp() - } -} - -#[async_trait] -impl GatewaySessionStore for SqliteGatewaySession { - async fn insert( - &self, - session_id: SessionId, - data_key: DataKey, - data_value: DataValue, - ) -> Result<(), GatewaySessionError> { - let expiry_time = Self::current_time() + self.expiration.session_expiry.as_secs() as i64; - - let serialized_value: &[u8] = &golem_common::serialization::serialize(&data_value) - .map_err(|e| GatewaySessionError::InternalError(e.to_string()))?; - - let result = self - .pool - .with_rw("gateway_session", "insert") - .execute( - sqlx::query( - r#" - INSERT INTO gateway_session (session_id, data_key, data_value, expiry_time) - VALUES (?, ?, ?, ?); - "#, - ) - .bind(session_id.0) - .bind(data_key.0) - .bind(serialized_value) - .bind(expiry_time), - ) - .await; - - result.map_err(|e| { - error!("Failed to insert session data into SQLite: {}", e); - GatewaySessionError::InternalError(e.to_string()) - })?; - - Ok(()) - } - - async fn get( - &self, - session_id: &SessionId, - data_key: &DataKey, - ) -> Result { - let query = sqlx::query( - "SELECT data_value FROM gateway_session WHERE session_id = ? AND data_key = ?;", - ) - .bind(&session_id.0) - .bind(&data_key.0); - - let result = self - .pool - .with_ro("gateway_sesssion", "get") - .fetch_optional(query) - .await - .map_err(|e| GatewaySessionError::InternalError(e.to_string()))?; - - match result { - Some(row) => { - let row = row.get::, _>(0); - - let data_value = golem_common::serialization::deserialize(&row) - .map_err(|e| GatewaySessionError::InternalError(e.to_string()))?; - - Ok(data_value) - } - None => Err(GatewaySessionError::MissingValue { - session_id: session_id.clone(), - data_key: data_key.clone(), - }), - } - } -} diff --git a/golem-worker-service/src/gateway_execution/gateway_worker_request_executor.rs b/golem-worker-service/src/gateway_execution/gateway_worker_request_executor.rs deleted file mode 100644 index 26e2b47db1..0000000000 --- a/golem-worker-service/src/gateway_execution/gateway_worker_request_executor.rs +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::gateway_execution::GatewayResolvedWorkerRequest; -use crate::service::component::ComponentService; -use crate::service::worker::WorkerService; -use async_trait::async_trait; -use golem_common::model::account::AccountId; -use golem_common::model::agent::{AgentId, AgentMode, AgentTypeName}; -use golem_common::model::component::{ComponentId, ComponentRevision}; -use golem_common::model::invocation_context::InvocationContextStack; -use golem_common::model::{IdempotencyKey, WorkerId}; -use golem_common::SafeDisplay; -use golem_service_base::model::auth::AuthCtx; -use golem_wasm::analysis::AnalysedType; -use golem_wasm::ValueAndType; -use rib::InstructionId; -use rib::{ - ComponentDependencyKey, EvaluatedFnArgs, EvaluatedFqFn, EvaluatedWorkerName, RibByteCode, - RibComponentFunctionInvoke, RibFunctionInvokeResult, RibInput, RibResult, -}; -use std::collections::BTreeMap; -use std::fmt::Display; -use std::sync::Arc; -use tracing::debug; -use uuid::Uuid; - -pub struct GatewayWorkerRequestExecutor { - worker_service: Arc, - component_service: Arc, -} - -impl GatewayWorkerRequestExecutor { - pub fn new( - worker_service: Arc, - component_service: Arc, - ) -> Self { - Self { - worker_service, - component_service, - } - } - - pub async fn evaluate_rib( - self: &Arc, - idempotency_key: Option, - invocation_context: InvocationContextStack, - account_id: AccountId, - expr: RibByteCode, - rib_input: RibInput, - ) -> Result { - let worker_invoke_function: Arc = - Arc::new(self.rib_invoke(idempotency_key, invocation_context, account_id)); - - let result = rib::interpret(expr, rib_input, worker_invoke_function, None) - .await - .map_err(|err| WorkerRequestExecutorError(err.to_string()))?; - Ok(result) - } - - pub async fn execute( - &self, - resolved_worker_request: GatewayResolvedWorkerRequest, - account_id: AccountId, - ) -> Result, WorkerRequestExecutorError> { - let component = self - .component_service - .get_revision( - resolved_worker_request.component_id, - resolved_worker_request.component_revision, - ) - .await - .map_err(|err| WorkerRequestExecutorError(err.to_safe_string()))?; - - let mut worker_name = resolved_worker_request.worker_name; - - if component.metadata.is_agent() { - let agent_type_name = AgentId::parse_agent_type_name(&worker_name) - .map_err(|err| WorkerRequestExecutorError(format!("Invalid agent ID: {err}")))?; - let agent_type = component - .metadata - .find_agent_type_by_wrapper_name(&AgentTypeName(agent_type_name.to_string())) - .map_err(|err| { - WorkerRequestExecutorError(format!("Failed to extract agent type: {err}")) - })? - .ok_or_else(|| WorkerRequestExecutorError("Agent type not found".to_string()))?; - - if agent_type.mode == AgentMode::Ephemeral { - let phantom_id = Uuid::new_v4(); - let phantom_id_postfix = format!("[{phantom_id}]"); - worker_name.push_str(&phantom_id_postfix); - } - } - - let worker_id = WorkerId::from_component_metadata_and_worker_id( - component.id, - &component.metadata, - worker_name, - )?; - - debug!( - component_id = resolved_worker_request.component_id.to_string(), - function_name = resolved_worker_request.function_name, - worker_name = worker_id.worker_name.clone(), - "Executing invocation", - ); - - let result = self - .worker_service - .invoke_and_await_typed( - &worker_id, - resolved_worker_request.idempotency_key, - resolved_worker_request.function_name.to_string(), - resolved_worker_request.function_params, - Some(golem_api_grpc::proto::golem::worker::InvocationContext { - parent: None, - env: Default::default(), - wasi_config_vars: Some(BTreeMap::new().into()), - tracing: Some(resolved_worker_request.invocation_context.into()), - }), - AuthCtx::impersonated_user(account_id), - ) - .await - .map_err(|e| format!("Error when executing resolved worker request. Error: {e}"))?; - - Ok(result) - } - - fn rib_invoke( - self: &Arc, - idempotency_key: Option, - invocation_context: InvocationContextStack, - account_id: AccountId, - ) -> WorkerRequestExecutorRibInvoke { - WorkerRequestExecutorRibInvoke { - idempotency_key, - invocation_context, - executor: self.clone(), - account_id, - } - } -} - -#[derive(Clone, Debug)] -pub struct WorkerRequestExecutorError(pub String); - -impl std::error::Error for WorkerRequestExecutorError {} - -impl Display for WorkerRequestExecutorError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -impl> From for WorkerRequestExecutorError { - fn from(err: T) -> Self { - WorkerRequestExecutorError(err.as_ref().to_string()) - } -} - -impl SafeDisplay for WorkerRequestExecutorError { - fn to_safe_string(&self) -> String { - self.0.clone() - } -} - -struct WorkerRequestExecutorRibInvoke { - executor: Arc, - idempotency_key: Option, - invocation_context: InvocationContextStack, - account_id: AccountId, -} - -#[async_trait] -impl RibComponentFunctionInvoke for WorkerRequestExecutorRibInvoke { - async fn invoke( - &self, - component_dependency_key: ComponentDependencyKey, - _instruction_id: &InstructionId, - worker_name: EvaluatedWorkerName, - function_name: EvaluatedFqFn, - parameters: EvaluatedFnArgs, - _return_type: Option, - ) -> RibFunctionInvokeResult { - let worker_name = worker_name.0; - - let idempotency_key = self.idempotency_key.clone(); - let invocation_context = self.invocation_context.clone(); - let executor = self.executor.clone(); - - let function_name = function_name.0; - let function_params: Vec = parameters.0; - - let component_id = ComponentId(component_dependency_key.component_id); - let component_revision: ComponentRevision = - component_dependency_key.component_revision.try_into()?; - - let worker_request = GatewayResolvedWorkerRequest { - component_id, - component_revision, - worker_name, - function_name, - function_params, - idempotency_key, - invocation_context, - }; - - let result = executor.execute(worker_request, self.account_id).await?; - Ok(result) - } -} diff --git a/golem-worker-service/src/gateway_execution/http_content_type_mapper.rs b/golem-worker-service/src/gateway_execution/http_content_type_mapper.rs deleted file mode 100644 index 152dec2850..0000000000 --- a/golem-worker-service/src/gateway_execution/http_content_type_mapper.rs +++ /dev/null @@ -1,956 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use golem_wasm::analysis::AnalysedType; -use golem_wasm::ValueAndType; -use mime::Mime; -use poem::web::headers::ContentType; -use poem::web::WithContentType; -use poem::Body; -use std::fmt::{Display, Formatter}; -use std::str::FromStr; - -pub trait HttpContentTypeResponseMapper { - fn to_http_resp_with_content_type( - &self, - content_type_headers: ContentTypeHeaders, - ) -> Result, ContentTypeMapError>; -} - -#[derive(Debug, Clone)] -pub enum ContentTypeHeaders { - FromClientAccept(AcceptHeaders), - FromUserDefinedResponseMapping(ContentType), - Empty, -} - -#[derive(Clone, Debug)] -pub struct AcceptHeaders(Vec); - -pub trait ContentTypeHeaderExt { - fn has_application_json(&self) -> bool; - fn response_content_type(&self) -> Result; -} - -impl ContentTypeHeaderExt for ContentType { - fn has_application_json(&self) -> bool { - self == &ContentType::json() - } - fn response_content_type(&self) -> Result { - Ok(self.clone()) - } -} -impl ContentTypeHeaderExt for AcceptHeaders { - fn has_application_json(&self) -> bool { - self.0.iter().any(|v| { - if let Ok(mime) = Mime::from_str(v) { - matches!( - (mime.type_(), mime.subtype()), - (mime::APPLICATION, mime::JSON) - | (mime::APPLICATION, mime::STAR) - | (mime::STAR, mime::STAR) - | (mime::STAR, mime::JSON) - ) - } else { - false - } - }) - } - - fn response_content_type(&self) -> Result { - internal::pick_highest_priority_content_type(self) - } -} - -impl AcceptHeaders { - fn from_str>(input: A) -> AcceptHeaders { - let headers = input - .as_ref() - .split(',') - .map(|v| v.trim().to_string()) - .collect::>(); - - AcceptHeaders(headers) - } -} - -impl Display for AcceptHeaders { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self.0.join(", ")) - } -} - -impl AcceptHeaders { - fn contains(&self, content_type: &ContentType) -> bool { - self.0 - .iter() - .any(|accept_header| accept_header.contains(content_type.to_string().as_str())) - } -} - -impl ContentTypeHeaders { - pub fn from( - response_content_type: Option, - accepted_content_types: Option, - ) -> Self { - if let Some(response_content_type) = response_content_type { - ContentTypeHeaders::FromUserDefinedResponseMapping(response_content_type) - } else if let Some(accept_header_string) = accepted_content_types { - ContentTypeHeaders::FromClientAccept(AcceptHeaders::from_str(accept_header_string)) - } else { - ContentTypeHeaders::Empty - } - } -} - -impl HttpContentTypeResponseMapper for ValueAndType { - fn to_http_resp_with_content_type( - &self, - content_type_headers: ContentTypeHeaders, - ) -> Result, ContentTypeMapError> { - match content_type_headers { - ContentTypeHeaders::FromUserDefinedResponseMapping(content_type) => { - internal::get_response_body_based_on_content_type(self, &content_type) - } - ContentTypeHeaders::FromClientAccept(accept_content_headers) => { - internal::get_response_body_based_on_content_type(self, &accept_content_headers) - } - ContentTypeHeaders::Empty => internal::get_response_body(self), - } - } -} - -#[derive(PartialEq, Debug)] -pub enum ContentTypeMapError { - IllegalMapping { - input_type: AnalysedType, - expected_content_types: String, - }, - InternalError(String), -} - -impl ContentTypeMapError { - fn internal>(msg: A) -> ContentTypeMapError { - ContentTypeMapError::InternalError(msg.as_ref().to_string()) - } - - fn illegal_mapping( - input_type: &AnalysedType, - expected_content_types: &A, - ) -> ContentTypeMapError { - ContentTypeMapError::IllegalMapping { - input_type: input_type.clone(), - expected_content_types: expected_content_types.to_string(), - } - } -} - -impl Display for ContentTypeMapError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ContentTypeMapError::InternalError(message) => { - write!(f, "{message}") - } - ContentTypeMapError::IllegalMapping { - input_type, - expected_content_types, - } => { - write!( - f, - "Failed to map input type {input_type:?} to any of the expected content types: {expected_content_types:?}" - ) - } - } - } -} - -mod internal { - use crate::gateway_execution::http_content_type_mapper::{ - AcceptHeaders, ContentTypeHeaderExt, ContentTypeMapError, - }; - use golem_wasm::analysis::{AnalysedType, TypeEnum, TypeList, TypeOption, TypeRecord}; - use golem_wasm::json::ValueAndTypeJsonExtensions; - use golem_wasm::{Value, ValueAndType}; - use poem::web::headers::ContentType; - use poem::web::WithContentType; - use poem::{Body, IntoResponse}; - use std::fmt::Display; - - pub(crate) fn get_response_body_based_on_content_type( - value_and_type: &ValueAndType, - content_header: &A, - ) -> Result, ContentTypeMapError> { - match (&value_and_type.typ, &value_and_type.value) { - (AnalysedType::Record(_record), Value::Record(_values)) => { - handle_record(value_and_type, content_header) - } - (AnalysedType::Variant(_variant), Value::Variant { .. }) => { - handle_record(value_and_type, content_header) - } - (AnalysedType::List(TypeList { inner, .. }), Value::List(values)) => { - handle_list(value_and_type, values, inner, content_header) - } - (AnalysedType::Bool(_), Value::Bool(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::S8(_), Value::S8(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::U8(_), Value::U8(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::S16(_), Value::S16(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::U16(_), Value::U16(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::S32(_), Value::S32(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::U32(_), Value::U32(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::S64(_), Value::S64(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::U64(_), Value::U64(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::F32(_), Value::F32(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::F64(_), Value::F64(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::Chr(_), Value::Char(value)) => { - handle_primitive(value, &value_and_type.typ, content_header) - } - (AnalysedType::Str(_), Value::String(string)) => handle_string(string, content_header), - (AnalysedType::Tuple(_), Value::Tuple(_)) => { - handle_complex(value_and_type, content_header) - } - (AnalysedType::Flags(_), Value::Flags(_)) => { - handle_complex(value_and_type, content_header) - } - // Can be considered as a record - (AnalysedType::Result(_), Value::Result(_)) => { - handle_complex(value_and_type, content_header) - } - (AnalysedType::Handle(_), Value::Handle { .. }) => { - handle_complex(value_and_type, content_header) - } - (AnalysedType::Enum(TypeEnum { cases, .. }), Value::Enum(name_idx)) => { - let name = cases - .get(*name_idx as usize) - .ok_or(ContentTypeMapError::internal("Invalid enum index"))?; - handle_string(name, content_header) - } - (AnalysedType::Option(TypeOption { inner, .. }), Value::Option(value)) => match value { - Some(value) => { - let value_and_type = ValueAndType::new((**value).clone(), (**inner).clone()); - get_response_body_based_on_content_type(&value_and_type, content_header) - } - None => { - if content_header.has_application_json() { - get_json_null() - } else { - let typ = value_and_type.typ.clone(); - Err(ContentTypeMapError::illegal_mapping(&typ, content_header)) - } - } - }, - _ => Err(ContentTypeMapError::InternalError( - "Value and type mismatch".to_string(), - )), - } - } - - pub(crate) fn get_response_body( - value_and_type: &ValueAndType, - ) -> Result, ContentTypeMapError> { - match (&value_and_type.typ, &value_and_type.value) { - (AnalysedType::Record(TypeRecord { .. }), Value::Record { .. }) => { - get_json(value_and_type) - } - (AnalysedType::List(TypeList { inner, .. }), Value::List(values)) => match &**inner { - AnalysedType::U8(_) => get_byte_stream_body(values), - _ => get_json(value_and_type), - }, - (AnalysedType::Str(_), Value::String(string)) => { - Ok(Body::from_string(string.to_string()) - .with_content_type(ContentType::json().to_string())) - } - (AnalysedType::Enum(TypeEnum { cases, .. }), Value::Enum(case_idx)) => { - let case_name = cases - .get(*case_idx as usize) - .ok_or(ContentTypeMapError::internal("Invalid enum index"))?; - Ok(Body::from_string(case_name.to_string()) - .with_content_type(ContentType::json().to_string())) - } - (AnalysedType::Bool(_), Value::Bool(bool)) => get_json_of(bool), - (AnalysedType::S8(_), Value::S8(s8)) => get_json_of(s8), - (AnalysedType::U8(_), Value::U8(u8)) => get_json_of(u8), - (AnalysedType::S16(_), Value::S16(s16)) => get_json_of(s16), - (AnalysedType::U16(_), Value::U16(u16)) => get_json_of(u16), - (AnalysedType::S32(_), Value::S32(s32)) => get_json_of(s32), - (AnalysedType::U32(_), Value::U32(u32)) => get_json_of(u32), - (AnalysedType::S64(_), Value::S64(s64)) => get_json_of(s64), - (AnalysedType::U64(_), Value::U64(u64)) => get_json_of(u64), - (AnalysedType::F32(_), Value::F32(f32)) => get_json_of(f32), - (AnalysedType::F64(_), Value::F64(f64)) => get_json_of(f64), - (AnalysedType::Chr(_), Value::Char(char)) => get_json_of(char), - (AnalysedType::Tuple(_), Value::Tuple(_)) => get_json(value_and_type), - (AnalysedType::Flags(_), Value::Flags(_)) => get_json(value_and_type), - (AnalysedType::Variant(_), Value::Variant { .. }) => get_json(value_and_type), - (AnalysedType::Result(_), Value::Result { .. }) => get_json(value_and_type), - (AnalysedType::Handle(_), Value::Handle { .. }) => get_json(value_and_type), - (AnalysedType::Option(TypeOption { inner, .. }), Value::Option(value)) => match value { - Some(value) => { - let value = ValueAndType::new((**value).clone(), (**inner).clone()); - get_response_body(&value) - } - None => get_json_null(), - }, - _ => Err(ContentTypeMapError::internal("Value and type mismatch")), - } - } - - pub(crate) fn pick_highest_priority_content_type( - input_content_types: &AcceptHeaders, - ) -> Result { - let content_headers_in_priority: Vec = vec![ - ContentType::json(), - ContentType::text(), - ContentType::text_utf8(), - ContentType::html(), - ContentType::xml(), - ContentType::form_url_encoded(), - ContentType::jpeg(), - ContentType::png(), - ContentType::octet_stream(), - ]; - - let mut prioritised_content_type = None; - for content_type in &content_headers_in_priority { - if input_content_types.contains(content_type) { - prioritised_content_type = Some(content_type.clone()); - break; - } - } - - if let Some(prioritised) = prioritised_content_type { - Ok(prioritised) - } else { - Err(ContentTypeMapError::internal( - "Failed to pick a content type to set in response headers", - )) - } - } - - fn get_byte_stream(values: &[Value]) -> Result, ContentTypeMapError> { - let bytes = values - .iter() - .map(|v| match v { - Value::U8(u8) => Ok(*u8), - _ => Err(ContentTypeMapError::internal( - "The analysed type is a binary stream however unable to fetch vec", - )), - }) - .collect::, ContentTypeMapError>>()?; - - Ok(bytes) - } - - fn get_byte_stream_body( - values: &[Value], - ) -> Result, ContentTypeMapError> { - let bytes = get_byte_stream(values)?; - Ok(Body::from_bytes(bytes::Bytes::from(bytes)) - .with_content_type(ContentType::octet_stream().to_string())) - } - - fn get_json( - value_and_type: &ValueAndType, - ) -> Result, ContentTypeMapError> { - let json = value_and_type.to_json_value().map_err(|err| { - ContentTypeMapError::internal(format!("Failed to encode value as JSON: {err}")) - })?; - Body::from_json(json) - .map(|body| body.with_content_type(ContentType::json().to_string())) - .map_err(|_| ContentTypeMapError::internal("Failed to convert to json body")) - } - - fn get_json_of( - a: A, - ) -> Result, ContentTypeMapError> { - let json = serde_json::to_value(&a).map_err(|_| { - ContentTypeMapError::internal(format!("Failed to serialise {a} to json")) - })?; - let body = Body::from_json(json) - .map_err(|_| ContentTypeMapError::internal("Failed to create body from JSON"))?; - - Ok(body.with_content_type(ContentType::json().to_string())) - } - - fn get_json_null() -> Result, ContentTypeMapError> { - Body::from_json(serde_json::Value::Null) - .map(|body| body.with_content_type(ContentType::json().to_string())) - .map_err(|_| ContentTypeMapError::internal("Failed to convert to json body")) - } - - fn handle_complex( - complex: &ValueAndType, - content_header: &A, - ) -> Result, ContentTypeMapError> { - if content_header.has_application_json() { - get_json(complex) - } else { - let typ = complex.typ.clone(); - Err(ContentTypeMapError::illegal_mapping(&typ, content_header)) - } - } - - fn handle_list( - original: &ValueAndType, - inner_values: &[Value], - elem_type: &AnalysedType, - content_header: &A, - ) -> Result, ContentTypeMapError> { - match elem_type { - AnalysedType::U8(_) => { - let byte_stream = get_byte_stream(inner_values)?; - let body = Body::from_bytes(bytes::Bytes::from(byte_stream)); - let content_type_header = content_header.response_content_type()?; - Ok(body.with_content_type(content_type_header.to_string())) - } - _ => { - if content_header.has_application_json() { - get_json(original) - } else { - Err(ContentTypeMapError::illegal_mapping( - elem_type, - content_header, - )) - } - } - } - } - - fn handle_primitive( - input: &A, - primitive_type: &AnalysedType, - content_header: &B, - ) -> Result, ContentTypeMapError> where { - if content_header.has_application_json() { - let json = serde_json::to_value(input) - .map_err(|_| ContentTypeMapError::internal("Failed to convert to json body"))?; - - let body = Body::from_json(json).map_err(|_| { - ContentTypeMapError::internal(format!("Failed to convert {input} to json body")) - })?; - - Ok(body.with_content_type(ContentType::json().to_string())) - } else { - Err(ContentTypeMapError::illegal_mapping( - primitive_type, - content_header, - )) - } - } - - fn handle_record( - value_and_type: &ValueAndType, - content_header: &A, - ) -> Result, ContentTypeMapError> { - // if record, we prioritise JSON - if content_header.has_application_json() { - get_json(value_and_type) - } else { - let typ = value_and_type.typ.clone(); - // There is no way a Record can be properly serialised into any other formats to satisfy any other headers, therefore fail - Err(ContentTypeMapError::illegal_mapping(&typ, content_header)) - } - } - - fn handle_string( - string: &str, - content_type: &A, - ) -> Result, ContentTypeMapError> { - let response_content_type = content_type.response_content_type()?; - - let body = if content_type.has_application_json() { - bytes::Bytes::from(format!("\"{string}\"")) - } else { - bytes::Bytes::from(string.to_string()) - }; - - Ok(Body::from_bytes(body).with_content_type(response_content_type.to_string())) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use golem_wasm::analysis::analysed_type::{field, list, record, str}; - use golem_wasm::{IntoValue, Value}; - use poem::web::headers::ContentType; - use poem::IntoResponse; - - fn sample_record() -> ValueAndType { - ValueAndType::new( - Value::Record(vec!["Hello".into_value()]), - record(vec![field("name", str())]), - ) - } - - fn create_list(vec: Vec, analysed_type: AnalysedType) -> ValueAndType { - ValueAndType::new(Value::List(vec), list(analysed_type)) - } - - fn create_record(values: Vec<(&str, ValueAndType)>) -> ValueAndType { - ValueAndType::new( - Value::Record(values.iter().map(|(_, v)| v.value.clone()).collect()), - record( - values - .iter() - .map(|(name, value)| field(name, value.typ.clone())) - .collect(), - ), - ) - } - - #[cfg(test)] - mod no_content_type_header { - use test_r::test; - - use super::*; - use golem_wasm::analysis::analysed_type::{u16, u8}; - use golem_wasm::IntoValueAndType; - - fn get_content_type_and_body(input: &ValueAndType) -> (Option, Body) { - let response_body = internal::get_response_body(input).unwrap(); - let response = response_body.into_response(); - let (parts, body) = response.into_parts(); - let content_type = parts - .headers - .get("content-type") - .map(|v| v.to_str().unwrap().to_string()); - (content_type, body) - } - - #[test] - async fn test_string_type() { - let type_annotated_value = "Hello".into_value_and_type(); - let (content_type, body) = get_content_type_and_body(&type_annotated_value); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - // Had it serialized as json, it would have been "\"Hello\"" - assert_eq!( - (result, content_type), - ("Hello".to_string(), Some("application/json".to_string())) - ); - } - - #[test] - async fn test_singleton_u8_type() { - let type_annotated_value = 10u8.into_value_and_type(); - let (content_type, body) = get_content_type_and_body(&type_annotated_value); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ("10".to_string(), Some("application/json".to_string())) - ); - } - - #[test] - async fn test_list_u8_type() { - let type_annotated_value = create_list(vec![10u8.into_value()], u8()); - let (content_type, body) = get_content_type_and_body(&type_annotated_value); - let result = body.into_bytes().await.unwrap(); - - assert_eq!( - (result, content_type), - ( - bytes::Bytes::from(vec![10]), - Some("application/octet-stream".to_string()) - ) - ); - } - - #[test] - async fn test_list_non_u8_type() { - let type_annotated_value = create_list(vec![10u16.into_value()], u16()); - - let (content_type, body) = get_content_type_and_body(&type_annotated_value); - let data_as_str = - String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - let result_json: serde_json::Value = - serde_json::from_str(data_as_str.as_str()).unwrap(); - let expected_json = serde_json::Value::Array(vec![serde_json::Value::Number( - serde_json::Number::from(10), - )]); - assert_eq!( - (result_json, content_type), - (expected_json, Some("application/json".to_string())) - ); - } - - #[test] - async fn test_record_type() { - let type_annotated_value = create_record(vec![("name", "Hello".into_value_and_type())]); - - let (content_type, body) = get_content_type_and_body(&type_annotated_value); - let data_as_str = - String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - let result_json: serde_json::Value = - serde_json::from_str(data_as_str.as_str()).unwrap(); - let expected_json = serde_json::json!({"name": "Hello"}); - assert_eq!( - (result_json, content_type), - (expected_json, Some("application/json".to_string())) - ); - } - } - - #[cfg(test)] - mod with_response_type_header { - use test_r::test; - - use super::*; - use golem_wasm::analysis::analysed_type::{u16, u8}; - use golem_wasm::IntoValueAndType; - - fn get_content_type_and_body( - input: &ValueAndType, - header: &ContentType, - ) -> (Option, Body) { - let response_body = - internal::get_response_body_based_on_content_type(input, header).unwrap(); - let response = response_body.into_response(); - let (parts, body) = response.into_parts(); - let content_type = parts - .headers - .get("content-type") - .map(|v| v.to_str().unwrap().to_string()); - (content_type, body) - } - - #[test] - async fn test_string_type_as_text() { - let type_annotated_value = "Hello".into_value_and_type(); - let (content_type, body) = - get_content_type_and_body(&type_annotated_value, &ContentType::text()); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - // Had it serialized as json, it would have been "\"Hello\"" - assert_eq!( - (result, content_type), - ("Hello".to_string(), Some("text/plain".to_string())) - ); - } - - #[test] - async fn test_string_type_as_json() { - let type_annotated_value = "\"Hello\"".into_value_and_type(); - let (content_type, body) = - get_content_type_and_body(&type_annotated_value, &ContentType::json()); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - - // It doesn't matter if the string is already jsonified, it will be jsonified again - assert_eq!( - (result, content_type), - ( - "\"\"Hello\"\"".to_string(), - Some("application/json".to_string()) - ) - ); - } - - #[test] - async fn test_singleton_u8_type() { - let type_annotated_value = 10u8.into_value_and_type(); - let (content_type, body) = - get_content_type_and_body(&type_annotated_value, &ContentType::json()); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ("10".to_string(), Some("application/json".to_string())) - ); - } - - #[test] - async fn test_list_u8_type() { - let type_annotated_value = create_list(vec![10u8.into_value()], u8()); - - let (content_type, body) = - get_content_type_and_body(&type_annotated_value, &ContentType::json()); - let result = &body.into_bytes().await.unwrap(); - let data_as_str = String::from_utf8_lossy(result).to_string(); - let result_json: Result = - serde_json::from_str(data_as_str.as_str()); - - assert_eq!( - (result, content_type), - ( - &bytes::Bytes::from(vec![10]), - Some("application/json".to_string()) - ) - ); - assert!(result_json.is_err()); // That we haven't jsonified this case explicitly - } - - #[test] - async fn test_list_non_u8_type() { - let type_annotated_value = create_list(vec![10u16.into_value()], u16()); - - let (content_type, body) = - get_content_type_and_body(&type_annotated_value, &ContentType::json()); - let data_as_str = - String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - let result_json: serde_json::Value = - serde_json::from_str(data_as_str.as_str()).unwrap(); - let expected_json = serde_json::Value::Array(vec![serde_json::Value::Number( - serde_json::Number::from(10), - )]); - // That we jsonify any list other than u8, and can be retrieveed as a valid JSON - assert_eq!( - (result_json, content_type), - (expected_json, Some("application/json".to_string())) - ); - } - - #[test] - async fn test_record_type() { - // Record - let type_annotated_value = sample_record(); - - let (content_type, body) = - get_content_type_and_body(&type_annotated_value, &ContentType::json()); - let data_as_str = - String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - let result_json: serde_json::Value = - serde_json::from_str(data_as_str.as_str()).unwrap(); - let expected_json = serde_json::json!({"name": "Hello"}); - assert_eq!( - (result_json, content_type), - (expected_json, Some("application/json".to_string())) - ); - } - } - - #[cfg(test)] - mod with_accept_headers { - use test_r::test; - - use super::*; - use golem_wasm::analysis::analysed_type::{u16, u8}; - use golem_wasm::IntoValueAndType; - - fn get_content_type_and_body( - input: &ValueAndType, - headers: &AcceptHeaders, - ) -> (Option, Body) { - let response_body = - internal::get_response_body_based_on_content_type(input, headers).unwrap(); - let response = response_body.into_response(); - let (parts, body) = response.into_parts(); - let content_type = parts - .headers - .get("content-type") - .map(|v| v.to_str().unwrap().to_string()); - (content_type, body) - } - - #[test] - async fn test_string_type_with_json() { - let type_annotated_value = "Hello".into_value_and_type(); - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("text/html;q=0.8, application/json;q=0.5"), - ); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ( - "\"Hello\"".to_string(), - Some("application/json".to_string()) - ) - ); - } - - #[test] - async fn test_string_type_without_json() { - let type_annotated_value = "Hello".into_value_and_type(); - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("text/html;q=0.8, application/json;q=0.5"), - ); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ( - "\"Hello\"".to_string(), - Some("application/json".to_string()) - ) - ); - } - - #[test] - async fn test_string_type_with_html() { - let type_annotated_value = "Hello".into_value_and_type(); - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("text/html"), - ); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ("Hello".to_string(), Some("text/html".to_string())) - ); - } - - #[test] - async fn test_singleton_u8_type_text() { - let type_annotated_value = 10u8.into_value_and_type(); - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("application/json"), - ); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ("10".to_string(), Some("application/json".to_string())) - ); - } - - #[test] - async fn test_singleton_u8_type_json() { - let type_annotated_value = 10u8.into_value_and_type(); - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("application/json"), - ); - let result = String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - assert_eq!( - (result, content_type), - ("10".to_string(), Some("application/json".to_string())) - ); - } - - #[test] - async fn test_singleton_u8_failed_content_mapping() { - let type_annotated_value = 10u8.into_value_and_type(); - let result = internal::get_response_body_based_on_content_type( - &type_annotated_value, - &AcceptHeaders::from_str("text/html"), - ); - - assert!(matches!( - result, - Err(ContentTypeMapError::IllegalMapping { .. }) - )); - } - - #[test] - async fn test_list_u8_type_with_json() { - let type_annotated_value = create_list(vec![10u8.into_value()], u8()); - - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("text/html;q=0.8, application/json;q=0.50"), - ); - let result = &body.into_bytes().await.unwrap(); - let data_as_str = String::from_utf8_lossy(result).to_string(); - let result_json: Result = - serde_json::from_str(data_as_str.as_str()); - - assert_eq!( - (result, content_type), - ( - &bytes::Bytes::from(vec![10]), - Some("application/json".to_string()) - ) - ); - assert!(result_json.is_err()); // That we haven't jsonified this case explicitly - } - - #[test] - async fn test_list_non_u8_type_with_json() { - let type_annotated_value = create_list(vec![10u16.into_value()], u16()); - - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("text/html;q=0.8, application/json;q=0.5"), - ); - let data_as_str = - String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - let result_json: serde_json::Value = - serde_json::from_str(data_as_str.as_str()).unwrap(); - let expected_json = serde_json::Value::Array(vec![serde_json::Value::Number( - serde_json::Number::from(10), - )]); - assert_eq!( - (result_json, content_type), - (expected_json, Some("application/json".to_string())) - ); - } - - #[test] - async fn test_list_non_u8_type_with_html_fail() { - let type_annotated_value = create_list(vec![10u16.into_value()], u16()); - - let result = internal::get_response_body_based_on_content_type( - &type_annotated_value, - &AcceptHeaders::from_str("text/html"), - ); - - assert!(matches!( - result, - Err(ContentTypeMapError::IllegalMapping { .. }) - )); - } - - #[test] - async fn test_record_type_json() { - let type_annotated_value = sample_record(); - let (content_type, body) = get_content_type_and_body( - &type_annotated_value, - &AcceptHeaders::from_str("text/html;q=0.8, application/json;q=0.5"), - ); - let data_as_str = - String::from_utf8_lossy(&body.into_bytes().await.unwrap()).to_string(); - let result_json: serde_json::Value = - serde_json::from_str(data_as_str.as_str()).unwrap(); - let expected_json = serde_json::json!({"name": "Hello"}); - assert_eq!( - (result_json, content_type), - (expected_json, Some("application/json".to_string())) - ); - } - - #[test] - async fn test_record_type_html() { - let type_annotated_value = sample_record(); - - let result = internal::get_response_body_based_on_content_type( - &type_annotated_value, - &AcceptHeaders::from_str("text/html"), - ); - - assert!(matches!( - result, - Err(ContentTypeMapError::IllegalMapping { .. }) - )); - } - } -} diff --git a/golem-worker-service/src/gateway_execution/http_handler_binding_handler.rs b/golem-worker-service/src/gateway_execution/http_handler_binding_handler.rs deleted file mode 100644 index 495daa7ea3..0000000000 --- a/golem-worker-service/src/gateway_execution/http_handler_binding_handler.rs +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::{GatewayWorkerRequestExecutor, WorkerRequestExecutorError}; -use crate::gateway_execution::{GatewayResolvedWorkerRequest, WorkerDetails}; -use bytes::Bytes; -use golem_common::model::account::AccountId; -use golem_common::virtual_exports::http_incoming_handler::IncomingHttpRequest; -use golem_common::{virtual_exports, widen_infallible}; -use golem_wasm::ValueAndType; -use http::StatusCode; -use http_body_util::combinators::BoxBody; -use http_body_util::BodyExt; -use std::str::FromStr; -use std::sync::Arc; - -pub struct HttpHandlerBindingHandler { - worker_request_executor: Arc, -} - -impl HttpHandlerBindingHandler { - pub fn new(worker_request_executor: Arc) -> Self { - Self { - worker_request_executor, - } - } - - pub async fn handle_http_handler_binding( - &self, - worker_detail: &WorkerDetails, - incoming_http_request: IncomingHttpRequest, - account_id: AccountId, - ) -> HttpHandlerBindingResult { - let type_annotated_param = ValueAndType::new( - incoming_http_request.to_value(), - IncomingHttpRequest::analysed_type(), - ); - - let resolved_request = GatewayResolvedWorkerRequest { - component_id: worker_detail.component_id, - component_revision: worker_detail.component_revision, - worker_name: worker_detail - .worker_name - .as_ref() - .ok_or_else(|| { - HttpHandlerBindingError::InternalError("Missing worker name".to_string()) - })? - .clone(), - function_name: virtual_exports::http_incoming_handler::FUNCTION_NAME.to_string(), - function_params: vec![type_annotated_param], - idempotency_key: worker_detail.idempotency_key.clone(), - invocation_context: worker_detail.invocation_context.clone(), - }; - - let response = self - .worker_request_executor - .execute(resolved_request, account_id) - .await; - - match response { - Ok(_) => { - tracing::debug!("http_handler received successful response from worker invocation") - } - Err(ref e) => tracing::warn!("worker invocation of http_handler failed: {}", e), - } - - let response = response.map_err(HttpHandlerBindingError::WorkerRequestExecutorError)?; - - let poem_response = { - use golem_common::virtual_exports::http_incoming_handler as hic; - - let parsed_response = - hic::HttpResponse::from_function_output(response).map_err(|e| { - HttpHandlerBindingError::InternalError(format!("Failed parsing response: {e}")) - })?; - - let converted_status_code = - StatusCode::from_u16(parsed_response.status).map_err(|e| { - HttpHandlerBindingError::InternalError(format!( - "Failed to parse response status: {e}" - )) - })?; - - let mut builder = poem::Response::builder().status(converted_status_code); - - for (header_name, header_value) in parsed_response.headers.0 { - let converted_header_value = - http::HeaderValue::from_bytes(&header_value).map_err(|e| { - HttpHandlerBindingError::InternalError(format!( - "Failed to parse response header: {e}" - )) - })?; - builder = builder.header(header_name, converted_header_value); - } - - if let Some(body) = parsed_response.body { - let converted_body = http_body_util::Full::new(body.content.0); - - let trailers = if let Some(trailers) = body.trailers { - let mut acc = http::HeaderMap::new(); - for (header_name, header_value) in trailers.0.into_iter() { - let converted_header_name = http::HeaderName::from_str(&header_name) - .map_err(|e| { - HttpHandlerBindingError::InternalError(format!( - "Failed to parse response trailer name: {e}" - )) - })?; - let converted_header_value = http::HeaderValue::from_bytes(&header_value) - .map_err(|e| { - HttpHandlerBindingError::InternalError(format!( - "Failed to parse response trailer value: {e}" - )) - })?; - - acc.insert(converted_header_name, converted_header_value); - } - Some(Ok(acc)) - } else { - None - }; - - let body_with_trailers = converted_body.with_trailers(async { trailers }); - - let boxed: BoxBody = - BoxBody::new(body_with_trailers.map_err(widen_infallible::)); - - builder.body(boxed) - } else { - builder.body(poem::Body::empty()) - } - }; - - Ok(HttpHandlerBindingSuccess { - response: poem_response, - }) - } -} - -pub type HttpHandlerBindingResult = Result; - -pub struct HttpHandlerBindingSuccess { - pub response: poem::Response, -} - -#[derive(Debug)] -pub enum HttpHandlerBindingError { - InternalError(String), - WorkerRequestExecutorError(WorkerRequestExecutorError), -} diff --git a/golem-worker-service/src/gateway_execution/mod.rs b/golem-worker-service/src/gateway_execution/mod.rs deleted file mode 100644 index bc9f799fb3..0000000000 --- a/golem-worker-service/src/gateway_execution/mod.rs +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub mod api_definition_lookup; -// pub mod auth_call_back_binding_handler; -// pub mod file_server_binding_handler; -// pub mod gateway_http_input_executor; -// pub mod gateway_session_store; -// mod gateway_worker_request_executor; -// mod http_content_type_mapper; -// pub mod http_handler_binding_handler; -mod agent_response_mapping; -pub mod model; -mod parameter_parsing; -pub mod request; -pub mod request_handler; -pub mod route_resolver; -// pub mod to_response; -// pub mod to_response_failure; -// pub use gateway_worker_request_executor::*; - -pub use model::*; - -use golem_common::model::component::{ComponentId, ComponentRevision}; -use golem_common::model::invocation_context::InvocationContextStack; -use golem_common::model::IdempotencyKey; -use golem_common::SafeDisplay; -use golem_wasm::json::ValueAndTypeJsonExtensions; -use golem_wasm::ValueAndType; -use rib::{RibInput, RibInputTypeInfo}; -use serde_json::Value; -use std::collections::HashMap; -use std::fmt::Display; - -#[derive(PartialEq, Debug, Clone)] -pub struct GatewayResolvedWorkerRequest { - pub component_id: ComponentId, - pub component_revision: ComponentRevision, - pub worker_name: String, - pub function_name: String, - pub function_params: Vec, - pub idempotency_key: Option, - pub invocation_context: InvocationContextStack, -} - -#[derive(Debug, Clone, PartialEq)] -pub struct WorkerDetails { - pub component_id: ComponentId, - pub component_revision: ComponentRevision, - pub worker_name: Option, - pub idempotency_key: Option, - pub invocation_context: InvocationContextStack, -} - -impl WorkerDetails { - fn as_json(&self) -> Value { - let mut worker_detail_content = HashMap::new(); - worker_detail_content.insert( - "component_id".to_string(), - Value::String(self.component_id.0.to_string()), - ); - - if let Some(worker_name) = &self.worker_name { - worker_detail_content - .insert("name".to_string(), Value::String(worker_name.to_string())); - } - - if let Some(idempotency_key) = &self.idempotency_key { - worker_detail_content.insert( - "idempotency_key".to_string(), - Value::String(idempotency_key.value.clone()), - ); - } - - let map = serde_json::Map::from_iter(worker_detail_content); - - Value::Object(map) - } - - pub fn resolve_rib_input_value( - &self, - required_types: &RibInputTypeInfo, - ) -> Result { - let request_type_info = required_types.types.get("worker"); - - match request_type_info { - Some(worker_details_type) => { - let rib_input_with_request_content = &self.as_json(); - let request_value = - ValueAndType::parse_with_type(rib_input_with_request_content, worker_details_type) - .map_err(|err| RibInputTypeMismatch(format!("Worker details don't match the requirements for rib expression to execute: {}. Requirements. {:?}", err.join(", "), worker_details_type)))?; - - let mut rib_input_map = HashMap::new(); - rib_input_map.insert("worker".to_string(), request_value); - Ok(RibInput { - input: rib_input_map, - }) - } - None => Ok(RibInput::default()), - } - } -} - -#[derive(Debug)] -pub struct RibInputTypeMismatch(pub String); - -impl Display for RibInputTypeMismatch { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Rib input type mismatch: {}", self.0) - } -} - -impl SafeDisplay for RibInputTypeMismatch { - fn to_safe_string(&self) -> String { - self.0.clone() - } -} diff --git a/golem-worker-service/src/gateway_execution/model.rs b/golem-worker-service/src/gateway_execution/model.rs deleted file mode 100644 index 3dd2788ef9..0000000000 --- a/golem-worker-service/src/gateway_execution/model.rs +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use golem_common::model::agent::BinarySource; -use golem_wasm::ValueAndType; -use http::StatusCode; -use std::fmt; - -pub enum RouteExecutionResult { - NoBody { - status: StatusCode, - }, - ComponentModelJsonBody { - body: golem_wasm::ValueAndType, - status: StatusCode, - }, - UnstructuredBinaryBody { - body: BinarySource, - }, - CustomAgentError { - body: ValueAndType, - }, -} - -impl fmt::Debug for RouteExecutionResult { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - RouteExecutionResult::NoBody { status } => { - f.debug_struct("NoBody").field("status", status).finish() - } - RouteExecutionResult::ComponentModelJsonBody { body, status } => f - .debug_struct("ComponentModelJsonBody") - .field("body", body) - .field("status", status) - .finish(), - RouteExecutionResult::UnstructuredBinaryBody { .. } => { - f.write_str("UnstructuredBinaryBody") - } - RouteExecutionResult::CustomAgentError { body } => f - .debug_struct("CustomAgentError") - .field("body", body) - .finish(), - } - } -} - -pub enum ParsedRequestBody { - Unused, - JsonBody(golem_wasm::Value), - // Always Some initially, will be None after being consumed by handler code - UnstructuredBinary(Option), -} - -impl fmt::Debug for ParsedRequestBody { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ParsedRequestBody::Unused => f.write_str("Unused"), - ParsedRequestBody::JsonBody(value) => f.debug_tuple("JsonBody").field(value).finish(), - ParsedRequestBody::UnstructuredBinary(_) => f.write_str("UnstructuredBinary"), - } - } -} diff --git a/golem-worker-service/src/gateway_execution/request.rs b/golem-worker-service/src/gateway_execution/request.rs deleted file mode 100644 index 160330ec5b..0000000000 --- a/golem-worker-service/src/gateway_execution/request.rs +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// use super::gateway_session_store::{DataKey, GatewaySessionStore, SessionId}; -// use crate::gateway_router::PathParamExtractor; -// use crate::model::{HttpMiddleware, RichGatewayBindingCompiled}; -use http::HeaderMap; -use std::collections::HashMap; -use uuid::Uuid; - -const COOKIE_HEADER_NAMES: [&str; 2] = ["cookie", "Cookie"]; - -/// Thin wrapper around a poem::Request that is used to evaluate all binding types when coming from an http gateway. -pub struct RichRequest { - pub underlying: poem::Request, - pub request_id: Uuid, - // path_segments: Vec, - // path_param_extractors: Vec, - // auth_data: Option, - // cached_request_body: Value, -} - -impl RichRequest { - pub fn new(underlying: poem::Request) -> RichRequest { - RichRequest { - underlying, - request_id: Uuid::new_v4(), - // path_segments: vec![], - // path_param_extractors: vec![], - // auth_data: None, - // cached_request_body: serde_json::Value::Null, - } - } - - // pub fn auth_data(&self) -> Option<&Value> { - // self.auth_data.as_ref() - // } - - pub fn headers(&self) -> &HeaderMap { - self.underlying.headers() - } - - pub fn query_params(&self) -> HashMap> { - let mut params: HashMap> = HashMap::new(); - - if let Some(q) = self.underlying.uri().query() { - for (key, value) in url::form_urlencoded::parse(q.as_bytes()).into_owned() { - params.entry(key).or_default().push(value); - } - } - - params - } - - // pub async fn add_auth_details( - // &mut self, - // session_id: &SessionId, - // gateway_session_store: &Arc, - // ) -> Result<(), String> { - // let claims = gateway_session_store - // .get(session_id, &DataKey::claims()) - // .await - // .map_err(|err| err.to_safe_string())?; - - // self.auth_data = Some(claims.0); - - // Ok(()) - // } - - // fn path_and_query(&self) -> Result { - // self.underlying - // .uri() - // .path_and_query() - // .map(|paq| paq.to_string()) - // .ok_or("No path and query provided".to_string()) - // } - - // pub fn path_params(&self) -> HashMap { - // self.path_param_extractors - // .iter() - // .map(|param| match param { - // PathParamExtractor::Single { var_info, index } => ( - // var_info.key_name.clone(), - // self.path_segments[*index].clone(), - // ), - // PathParamExtractor::AllFollowing { var_info, index } => { - // let value = self.path_segments[*index..].join("/"); - // (var_info.key_name.clone(), value) - // } - // }) - // .collect() - // } - - // pub async fn request_body(&mut self) -> Result<&Value, String> { - // self.take_request_body().await?; - // Ok(self.cached_request_body()) - // } - - pub fn cookies(&self) -> HashMap<&str, &str> { - let mut result = HashMap::new(); - - for header_name in COOKIE_HEADER_NAMES.iter() { - if let Some(value) = self.underlying.header(header_name) { - let parts: Vec<&str> = value.split(';').collect(); - for part in parts { - let key_value: Vec<&str> = part.split('=').collect(); - if let (Some(key), Some(value)) = (key_value.first(), key_value.get(1)) { - result.insert(key.trim(), value.trim()); - } - } - } - } - - result - } - - // fn cached_request_body(&self) -> &Value { - // &self.cached_request_body - // } - - // /// Consumes the body of the underlying request, and make it as part of RichRequest as `cached_request_body`. - // /// The following logic is subtle enough that it takes the following into consideration: - // /// 99% of the time, number of separate rib scripts in API definition that needs to look up request body is 1, - // /// and for that rib-script, there will be no extra logic to read the request body in the hot path. - // /// At the same, if by any chance, multiple rib scripts exist (within a request) that require to lookup the request body, `take_request_body` - // /// is idempotent, that it doesn't affect correctness. - // /// We intentionally don't consume the body if its not required in any Rib script. - // async fn take_request_body(&mut self) -> Result<(), String> { - // let body = self.underlying.take_body(); - - // if !body.is_empty() { - // match body.into_json().await { - // Ok(json_request_body) => { - // self.cached_request_body = json_request_body; - // } - // Err(err) => { - // tracing::error!("Failed reading http request body as json: {}", err); - // return Err(format!("Request body parse error: {err}"))?; - // } - // } - // }; - - // Ok(()) - // } -} - -// pub struct SplitResolvedRouteEntryResult { -// pub binding: RichGatewayBindingCompiled, -// pub middlewares: Vec, -// pub rich_request: RichRequest, -// pub account_id: AccountId, -// pub environment_id: EnvironmentId, -// } - -// pub fn split_resolved_route_entry( -// request: poem::Request, -// entry: ResolvedRouteEntry, -// ) -> SplitResolvedRouteEntryResult { -// // helper function to save a few clones - -// let binding = entry.route_entry.binding; -// let middlewares = entry.route_entry.middlewares; -// let account_id = entry.route_entry.account_id; -// let environment_id = entry.route_entry.environment_id; - -// let rich_request = RichRequest { -// underlying: request, -// request_id: Uuid::new_v4(), -// path_segments: entry.path_segments, -// path_param_extractors: entry.route_entry.path_params, -// auth_data: None, -// cached_request_body: Value::Null, -// }; - -// SplitResolvedRouteEntryResult { -// binding, -// middlewares, -// rich_request, -// account_id, -// environment_id, -// } -// } - -// #[derive(Debug, Clone)] -// pub struct RequestQueryValues(pub HashMap); - -// impl RequestQueryValues { -// pub fn from( -// query_key_values: &HashMap, -// query_keys: &[QueryInfo], -// ) -> Result> { -// let mut unavailable_query_variables: Vec = vec![]; -// let mut query_variable_map: HashMap = HashMap::new(); - -// for spec_query_variable in query_keys.iter() { -// let key = &spec_query_variable.key_name; -// if let Some(query_value) = query_key_values.get(key) { -// query_variable_map.insert(key.clone(), query_value.to_string()); -// } else { -// unavailable_query_variables.push(spec_query_variable.to_string()); -// } -// } - -// if unavailable_query_variables.is_empty() { -// Ok(RequestQueryValues(query_variable_map)) -// } else { -// Err(unavailable_query_variables) -// } -// } -// } - -// #[derive(Debug, Clone)] -// pub struct RequestHeaderValues(pub HashMap); - -// impl RequestHeaderValues { -// pub fn from(headers: &HeaderMap) -> Result> { -// let mut headers_map: HashMap = HashMap::new(); - -// for (header_name, header_value) in headers { -// let header_value_str = header_value.to_str().map_err(|err| vec![err.to_string()])?; - -// headers_map.insert(header_name.to_string(), header_value_str.to_string()); -// } - -// Ok(RequestHeaderValues(headers_map)) -// } -// } - -// #[cfg(test)] -// mod tests { -// use super::*; -// use poem::http::Uri; -// use test_r::test; - -// #[test] -// fn test_query_params_with_plus_encoded_spaces() -> anyhow::Result<()> { -// let uri: Uri = "/search?q=hello+world&lang=rust".parse()?; -// let req = poem::Request::builder().uri(uri).finish(); -// let rich_req = RichRequest::new(req); -// let params = rich_req.query_params(); - -// assert_eq!(params.get("q"), Some(&"hello world".to_string())); // '+' decoded to space -// assert_eq!(params.get("lang"), Some(&"rust".to_string())); - -// Ok(()) -// } -// } diff --git a/golem-worker-service/src/gateway_execution/request_handler.rs b/golem-worker-service/src/gateway_execution/request_handler.rs deleted file mode 100644 index 8844e51c2f..0000000000 --- a/golem-worker-service/src/gateway_execution/request_handler.rs +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::agent_response_mapping::interpret_agent_response; -use super::parameter_parsing::{ - parse_path_segment_value, parse_path_segment_value_to_component_model, - parse_query_or_header_value, parse_request_body, -}; -use super::request::RichRequest; -use super::route_resolver::{ResolvedRouteEntry, RouteResolver, RouteResolverError}; -use super::{ParsedRequestBody, RouteExecutionResult}; -use crate::service::worker::{WorkerService, WorkerServiceError}; -use anyhow::anyhow; -use golem_common::model::agent::{ - AgentId, BinaryReference, BinaryReferenceValue, DataValue, ElementValue, ElementValues, - UntypedDataValue, UntypedElementValue, -}; -use golem_common::model::{IdempotencyKey, WorkerId}; -use golem_common::{error_forwarding, SafeDisplay}; -use golem_service_base::custom_api::{ConstructorParameter, MethodParameter, RouteBehaviour}; -use golem_service_base::model::auth::AuthCtx; -use golem_wasm::json::ValueAndTypeJsonExtensions; -use golem_wasm::IntoValue; -use golem_wasm::ValueAndType; -use http::StatusCode; -use poem::{Request, Response}; -use std::sync::Arc; -use tracing::debug; -use uuid::Uuid; - -#[derive(Debug, thiserror::Error)] -pub enum RequestHandlerError { - #[error("Failed parsing value; Provided: {value}; Expected type: {expected}")] - ValueParsingFailed { - value: String, - expected: &'static str, - }, - #[error("Expected {expected} values to be provided, but found none")] - MissingValue { expected: &'static str }, - #[error("Expected {expected} values to be provided, but found too many")] - TooManyValues { expected: &'static str }, - #[error("Header value of {header_name} is not valid ascii")] - HeaderIsNotAscii { header_name: String }, - #[error("Request body was not valid json: {error}")] - BodyIsNotValidJson { error: String }, - #[error("Failed parsing json body: [{formatted}]", formatted=.errors.join(","))] - JsonBodyParsingFailed { errors: Vec }, - #[error("Agent response did not match expected type: {error}")] - AgentResponseTypeMismatch { error: String }, - #[error("Mime type {mime_type} is not supported. Allowed mime types: [{formatted_mime_types}]", formatted_mime_types=.allowed_mime_types.join(","))] - UnsupportedMimeType { - mime_type: String, - allowed_mime_types: Vec, - }, - #[error("Invariant violated: {msg}")] - InvariantViolated { msg: &'static str }, - #[error("Resolving route failed: {0}")] - ResolvingRouteFailed(#[from] RouteResolverError), - #[error("Invocation failed: {0}")] - AgentInvocationFailed(#[from] WorkerServiceError), - #[error(transparent)] - InternalError(#[from] anyhow::Error), -} - -impl RequestHandlerError { - pub fn invariant_violated(msg: &'static str) -> Self { - Self::InvariantViolated { msg } - } -} - -impl SafeDisplay for RequestHandlerError { - fn to_safe_string(&self) -> String { - match self { - Self::ValueParsingFailed { .. } => self.to_string(), - Self::MissingValue { .. } => self.to_string(), - Self::TooManyValues { .. } => self.to_string(), - Self::HeaderIsNotAscii { .. } => self.to_string(), - Self::BodyIsNotValidJson { .. } => self.to_string(), - Self::JsonBodyParsingFailed { .. } => self.to_string(), - Self::AgentResponseTypeMismatch { .. } => self.to_string(), - Self::UnsupportedMimeType { .. } => self.to_string(), - - Self::InvariantViolated { .. } => "internal error".to_string(), - - Self::ResolvingRouteFailed(inner) => { - format!("Resolving route failed: {}", inner.to_safe_string()) - } - Self::AgentInvocationFailed(inner) => { - format!("Invocation failed: {}", inner.to_safe_string()) - } - - Self::InternalError(_) => "internal error".to_string(), - } - } -} - -error_forwarding!(RequestHandlerError); - -pub struct RequestHandler { - route_resolver: Arc, - worker_service: Arc, -} - -#[allow(irrefutable_let_patterns)] -impl RequestHandler { - pub fn new(route_resolver: Arc, worker_service: Arc) -> Self { - Self { - route_resolver, - worker_service, - } - } - - pub async fn handle_request(&self, request: Request) -> Result { - debug!("Begin http request handling for request {request:?}"); - - let matching_route = self.route_resolver.resolve_matching_route(&request).await?; - let mut request = RichRequest::new(request); - let execution_result = self.execute_route(&mut request, &matching_route).await?; - let response = route_execution_result_to_response(execution_result)?; - Ok(response) - } - - async fn execute_route( - &self, - request: &mut RichRequest, - resolved_route: &ResolvedRouteEntry, - ) -> Result { - match &resolved_route.route.behavior { - RouteBehaviour::CallAgent { .. } => { - self.execute_call_agent(request, resolved_route).await - } - } - } - - async fn execute_call_agent( - &self, - request: &mut RichRequest, - resolved_route: &ResolvedRouteEntry, - ) -> Result { - let RouteBehaviour::CallAgent { - expected_agent_response, - .. - } = &resolved_route.route.behavior - else { - unreachable!() - }; - - let worker_id = self.build_worker_id(resolved_route)?; - - let parsed_body = parse_request_body(request, &resolved_route.route.body).await?; - - let method_params = self.resolve_method_arguments(resolved_route, request, parsed_body)?; - - debug!("Invoking agent {worker_id}"); - - let agent_response = self - .invoke_agent(&worker_id, resolved_route, method_params) - .await?; - - debug!("Received agent response: {agent_response:?}"); - - debug!( - "Json agent response: {}", - agent_response.clone().unwrap().to_json_value().unwrap() - ); - - let mapped_result = interpret_agent_response(agent_response, expected_agent_response)?; - - debug!("Returning mapped agent result: {mapped_result:?}"); - - Ok(mapped_result) - } - - fn build_worker_id( - &self, - resolved_route: &ResolvedRouteEntry, - ) -> Result { - let RouteBehaviour::CallAgent { - component_id, - agent_type, - constructor_parameters, - phantom, - .. - } = &resolved_route.route.behavior - else { - unreachable!() - }; - - let mut values = Vec::with_capacity(constructor_parameters.len()); - - for param in constructor_parameters { - match param { - ConstructorParameter::Path { - path_segment_index, - parameter_type, - } => { - let raw = resolved_route.captured_path_parameters - [usize::from(*path_segment_index)] - .clone(); - - let value = parse_path_segment_value_to_component_model(raw, parameter_type)?; - - values.push(ElementValue::ComponentModel(ValueAndType::new( - value, - parameter_type.clone().into(), - ))); - } - } - } - - let data_value = DataValue::Tuple(ElementValues { elements: values }); - - let phantom_id = phantom.then(Uuid::new_v4); - - let agent_id = AgentId::new(agent_type.clone(), data_value, phantom_id); - - Ok(WorkerId { - component_id: *component_id, - worker_name: agent_id.to_string(), - }) - } - - fn resolve_method_arguments( - &self, - resolved_route: &ResolvedRouteEntry, - request: &RichRequest, - mut body: ParsedRequestBody, - ) -> Result, RequestHandlerError> { - let RouteBehaviour::CallAgent { - method_parameters, .. - } = &resolved_route.route.behavior - else { - unreachable!() - }; - - let query_params = request.query_params(); - let headers = request.headers(); - - let mut values = Vec::with_capacity(method_parameters.len()); - - for param in method_parameters { - let value = match param { - MethodParameter::Path { - path_segment_index, - parameter_type, - } => { - let raw = resolved_route.captured_path_parameters[usize::from(*path_segment_index)].clone(); - - parse_path_segment_value(raw, parameter_type)? - } - - MethodParameter::Query { - query_parameter_name, - parameter_type, - } => { - let empty = Vec::new(); - let vals = query_params.get(query_parameter_name).unwrap_or(&empty); - - parse_query_or_header_value(vals, parameter_type)? - } - - MethodParameter::Header { - header_name, - parameter_type, - } => { - let vals = headers - .get_all(header_name) - .iter() - .map(|h| { - h.to_str().map(String::from).map_err(|_| { - RequestHandlerError::HeaderIsNotAscii { - header_name: header_name.clone(), - } - }) - }) - .collect::, _>>()?; - - parse_query_or_header_value(&vals, parameter_type)? - } - - MethodParameter::JsonObjectBodyField { field_index } => { - match &body { - ParsedRequestBody::JsonBody(golem_wasm::Value::Record(fields)) => { - UntypedElementValue::ComponentModel(fields[usize::from(*field_index)].clone()) - } - - ParsedRequestBody::JsonBody(_) => { - return Err(RequestHandlerError::invariant_violated( - "Inconsistent API definition: JSON field parameter but body is not an object", - )) - } - - _ => return Err(RequestHandlerError::invariant_violated( - "JSON body parameter used but no JSON body schema", - )), - } - } - - MethodParameter::UnstructuredBinaryBody => { - match &mut body { - ParsedRequestBody::UnstructuredBinary(binary_source) => { - let binary_source = binary_source.take().ok_or_else(|| RequestHandlerError::invariant_violated( - "Parsed body was already consumed", - ))?; - - UntypedElementValue::UnstructuredBinary(BinaryReferenceValue { value: BinaryReference::Inline(binary_source) }) - } - - _ => return Err(RequestHandlerError::invariant_violated( - "Binary body parameter used but no binary body schema", - )), - } - } - }; - - values.push(value); - } - - Ok(values) - } - - async fn invoke_agent( - &self, - worker_id: &WorkerId, - resolved_route: &ResolvedRouteEntry, - params: Vec, - ) -> Result, RequestHandlerError> { - let RouteBehaviour::CallAgent { method_name, .. } = &resolved_route.route.behavior else { - unreachable!() - }; - - let method_params_data_value = UntypedDataValue::Tuple(params); - - self.worker_service - .invoke_and_await_owned_agent( - worker_id, - Some(IdempotencyKey::fresh()), - "golem:agent/guest.{invoke}".to_string(), - vec![ - golem_wasm::protobuf::Val::from(method_name.clone().into_value()), - golem_wasm::protobuf::Val::from(method_params_data_value.into_value()), - golem_wasm::protobuf::Val::from( - golem_common::model::agent::Principal::anonymous().into_value(), - ), - ], - None, - resolved_route.route.environment_id, - resolved_route.route.account_id, - AuthCtx::impersonated_user(resolved_route.route.account_id), - ) - .await - .map_err(Into::into) - } -} - -fn route_execution_result_to_response( - result: RouteExecutionResult, -) -> Result { - match result { - RouteExecutionResult::NoBody { status } => Ok(Response::builder().status(status).finish()), - - RouteExecutionResult::ComponentModelJsonBody { body, status } => { - let body = poem::Body::from_json( - body.to_json_value() - .map_err(|e| anyhow!("ComponentModelJsonBody conversion error: {e}"))?, - ) - .map_err(anyhow::Error::from)?; - - Ok(Response::builder().status(status).body(body)) - } - - RouteExecutionResult::UnstructuredBinaryBody { body } => Ok(Response::builder() - .status(StatusCode::OK) - .body(body.data) - .set_content_type(body.binary_type.mime_type)), - - RouteExecutionResult::CustomAgentError { body } => { - let body = poem::Body::from_json( - body.to_json_value() - .map_err(|e| anyhow!("CustomAgentError conversion error: {e}"))?, - ) - .map_err(anyhow::Error::from)?; - - Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(body)) - } - } -} diff --git a/golem-worker-service/src/gateway_execution/to_response.rs b/golem-worker-service/src/gateway_execution/to_response.rs deleted file mode 100644 index e3a5196eb7..0000000000 --- a/golem-worker-service/src/gateway_execution/to_response.rs +++ /dev/null @@ -1,534 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::auth_call_back_binding_handler::{AuthenticationSuccess, AuthorisationError}; -use super::file_server_binding_handler::FileServerBindingSuccess; -use super::http_handler_binding_handler::{HttpHandlerBindingError, HttpHandlerBindingSuccess}; -use super::{RibInputTypeMismatch, WorkerRequestExecutorError}; -use crate::api::common::ApiEndpointError; -use crate::gateway_execution::file_server_binding_handler::FileServerBindingError; -use crate::gateway_execution::gateway_session_store::GatewaySessionStore; -use crate::gateway_execution::request::RichRequest; -use crate::gateway_execution::to_response_failure::ToHttpResponseFromSafeDisplay; -use crate::model::SwaggerHtml; -use async_trait::async_trait; -use golem_service_base::custom_api::HttpCors; -use http::header::*; -use http::StatusCode; -use poem::Body; -use poem::IntoResponse; -use rib::RibResult; -use std::sync::Arc; - -#[async_trait] -pub trait ToHttpResponse { - async fn to_response( - self, - request: &RichRequest, - session_store: &Arc, - ) -> poem::Response; -} - -#[async_trait] -impl ToHttpResponse for Result { - async fn to_response( - self, - request: &RichRequest, - session_store: &Arc, - ) -> poem::Response { - match self { - Ok(t) => t.to_response(request, session_store).await, - Err(e) => e.to_response(request, session_store).await, - } - } -} - -pub type GatewayHttpResult = Result; - -pub enum GatewayHttpError { - BadRequest(String), - InternalError(String), - RibInputTypeMismatch(RibInputTypeMismatch), - EvaluationError(WorkerRequestExecutorError), - RibInterpretPureError(String), - HttpHandlerBindingError(HttpHandlerBindingError), - FileServerBindingError(FileServerBindingError), - AuthorisationError(AuthorisationError), -} - -#[async_trait] -impl ToHttpResponse for GatewayHttpError { - async fn to_response( - self, - request_details: &RichRequest, - session_store: &Arc, - ) -> poem::Response { - match self { - GatewayHttpError::BadRequest(e) => poem::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from_string(e)), - GatewayHttpError::RibInputTypeMismatch(err) => { - err.to_response_from_safe_display(|_| StatusCode::BAD_REQUEST) - } - GatewayHttpError::RibInterpretPureError(err) => poem::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from_string(format!( - "Failed interpreting pure rib expression: {err}" - ))), - GatewayHttpError::EvaluationError(err) => { - err.to_response_from_safe_display(|_| StatusCode::INTERNAL_SERVER_ERROR) - } - GatewayHttpError::HttpHandlerBindingError(inner) => { - inner.to_response(request_details, session_store).await - } - GatewayHttpError::FileServerBindingError(inner) => { - inner.to_response(request_details, session_store).await - } - GatewayHttpError::AuthorisationError(inner) => { - inner.to_response(request_details, session_store).await - } - GatewayHttpError::InternalError(e) => poem::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from_string(e)), - } - } -} - -#[async_trait] -impl ToHttpResponse for FileServerBindingSuccess { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - Body::from_bytes_stream(self.data) - .with_content_type(self.binding_details.content_type.to_string()) - .with_status(self.binding_details.status_code) - .into_response() - } -} - -#[async_trait] -impl ToHttpResponse for FileServerBindingError { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - match self { - FileServerBindingError::InternalError(e) => poem::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from_string(format!("Error {e}"))), - FileServerBindingError::ComponentServiceError(inner) => { - ApiEndpointError::from(inner).into_response() - } - FileServerBindingError::WorkerServiceError(inner) => { - ApiEndpointError::from(inner).into_response() - } - FileServerBindingError::InvalidRibResult(e) => poem::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from_string(format!( - "Error while processing rib result: {e}" - ))), - } - } -} - -#[async_trait] -impl ToHttpResponse for HttpHandlerBindingSuccess { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - self.response - } -} - -#[async_trait] -impl ToHttpResponse for HttpHandlerBindingError { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - match self { - HttpHandlerBindingError::InternalError(e) => poem::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from_string(format!("Error {e}"))), - HttpHandlerBindingError::WorkerRequestExecutorError(e) => poem::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from_string(format!( - "Error calling worker executor {e}" - ))), - } - } -} - -// Preflight (OPTIONS) response that will consist of all configured CORS headers -#[async_trait] -impl ToHttpResponse for HttpCors { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - let mut response = poem::Response::builder().status(StatusCode::OK).finish(); - - // TODO: should not unwrap here - response.headers_mut().insert( - ACCESS_CONTROL_ALLOW_ORIGIN, - self.allow_origin.clone().parse().unwrap(), - ); - response.headers_mut().insert( - ACCESS_CONTROL_ALLOW_METHODS, - self.allow_methods.clone().parse().unwrap(), - ); - response.headers_mut().insert( - ACCESS_CONTROL_ALLOW_HEADERS, - self.allow_headers.clone().parse().unwrap(), - ); - - if let Some(expose_headers) = &self.expose_headers { - response.headers_mut().insert( - ACCESS_CONTROL_EXPOSE_HEADERS, - expose_headers.clone().parse().unwrap(), - ); - } - - if let Some(allow_credentials) = self.allow_credentials { - response.headers_mut().insert( - ACCESS_CONTROL_ALLOW_CREDENTIALS, - allow_credentials.to_string().parse().unwrap(), - ); - } - - if let Some(max_age) = self.max_age { - response - .headers_mut() - .insert(ACCESS_CONTROL_MAX_AGE, max_age.to_string().parse().unwrap()); - } - - response - } -} - -#[async_trait] -impl ToHttpResponse for RibResult { - async fn to_response( - self, - request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - match internal::IntermediateRibResultHttpResponse::from(&self) { - Ok(intermediate_response) => intermediate_response.to_http_response(request_details), - Err(e) => e.to_response_from_safe_display(|_| StatusCode::INTERNAL_SERVER_ERROR), - } - } -} - -#[async_trait] -impl ToHttpResponse for AuthenticationSuccess { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - let access_token = self.access_token; - let id_token = self.id_token; - let session_id = self.session; - - let mut response = poem::Response::builder() - .status(StatusCode::FOUND) - .header("Location", self.target_path) - .header( - "Set-Cookie", - format!("access_token={access_token}; HttpOnly; Secure; Path=/; SameSite=None") - .as_str(), - ); - - if let Some(id_token) = id_token { - response = response.header( - "Set-Cookie", - format!("id_token={id_token}; HttpOnly; Secure; Path=/; SameSite=None").as_str(), - ) - } - - response = response.header( - "Set-Cookie", - format!("session_id={session_id}; HttpOnly; Secure; Path=/; SameSite=None").as_str(), - ); - - response.body(()) - } -} - -#[async_trait] -impl ToHttpResponse for AuthorisationError { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - self.to_response_from_safe_display(|_| StatusCode::UNAUTHORIZED) - } -} - -#[async_trait] -impl ToHttpResponse for SwaggerHtml { - async fn to_response( - self, - _request_details: &RichRequest, - _session_store: &Arc, - ) -> poem::Response { - poem::Response::builder() - .content_type("text/html") - .body(Body::from_string(self.0)) - } -} - -mod internal { - use crate::gateway_execution::http_content_type_mapper::{ - ContentTypeHeaders, HttpContentTypeResponseMapper, - }; - use crate::gateway_execution::request::RichRequest; - use http::StatusCode; - - use crate::getter::{get_response_headers_or_default, get_status_code_or_ok, GetterExt}; - use crate::path::Path; - - use crate::gateway_execution::WorkerRequestExecutorError; - use crate::headers::ResolvedResponseHeaders; - use golem_wasm::ValueAndType; - use poem::{Body, IntoResponse, ResponseParts}; - use rib::RibResult; - - #[derive(Debug)] - pub(crate) struct IntermediateRibResultHttpResponse { - body: Option, - status: StatusCode, - headers: ResolvedResponseHeaders, - } - - impl IntermediateRibResultHttpResponse { - pub(crate) fn from( - evaluation_result: &RibResult, - ) -> Result { - match evaluation_result { - RibResult::Val(rib_result) => { - let status = - get_status_code_or_ok(rib_result).map_err(WorkerRequestExecutorError)?; - - let headers = get_response_headers_or_default(rib_result) - .map_err(WorkerRequestExecutorError)?; - - let body = rib_result - .get_optional(&Path::from_key("body")) - .unwrap_or(rib_result.clone()); - - Ok(IntermediateRibResultHttpResponse { - body: Some(body), - status, - headers, - }) - } - RibResult::Unit => Ok(IntermediateRibResultHttpResponse { - body: None, - status: StatusCode::default(), - headers: ResolvedResponseHeaders::default(), - }), - } - } - - pub(crate) fn to_http_response(&self, request_details: &RichRequest) -> poem::Response { - let response_content_type = self.headers.get_content_type(); - let response_headers = self.headers.headers.clone(); - - let status = &self.status; - let evaluation_result = &self.body; - - let accepted_content_types = request_details - .underlying - .header(http::header::ACCEPT) - .map(|s| s.to_string()); - - let content_type = - ContentTypeHeaders::from(response_content_type, accepted_content_types); - - let response = match evaluation_result { - Some(type_annotated_value) => { - match type_annotated_value.to_http_resp_with_content_type(content_type) { - Ok(body_with_header) => { - let mut response = body_with_header.into_response(); - response.set_status(*status); - response.headers_mut().extend(response_headers); - response - } - Err(content_map_error) => poem::Response::builder() - .status(StatusCode::BAD_REQUEST) - .body(Body::from_string(content_map_error.to_string())), - } - } - None => { - let parts = ResponseParts { - status: *status, - version: Default::default(), - headers: response_headers, - extensions: Default::default(), - }; - - poem::Response::from_parts(parts, Body::empty()) - } - }; - - response - } - } -} - -#[cfg(test)] -mod test { - use async_trait::async_trait; - use std::sync::Arc; - use test_r::test; - - use crate::gateway_execution::gateway_session_store::{ - DataKey, DataValue, GatewaySessionError, GatewaySessionStore, SessionId, - }; - use crate::gateway_execution::request::RichRequest; - use crate::gateway_execution::to_response::ToHttpResponse; - use golem_wasm::analysis::analysed_type::record; - use golem_wasm::analysis::NameTypePair; - use golem_wasm::{IntoValueAndType, Value, ValueAndType}; - use http::header::CONTENT_TYPE; - use http::StatusCode; - use rib::RibResult; - - fn create_record(values: Vec<(String, ValueAndType)>) -> ValueAndType { - let mut fields = vec![]; - let mut field_values = vec![]; - - for (key, vnt) in values { - fields.push(NameTypePair { - name: key, - typ: vnt.typ, - }); - field_values.push(vnt.value); - } - - ValueAndType { - value: Value::Record(field_values), - typ: record(fields), - } - } - - fn test_request() -> RichRequest { - RichRequest::new(poem::Request::default()) - } - - #[test] - async fn test_evaluation_result_to_response_with_http_specifics() { - let record = create_record(vec![ - ("status".to_string(), 400u16.into_value_and_type()), - ( - "headers".to_string(), - create_record(vec![( - "Content-Type".to_string(), - "application/json".into_value_and_type(), - )]), - ), - ("body".to_string(), "Hello".into_value_and_type()), - ]); - - let evaluation_result: RibResult = RibResult::Val(record); - - let session_store: Arc = Arc::new(TestSessionStore); - - let http_response: poem::Response = evaluation_result - .to_response(&test_request(), &session_store) - .await; - - let (response_parts, body) = http_response.into_parts(); - let body = body.into_string().await.unwrap(); - let headers = response_parts.headers; - let status = response_parts.status; - - let expected_body = "\"Hello\""; - let expected_headers = poem::web::headers::HeaderMap::from_iter(vec![( - CONTENT_TYPE, - "application/json".parse().unwrap(), - )]); - - let expected_status = StatusCode::BAD_REQUEST; - - assert_eq!(body, expected_body); - assert_eq!(headers.clone(), expected_headers); - assert_eq!(status, expected_status); - } - - #[test] - async fn test_evaluation_result_to_response_with_no_http_specifics() { - let evaluation_result: RibResult = RibResult::Val("Healthy".into_value_and_type()); - - let session_store: Arc = Arc::new(TestSessionStore); - - let http_response: poem::Response = evaluation_result - .to_response(&test_request(), &session_store) - .await; - - let (response_parts, body) = http_response.into_parts(); - let body = body.into_string().await.unwrap(); - let headers = response_parts.headers; - let status = response_parts.status; - - let expected_body = "Healthy"; - - // Deault content response is application/json. Refer HttpResponse - let expected_headers = poem::web::headers::HeaderMap::from_iter(vec![( - CONTENT_TYPE, - "application/json".parse().unwrap(), - )]); - let expected_status = StatusCode::OK; - - assert_eq!(body, expected_body); - assert_eq!(headers.clone(), expected_headers); - assert_eq!(status, expected_status); - } - - struct TestSessionStore; - - #[async_trait] - impl GatewaySessionStore for TestSessionStore { - async fn insert( - &self, - _session_id: SessionId, - _data_key: DataKey, - _data_value: DataValue, - ) -> Result<(), GatewaySessionError> { - Err(GatewaySessionError::InternalError( - "unimplemented".to_string(), - )) - } - - async fn get( - &self, - _session_id: &SessionId, - _data_key: &DataKey, - ) -> Result { - Err(GatewaySessionError::InternalError( - "unimplemented".to_string(), - )) - } - } -} diff --git a/golem-worker-service/src/gateway_execution/to_response_failure.rs b/golem-worker-service/src/gateway_execution/to_response_failure.rs deleted file mode 100644 index d48e078aba..0000000000 --- a/golem-worker-service/src/gateway_execution/to_response_failure.rs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use golem_common::SafeDisplay; -use http::StatusCode; -use poem::Body; - -pub trait ToHttpResponseFromSafeDisplay { - fn to_response_from_safe_display(&self, get_status_code: F) -> poem::Response - where - Self: SafeDisplay, - F: Fn(&Self) -> StatusCode, - Self: Sized; -} - -// Only SafeDisplay'd errors are allowed to be embedded in any output response -impl ToHttpResponseFromSafeDisplay for E { - fn to_response_from_safe_display(&self, get_status_code: F) -> poem::Response - where - F: Fn(&Self) -> StatusCode, - Self: Sized, - { - poem::Response::builder() - .status(get_status_code(self)) - .body(Body::from_string(self.to_safe_string())) - } -} diff --git a/golem-worker-service/src/gateway_middleware/auth.rs b/golem-worker-service/src/gateway_middleware/auth.rs deleted file mode 100644 index 4a741e279a..0000000000 --- a/golem-worker-service/src/gateway_middleware/auth.rs +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use super::{MiddlewareError, MiddlewareSuccess}; -use crate::gateway_execution::auth_call_back_binding_handler::AuthorisationError; -use crate::gateway_execution::gateway_session_store::{ - DataKey, DataValue, GatewaySessionError, GatewaySessionStore, SessionId, -}; -use crate::gateway_execution::request::RichRequest; -use crate::gateway_security::{IdentityProvider, OpenIdClient}; -use golem_common::SafeDisplay; -use golem_service_base::custom_api::SecuritySchemeDetails; -use http::StatusCode; -use openidconnect::core::{CoreIdToken, CoreIdTokenClaims, CoreIdTokenVerifier}; -use openidconnect::{ClaimsVerificationError, Nonce}; -use std::str::FromStr; -use std::sync::Arc; -use tracing::{debug, error}; - -pub async fn apply_http_auth( - security_scheme: &SecuritySchemeDetails, - input: &RichRequest, - session_store: &Arc, - identity_provider: &Arc, -) -> Result { - let open_id_client = identity_provider - .get_client(security_scheme) - .await - .map_err(|err| { - MiddlewareError::Unauthorized(AuthorisationError::IdentityProviderError(err)) - })?; - - let identity_token_verifier = open_id_client.id_token_verifier(); - - let cookie_values = input.get_cookie_values(); - - let id_token = cookie_values.get("id_token"); - let state = cookie_values.get("session_id"); - - if let (Some(id_token), Some(state)) = (id_token, state) { - get_session_details_or_redirect( - state, - identity_token_verifier, - id_token, - session_store, - input, - identity_provider, - &open_id_client, - security_scheme, - ) - .await - } else { - redirect( - session_store, - input, - identity_provider, - &open_id_client, - security_scheme, - ) - .await - } -} - -async fn get_session_details_or_redirect( - state_from_request: &str, - identity_token_verifier: CoreIdTokenVerifier<'_>, - id_token: &str, - session_store: &Arc, - input: &RichRequest, - identity_provider: &Arc, - open_id_client: &OpenIdClient, - security_scheme: &SecuritySchemeDetails, -) -> Result { - let session_id = SessionId(state_from_request.to_string()); - - let nonce_from_session = session_store.get(&session_id, &DataKey::nonce()).await; - - match nonce_from_session { - Ok(nonce) => { - let id_token = CoreIdToken::from_str(id_token).map_err(|err| { - debug!( - "Failed to parse id token for session {}: {}", - err, session_id.0 - ); - MiddlewareError::Unauthorized(AuthorisationError::InvalidToken) - })?; - - get_claims( - &nonce, - id_token, - identity_token_verifier, - &session_id, - session_store, - input, - identity_provider, - open_id_client, - security_scheme, - ) - .await - } - Err(GatewaySessionError::MissingValue { .. }) => { - redirect( - session_store, - input, - identity_provider, - open_id_client, - security_scheme, - ) - .await - } - Err(err) => { - debug!( - "Failed to get nonce from session store: {:?} for session {}", - err, session_id.0 - ); - Err(MiddlewareError::Unauthorized( - AuthorisationError::SessionError(err), - )) - } - } -} - -async fn get_claims( - nonce: &DataValue, - id_token: CoreIdToken, - identity_token_verifier: CoreIdTokenVerifier<'_>, - session_id: &SessionId, - session_store: &Arc, - input: &RichRequest, - identity_provider: &Arc, - open_id_client: &OpenIdClient, - security_scheme: &SecuritySchemeDetails, -) -> Result { - if let Some(nonce) = nonce.as_string() { - let token_claims_result: Result<&CoreIdTokenClaims, ClaimsVerificationError> = - id_token.claims(&identity_token_verifier, &Nonce::new(nonce)); - - match token_claims_result { - Ok(claims) => { - store_claims_in_session_store(session_id, claims, session_store).await?; - - Ok(MiddlewareSuccess::PassThrough { - session_id: Some(session_id.clone()), - }) - } - Err(ClaimsVerificationError::Expired(_)) => { - redirect( - session_store, - input, - identity_provider, - open_id_client, - security_scheme, - ) - .await - } - Err(claims_verification_error) => { - error!("Invalid token for session {}", claims_verification_error); - - Err(MiddlewareError::Unauthorized( - AuthorisationError::InvalidToken, - )) - } - } - } else { - Err(MiddlewareError::Unauthorized( - AuthorisationError::InvalidNonce, - )) - } -} - -async fn redirect( - session_store: &Arc, - input: &RichRequest, - identity_provider: &Arc, - client: &OpenIdClient, - security_scheme: &SecuritySchemeDetails, -) -> Result { - let redirect_uri = input - .underlying - .uri() - .path_and_query() - .ok_or(MiddlewareError::InternalError( - "Failed to get redirect uri".to_string(), - ))? - .to_string(); - - let authorization = - identity_provider.get_authorization_url(client, security_scheme.scopes.clone(), None, None); - - let state = authorization.csrf_state.secret(); - - let session_id = SessionId(state.clone()); - let nonce_data_key = DataKey::nonce(); - let nonce_data_value = DataValue(serde_json::Value::String( - authorization.nonce.secret().clone(), - )); - - let redirect_url_data_key = DataKey::redirect_url(); - - let redirect_url_data_value = DataValue(serde_json::Value::String(redirect_uri)); - - session_store - .insert(session_id.clone(), nonce_data_key, nonce_data_value) - .await - .map_err(|err| MiddlewareError::Unauthorized(AuthorisationError::SessionError(err)))?; - session_store - .insert(session_id, redirect_url_data_key, redirect_url_data_value) - .await - .map_err(|err| MiddlewareError::Unauthorized(AuthorisationError::SessionError(err)))?; - - let response = poem::Response::builder(); - let result = response - .header("Location", authorization.url.to_string()) - .status(StatusCode::FOUND) - .body(()); - - Ok(MiddlewareSuccess::Redirect(result)) -} - -async fn store_claims_in_session_store( - session_id: &SessionId, - claims: &CoreIdTokenClaims, - session_store: &Arc, -) -> Result<(), MiddlewareError> { - let claims_data_key = DataKey::claims(); - let json = serde_json::to_value(claims) - .map_err(|err| MiddlewareError::InternalError(err.to_string()))?; - - let claims_data_value = DataValue(json); - - session_store - .insert(session_id.clone(), claims_data_key, claims_data_value) - .await - .map_err(|err| MiddlewareError::InternalError(err.to_safe_string())) -} diff --git a/golem-worker-service/src/gateway_middleware/cors.rs b/golem-worker-service/src/gateway_middleware/cors.rs deleted file mode 100644 index 598743a2a6..0000000000 --- a/golem-worker-service/src/gateway_middleware/cors.rs +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::gateway_execution::request::RichRequest; -use golem_service_base::custom_api::HttpCors; -use http::{HeaderValue, Method}; - -#[derive(Debug)] -#[allow(clippy::enum_variant_names)] -pub enum CorsError { - OriginNotAllowed, - MethodNotAllowed, - HeadersNotAllowed, -} - -pub fn apply_cors(cors: &HttpCors, request: &RichRequest) -> Result<(), CorsError> { - let origin = match request.headers().get(http::header::ORIGIN) { - Some(origin) => origin.clone(), - None => return Ok(()), - }; - - if let OriginStatus::NotAllowed = check_origin(cors, &origin) { - return Err(CorsError::OriginNotAllowed); - } - - if request.underlying.method() == Method::OPTIONS { - let allow_method = request - .headers() - .get(http::header::ACCESS_CONTROL_REQUEST_METHOD) - .and_then(|val| val.to_str().ok()) - .and_then(|m| m.parse::().ok()); - - if let Some(method) = allow_method { - if !cors.allow_methods.trim().is_empty() - && !split_origin(&cors.allow_methods) - .any(|m| m.eq_ignore_ascii_case(method.as_str())) - { - return Err(CorsError::MethodNotAllowed); - } - } else { - return Err(CorsError::MethodNotAllowed); - } - - check_headers_allowed(cors, request)?; - } - - Ok(()) -} - -pub fn add_cors_headers_to_response(cors: &HttpCors, response: &mut poem::Response) { - response.headers_mut().insert( - http::header::ACCESS_CONTROL_ALLOW_ORIGIN, - cors.allow_origin.clone().parse().unwrap(), - ); - - if let Some(allow_credentials) = &cors.allow_credentials { - response.headers_mut().insert( - http::header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - allow_credentials.to_string().clone().parse().unwrap(), - ); - } - - if let Some(expose_headers) = &cors.expose_headers { - response.headers_mut().insert( - http::header::ACCESS_CONTROL_EXPOSE_HEADERS, - expose_headers.clone().parse().unwrap(), - ); - } -} - -fn split_origin(input: &str) -> impl Iterator { - input.split(',').map(|s| s.trim()).filter(|s| !s.is_empty()) -} - -enum OriginStatus { - AllowedExact, - AllowedWildcard, - NotAllowed, -} - -fn check_origin(cors: &HttpCors, origin: &HeaderValue) -> OriginStatus { - let origin_str = match origin.to_str() { - Ok(s) => s, - Err(_) => return OriginStatus::NotAllowed, - }; - - if split_origin(&cors.allow_origin).any(|o| o == origin_str) { - return OriginStatus::AllowedExact; - } - - if split_origin(&cors.allow_origin) - .any(|pattern| pattern.contains('*') && wildcard_match(pattern, origin_str)) - { - return OriginStatus::AllowedWildcard; - } - - OriginStatus::NotAllowed -} - -fn wildcard_match(pattern: &str, text: &str) -> bool { - if !pattern.contains('*') { - return pattern == text; - } - - let parts: Vec<&str> = pattern.split('*').collect(); - if parts.len() == 2 { - text.starts_with(parts[0]) && text.ends_with(parts[1]) - } else { - false - } -} - -fn check_headers_allowed<'a>( - cors: &HttpCors, - req: &'a RichRequest, -) -> Result, CorsError> { - let request_headers = req - .headers() - .get(http::header::ACCESS_CONTROL_REQUEST_HEADERS); - - if let Some(headers_value) = request_headers { - let allow_list: Vec<_> = split_origin(&cors.allow_headers).collect(); - if allow_list.is_empty() { - return Ok(Some(headers_value)); - } - - let header_str = headers_value - .to_str() - .map_err(|_| CorsError::HeadersNotAllowed)?; - - let all_allowed = split_origin(header_str).all(|h| { - allow_list - .iter() - .any(|&allowed| allowed.eq_ignore_ascii_case(h)) - }); - - if !all_allowed { - return Err(CorsError::HeadersNotAllowed); - } - - Ok(Some(headers_value)) - } else { - Ok(None) - } -} diff --git a/golem-worker-service/src/gateway_middleware/mod.rs b/golem-worker-service/src/gateway_middleware/mod.rs deleted file mode 100644 index 4d0a822ea6..0000000000 --- a/golem-worker-service/src/gateway_middleware/mod.rs +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -mod auth; -mod cors; - -use self::auth::apply_http_auth; -use self::cors::{add_cors_headers_to_response, apply_cors, CorsError}; -use crate::gateway_execution::auth_call_back_binding_handler::AuthorisationError; -use crate::gateway_execution::gateway_session_store::{GatewaySessionStore, SessionId}; -use crate::gateway_execution::request::RichRequest; -use crate::gateway_security::IdentityProvider; -use crate::model::HttpMiddleware; -use golem_common::SafeDisplay; -use std::sync::Arc; - -pub enum MiddlewareSuccess { - PassThrough { session_id: Option }, - Redirect(poem::Response), -} - -#[derive(Debug)] -pub enum MiddlewareError { - Unauthorized(AuthorisationError), - CorsError(CorsError), - InternalError(String), -} - -impl SafeDisplay for MiddlewareError { - fn to_safe_string(&self) -> String { - match self { - MiddlewareError::Unauthorized(msg) => format!("Unauthorized: {}", msg.to_safe_string()), - MiddlewareError::CorsError(error) => match error { - CorsError::OriginNotAllowed => "CORS Error: Origin not allowed".to_string(), - CorsError::MethodNotAllowed => "CORS Error: Method not allowed".to_string(), - CorsError::HeadersNotAllowed => "CORS Error: Headers not allowed".to_string(), - }, - MiddlewareError::InternalError(msg) => { - format!("Internal Server Error: {msg}") - } - } - } -} - -pub async fn process_middleware_in( - middlewares: &Vec, - rich_request: &RichRequest, - session_store: &Arc, - identity_provider: &Arc, -) -> Result { - let mut final_session_id = None; - - for middleware in middlewares { - match middleware { - HttpMiddleware::Cors(cors) => { - apply_cors(cors, rich_request).map_err(MiddlewareError::CorsError)?; - } - HttpMiddleware::AuthenticateRequest(auth) => { - let result = - apply_http_auth(auth, rich_request, session_store, identity_provider).await?; - - match result { - MiddlewareSuccess::Redirect(response) => { - return Ok(MiddlewareSuccess::Redirect(response)) - } - MiddlewareSuccess::PassThrough { session_id } => { - final_session_id = session_id; - } - } - } - } - } - - Ok(MiddlewareSuccess::PassThrough { - session_id: final_session_id, - }) -} - -pub async fn process_middleware_out( - middlewares: &Vec, - response: &mut poem::Response, -) -> Result<(), MiddlewareError> { - for middleware in middlewares { - match middleware { - HttpMiddleware::Cors(cors) => { - add_cors_headers_to_response(cors, response); - } - HttpMiddleware::AuthenticateRequest(_) => {} - } - } - - Ok(()) -} diff --git a/golem-worker-service/src/getter.rs b/golem-worker-service/src/getter.rs deleted file mode 100644 index 0922d2792b..0000000000 --- a/golem-worker-service/src/getter.rs +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::headers::ResolvedResponseHeaders; -use crate::path::{Path, PathComponent}; -use golem_wasm::analysis::{AnalysedType, TypeList, TypeRecord, TypeTuple}; -use golem_wasm::json::ValueAndTypeJsonExtensions; -use golem_wasm::{Value, ValueAndType}; -use http::StatusCode; -use rib::GetLiteralValue; -use rib::LiteralValue; - -pub trait Getter { - fn get(&self, key: &Path) -> Result; -} - -#[derive(Debug, PartialEq, Eq, thiserror::Error)] -pub enum GetError { - #[error("Key not found: {0}")] - KeyNotFound(String), - #[error("Index not found: {0}")] - IndexNotFound(usize), - #[error("Not a record: key_name: {key_name}, original_value: {found}")] - NotRecord { key_name: String, found: String }, - #[error("Not an array: index: {index}, original_value: {found}")] - NotArray { index: usize, found: String }, - #[error("Internal error: {0}")] - Internal(String), -} - -// To deal with fields in a TypeAnnotatedValue (that's returned from golem-rib) -impl Getter for ValueAndType { - fn get(&self, key: &Path) -> Result { - let size = key.0.len(); - fn go( - value_and_type: &ValueAndType, - paths: Vec, - index: usize, - size: usize, - ) -> Result { - if index < size { - match &paths[index] { - PathComponent::KeyName(key) => { - match (&value_and_type.typ, &value_and_type.value) { - ( - AnalysedType::Record(TypeRecord { fields, .. }), - Value::Record(field_values), - ) => { - let new_value = fields - .iter() - .zip(field_values) - .find(|(field, _value)| field.name == key.0) - .map(|(field, value)| { - ValueAndType::new(value.clone(), field.typ.clone()) - }); - match new_value { - Some(new_value) => go(&new_value, paths, index + 1, size), - _ => Err(GetError::KeyNotFound(key.0.clone())), - } - } - _ => match value_and_type.to_json_value() { - Ok(json) => Err(GetError::NotRecord { - key_name: key.0.clone(), - found: json.to_string(), - }), - Err(err) => Err(GetError::Internal(err)), - }, - } - } - PathComponent::Index(value_index) => match get_array(value_and_type) { - Some(type_values) => { - let new_value = type_values.get(value_index.0); - match new_value { - Some(new_value) => go(new_value, paths, index + 1, size), - None => Err(GetError::IndexNotFound(value_index.0)), - } - } - None => match value_and_type.to_json_value() { - Ok(json) => Err(GetError::NotArray { - index: value_index.0, - found: json.to_string(), - }), - Err(err) => Err(GetError::Internal(err)), - }, - }, - } - } else { - Ok(value_and_type.clone()) - } - } - - go(self, key.0.clone(), 0, size) - } -} - -fn get_array(value: &ValueAndType) -> Option> { - match (&value.typ, &value.value) { - (AnalysedType::List(TypeList { inner, .. }), Value::List(values)) => { - let vec = values - .iter() - .map(|v| ValueAndType::new(v.clone(), (**inner).clone())) - .collect::>(); - Some(vec) - } - (AnalysedType::Tuple(TypeTuple { items, .. }), Value::Tuple(values)) => { - let vec = items - .iter() - .zip(values) - .map(|(typ, v)| ValueAndType::new(v.clone(), typ.clone())) - .collect::>(); - Some(vec) - } - _ => None, - } -} - -pub trait GetterExt { - fn get_optional(&self, key: &Path) -> Option; -} - -impl> GetterExt for T { - fn get_optional(&self, key: &Path) -> Option { - self.get(key).ok() - } -} - -pub fn get_response_headers( - field_values: &[Value], - record: &TypeRecord, -) -> Result, String> { - match record - .fields - .iter() - .position(|pair| &pair.name == "headers") - { - None => Ok(None), - Some(field_position) => Ok(Some(ResolvedResponseHeaders::from_typed_value( - ValueAndType::new( - field_values[field_position].clone(), - record.fields[field_position].typ.clone(), - ), - )?)), - } -} - -pub fn get_response_headers_or_default( - value: &ValueAndType, -) -> Result { - match value { - ValueAndType { - value: Value::Record(field_values), - typ: AnalysedType::Record(record), - } => get_response_headers(field_values, record).map(|headers| headers.unwrap_or_default()), - _ => Ok(ResolvedResponseHeaders::default()), - } -} - -pub fn get_status_code( - field_values: &[Value], - record: &TypeRecord, -) -> Result, String> { - match record - .fields - .iter() - .position(|field| &field.name == "status") - { - None => Ok(None), - Some(field_position) => Ok(Some(get_status_code_inner(ValueAndType::new( - field_values[field_position].clone(), - record.fields[field_position].typ.clone(), - ))?)), - } -} - -pub fn get_status_code_or_ok(value: &ValueAndType) -> Result { - match value { - ValueAndType { - value: Value::Record(field_values), - typ: AnalysedType::Record(record), - } => get_status_code(field_values, record).map(|status| status.unwrap_or(StatusCode::OK)), - _ => Ok(StatusCode::OK), - } -} - -fn get_status_code_inner(status_code: ValueAndType) -> Result { - let status_res: Result = - match status_code.get_literal() { - Some(LiteralValue::String(status_str)) => status_str.parse().map_err(|e| { - format!( - "Invalid Status Code Expression. It is resolved to a string but not a number {status_str}. Error: {e}" - ) - }), - Some(LiteralValue::Num(number)) => number.to_string().parse().map_err(|e| { - format!( - "Invalid Status Code Expression. It is resolved to a number but not a u16 {number}. Error: {e}" - ) - }), - _ => Err(format!( - "Status Code Expression is evaluated to a complex value. It is resolved to {:?}", - status_code.value - )) - }; - - let status_u16 = status_res?; - - StatusCode::from_u16(status_u16).map_err(|e| - format!( - "Invalid Status Code. A valid status code cannot be formed from the evaluated status code expression {status_u16}. Error: {e}" - )) -} diff --git a/golem-worker-service/src/grpcapi/mod.rs b/golem-worker-service/src/grpcapi/mod.rs index b3235c27cd..12e44b0f80 100644 --- a/golem-worker-service/src/grpcapi/mod.rs +++ b/golem-worker-service/src/grpcapi/mod.rs @@ -15,27 +15,27 @@ mod error; mod worker; +use crate::bootstrap::Services; use crate::config::GrpcApiConfig; use crate::grpcapi::worker::WorkerGrpcApi; -use crate::service::Services; use futures::TryFutureExt; use golem_api_grpc::proto; use golem_api_grpc::proto::golem::common::{ErrorBody, ErrorsBody}; use golem_api_grpc::proto::golem::worker::v1::worker_service_server::WorkerServiceServer; use golem_api_grpc::proto::golem::worker::v1::{ - worker_error, worker_execution_error, WorkerError, WorkerExecutionError, + WorkerError, WorkerExecutionError, worker_error, worker_execution_error, }; -use golem_common::model::component::ComponentFilePath; use golem_common::model::WorkerId; +use golem_common::model::component::ComponentFilePath; use golem_service_base::grpc::server::GrpcServerTlsConfig; use golem_wasm::json::OptionallyValueAndTypeJson; use std::net::{Ipv4Addr, SocketAddrV4}; use tokio::net::TcpListener; use tokio::task::JoinSet; use tokio_stream::wrappers::TcpListenerStream; +use tonic::Status; use tonic::codec::CompressionEncoding; use tonic::transport::Server; -use tonic::Status; use tonic_tracing_opentelemetry::middleware; use tonic_tracing_opentelemetry::middleware::filters; use tracing::Instrument; diff --git a/golem-worker-service/src/grpcapi/worker.rs b/golem-worker-service/src/grpcapi/worker.rs index 182d6e7a92..d894dd79d3 100644 --- a/golem-worker-service/src/grpcapi/worker.rs +++ b/golem-worker-service/src/grpcapi/worker.rs @@ -16,22 +16,22 @@ use super::error::WorkerTraceErrorKind; use super::{bad_request_error, validate_protobuf_worker_id}; use crate::service::worker::WorkerService; use golem_api_grpc::proto::golem::common::Empty; +use golem_api_grpc::proto::golem::worker::InvokeResultTyped; use golem_api_grpc::proto::golem::worker::v1::worker_service_server::WorkerService as GrpcWorkerService; use golem_api_grpc::proto::golem::worker::v1::{ + CompletePromiseRequest, CompletePromiseResponse, ForkWorkerRequest, ForkWorkerResponse, + InvokeAndAwaitRequest, InvokeAndAwaitResponse, InvokeRequest, InvokeResponse, + LaunchNewWorkerRequest, LaunchNewWorkerResponse, LaunchNewWorkerSuccessResponse, + ResumeWorkerRequest, ResumeWorkerResponse, RevertWorkerRequest, RevertWorkerResponse, + UpdateWorkerRequest, UpdateWorkerResponse, WorkerError as GrpcWorkerError, complete_promise_response, fork_worker_response, invoke_and_await_response, invoke_response, launch_new_worker_response, resume_worker_response, revert_worker_response, - update_worker_response, CompletePromiseRequest, CompletePromiseResponse, ForkWorkerRequest, - ForkWorkerResponse, InvokeAndAwaitRequest, InvokeAndAwaitResponse, InvokeRequest, - InvokeResponse, LaunchNewWorkerRequest, LaunchNewWorkerResponse, - LaunchNewWorkerSuccessResponse, ResumeWorkerRequest, ResumeWorkerResponse, RevertWorkerRequest, - RevertWorkerResponse, UpdateWorkerRequest, UpdateWorkerResponse, - WorkerError as GrpcWorkerError, + update_worker_response, }; -use golem_api_grpc::proto::golem::worker::InvokeResultTyped; +use golem_common::model::WorkerId; use golem_common::model::component::ComponentRevision; use golem_common::model::oplog::OplogIndex; use golem_common::model::worker::WorkerUpdateMode; -use golem_common::model::WorkerId; use golem_common::recorded_grpc_api_request; use golem_service_base::grpc::{ proto_component_id_string, proto_idempotency_key_string, diff --git a/golem-worker-service/src/headers.rs b/golem-worker-service/src/headers.rs deleted file mode 100644 index a655e2feaa..0000000000 --- a/golem-worker-service/src/headers.rs +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use golem_wasm::analysis::AnalysedType; -use golem_wasm::{Value, ValueAndType}; -use http::HeaderMap; -use poem::web::headers::ContentType; -use rib::GetLiteralValue; -use std::collections::HashMap; -use std::str::FromStr; - -#[derive(Default, Debug, PartialEq)] -pub struct ResolvedResponseHeaders { - pub headers: HeaderMap, -} - -impl ResolvedResponseHeaders { - pub fn from_typed_value(header_map: ValueAndType) -> Result { - match header_map { - ValueAndType { - value: Value::Record(field_values), - typ: AnalysedType::Record(record), - } => { - let mut resolved_headers: HashMap = HashMap::new(); - - for (value, field_def) in field_values.into_iter().zip(record.fields) { - let value = ValueAndType::new(value, field_def.typ); - let value_str = value - .get_literal() - .map(|primitive| primitive.to_string()) - .unwrap_or_else(|| { - "header values in the http response should be a literal".to_string() - }); - - resolved_headers.insert(field_def.name, value_str); - } - - let headers = (&resolved_headers) - .try_into() - .map_err(|e: http::Error| e.to_string()) - .map_err(|e| format!("unable to infer valid headers. Error: {e}"))?; - - Ok(ResolvedResponseHeaders { headers }) - } - - _ => Err(format!( - "Header expression is not a record. It is resolved to {header_map}", - )), - } - } - - pub fn get_content_type(&self) -> Option { - self.headers - .get(http::header::CONTENT_TYPE.to_string()) - .and_then(|header_value| { - header_value - .to_str() - .ok() - .and_then(|header_str| ContentType::from_str(header_str).ok()) - }) - } -} - -#[cfg(test)] -mod test { - use crate::headers::ResolvedResponseHeaders; - use golem_wasm::analysis::analysed_type::{field, record}; - use golem_wasm::{IntoValueAndType, Value, ValueAndType}; - use http::{HeaderMap, HeaderValue}; - use test_r::test; - - fn create_record(values: Vec<(&str, ValueAndType)>) -> ValueAndType { - ValueAndType::new( - Value::Record(values.iter().map(|(_, vnt)| vnt.value.clone()).collect()), - record( - values - .iter() - .map(|(name, vnt)| field(name, vnt.typ.clone())) - .collect(), - ), - ) - } - - #[test] - fn test_get_response_headers_from_typed_value() { - let header_map: ValueAndType = create_record(vec![ - ("header1", "value1".into_value_and_type()), - ("header2", 1.0f32.into_value_and_type()), - ]); - - let resolved_headers = ResolvedResponseHeaders::from_typed_value(header_map).unwrap(); - - let mut header_map = HeaderMap::new(); - - header_map.insert("header1", HeaderValue::from_str("value1").unwrap()); - header_map.insert("header2", HeaderValue::from_str("1").unwrap()); - - let expected = ResolvedResponseHeaders { - headers: header_map, - }; - - assert_eq!(resolved_headers, expected) - } -} diff --git a/golem-worker-service/src/http_invocation_context.rs b/golem-worker-service/src/http_invocation_context.rs deleted file mode 100644 index c471f0dd33..0000000000 --- a/golem-worker-service/src/http_invocation_context.rs +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use golem_common::model::invocation_context::{ - AttributeValue, InvocationContextSpan, InvocationContextStack, TraceId, -}; -use golem_service_base::headers::TraceContextHeaders; -use std::collections::HashMap; - -pub fn extract_request_attributes(request: &poem::Request) -> HashMap { - let mut result = HashMap::new(); - - result.insert( - "request.method".to_string(), - AttributeValue::String(request.method().to_string()), - ); - result.insert( - "request.uri".to_string(), - AttributeValue::String(request.uri().to_string()), - ); - result.insert( - "request.remote_addr".to_string(), - AttributeValue::String(request.remote_addr().to_string()), - ); - - result -} - -pub fn invocation_context_from_request(request: &poem::Request) -> InvocationContextStack { - let trace_context_headers = TraceContextHeaders::parse(request.headers()); - let request_attributes = extract_request_attributes(request); - - match trace_context_headers { - Some(ctx) => { - // Trace context found in headers, starting a new span - let mut ctx = InvocationContextStack::new( - ctx.trace_id, - InvocationContextSpan::external_parent(ctx.parent_id), - ctx.trace_states, - ); - ctx.push( - InvocationContextSpan::local() - .with_attributes(request_attributes) - .with_parent(ctx.spans.first().clone()) - .build(), - ); - ctx - } - None => { - // No trace context in headers, starting a new trace - InvocationContextStack::new( - TraceId::generate(), - InvocationContextSpan::local() - .with_attributes(request_attributes) - .build(), - Vec::new(), - ) - } - } -} diff --git a/golem-worker-service/src/lib.rs b/golem-worker-service/src/lib.rs index 42b6e003ed..dd641a5fb8 100644 --- a/golem-worker-service/src/lib.rs +++ b/golem-worker-service/src/lib.rs @@ -13,24 +13,18 @@ // limitations under the License. pub mod api; +pub mod bootstrap; pub mod config; -pub mod gateway_execution; -// pub mod gateway_middleware; -pub mod gateway_router; -// pub mod gateway_security; -pub mod getter; +pub mod custom_api; pub mod grpcapi; -pub mod headers; -pub mod http_invocation_context; pub mod metrics; pub mod model; pub mod path; pub mod service; -// pub mod swagger_ui; +use crate::bootstrap::Services; use crate::config::WorkerServiceConfig; -use crate::service::Services; -use anyhow::{anyhow, Context}; +use anyhow::{Context, anyhow}; use golem_common::poem::LazyEndpointExt; use opentelemetry_sdk::trace::SdkTracer; use poem::endpoint::{BoxEndpoint, PrometheusExporter}; @@ -40,7 +34,7 @@ use poem::middleware::{CookieJarManager, Cors, OpenTelemetryMetrics, OpenTelemet use poem::{EndpointExt, Route}; use prometheus::Registry; use tokio::task::JoinSet; -use tracing::{info, Instrument}; +use tracing::{Instrument, info}; #[cfg(test)] test_r::enable!(); @@ -69,9 +63,7 @@ impl WorkerService { config: WorkerServiceConfig, prometheus_registry: Registry, ) -> anyhow::Result { - let services: Services = Services::new(&config) - .await - .map_err(|err| anyhow!(err).context("Service initialization"))?; + let services: Services = Services::new(&config).await?; Ok(Self { config, @@ -180,7 +172,7 @@ impl WorkerService { tracer: Option, ) -> Result { let route = Route::new() - .nest("/", api::custom_http_request_api(&self.services)) + .nest("/", custom_api::make_custom_api_endpoint(&self.services)) .with(OpenTelemetryMetrics::new()) .with_if_lazy(tracer.is_some(), || { OpenTelemetryTracing::new(tracer.unwrap()) diff --git a/golem-worker-service/src/model.rs b/golem-worker-service/src/model.rs index f9f631c98a..4dabb076ad 100644 --- a/golem-worker-service/src/model.rs +++ b/golem-worker-service/src/model.rs @@ -12,110 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -// use crate::gateway_api_definition::{ApiDefinitionId, ApiVersion}; -// use crate::gateway_api_deployment::ApiSite; -use golem_common::model::worker::WorkerMetadataDto; use golem_common::model::ScanCursor; -// use golem_service_base::custom_api::HttpCors; -use golem_common::model::account::AccountId; -use golem_common::model::agent::CorsOptions; -use golem_common::model::environment::EnvironmentId; -use golem_service_base::custom_api::SecuritySchemeDetails; -use golem_service_base::custom_api::{PathSegment, RequestBodySchema, RouteBehaviour, RouteId}; -use http::Method; +use golem_common::model::worker::WorkerMetadataDto; use poem_openapi::Object; use std::fmt::Debug; -use std::sync::Arc; #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, Object)] pub struct WorkersMetadataResponse { pub workers: Vec, pub cursor: Option, } - -// #[derive(Debug, Clone)] -// pub enum HttpMiddleware { -// Cors(HttpCors), -// AuthenticateRequest(SecuritySchemeDetails), -// } - -#[derive(Debug, Clone)] -pub struct SwaggerHtml(pub String); - -#[derive(Debug, Clone)] -pub struct SwaggerUiBinding { - pub swagger_html: Arc, -} - -#[derive(Debug)] -pub struct RichCompiledRoute { - pub account_id: AccountId, - pub environment_id: EnvironmentId, - pub route_id: RouteId, - pub method: Method, - pub path: Vec, - pub body: RequestBodySchema, - pub behavior: RouteBehaviour, - pub security_scheme: Option>, - pub cors: CorsOptions, -} - -// #[derive(Debug, Clone)] -// pub enum RichGatewayBindingCompiled { -// HttpCorsPreflight(Box), -// HttpAuthCallBack(Box), -// Worker(Box), -// FileServer(Box), -// HttpHandler(Box), -// SwaggerUi(SwaggerUiBinding), -// } - -// impl RichGatewayBindingCompiled { -// pub fn from_compiled_binding( -// binding: RouteBehaviour, -// precomputed_swagger_ui_htmls: &HashMap>, -// ) -> Result { -// match binding { -// RouteBehaviour::FileServer(inner) => { -// Ok(RichGatewayBindingCompiled::FileServer(inner)) -// } -// RouteBehaviour::HttpCorsPreflight(inner) => { -// Ok(RichGatewayBindingCompiled::HttpCorsPreflight(inner)) -// } -// RouteBehaviour::Worker(inner) => Ok(RichGatewayBindingCompiled::Worker(inner)), -// RouteBehaviour::HttpHandler(inner) => { -// Ok(RichGatewayBindingCompiled::HttpHandler(inner)) -// } -// RouteBehaviour::SwaggerUi(inner) => { -// let swagger_html = precomputed_swagger_ui_htmls -// .get(&inner.http_api_definition_id) -// .ok_or("no precomputed swagger html".to_string())? -// .clone(); -// Ok(RichGatewayBindingCompiled::SwaggerUi(SwaggerUiBinding { -// swagger_html, -// })) -// } -// } -// } -// } - -// #[derive(Debug, Clone)] -// pub struct RichCompiledRoute { -// pub account_id: AccountId, -// pub environment_id: EnvironmentId, -// pub method: RouteMethod, -// pub path: AllPathPatterns, -// pub binding: RichGatewayBindingCompiled, -// pub middlewares: Vec, -// } - -// impl RichCompiledRoute { -// pub fn get_security_middleware(&self) -> Option { -// for middleware in &self.middlewares { -// if let HttpMiddleware::AuthenticateRequest(security) = middleware { -// return Some(security.clone()); -// } -// } -// None -// } -// } diff --git a/golem-worker-service/src/server.rs b/golem-worker-service/src/server.rs index 3fced9d00b..4bdb7cf1f8 100644 --- a/golem-worker-service/src/server.rs +++ b/golem-worker-service/src/server.rs @@ -12,12 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::anyhow; -use golem_common::tracing::init_tracing_with_default_env_filter; use golem_common::SafeDisplay; -use golem_worker_service::config::{make_worker_service_config_loader, WorkerServiceConfig}; -use golem_worker_service::service::Services; +use golem_common::tracing::init_tracing_with_default_env_filter; use golem_worker_service::WorkerService; +use golem_worker_service::bootstrap::Services; +use golem_worker_service::config::{WorkerServiceConfig, make_worker_service_config_loader}; use opentelemetry::global; use opentelemetry_sdk::metrics::MeterProviderBuilder; use opentelemetry_sdk::trace::SdkTracer; @@ -62,9 +61,7 @@ fn main() -> anyhow::Result<()> { pub async fn dump_openapi_yaml() -> anyhow::Result<()> { let config = WorkerServiceConfig::default(); - let services = Services::new(&config) - .await - .map_err(|e| anyhow!("Services - init error: {}", e))?; + let services = Services::new(&config).await?; let open_api_service = golem_worker_service::api::make_open_api_service(&services); println!("{}", open_api_service.spec_yaml()); Ok(()) diff --git a/golem-worker-service/src/service/auth.rs b/golem-worker-service/src/service/auth.rs index a1ab610644..a6654274f6 100644 --- a/golem-worker-service/src/service/auth.rs +++ b/golem-worker-service/src/service/auth.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use golem_common::cache::{BackgroundEvictionMode, Cache, FullCacheEvictionMode, SimpleCache}; use golem_common::model::auth::TokenSecret; use golem_common::model::environment::EnvironmentId; -use golem_common::{error_forwarding, SafeDisplay}; +use golem_common::{SafeDisplay, error_forwarding}; use golem_service_base::clients::registry::{RegistryService, RegistryServiceError}; use golem_service_base::model::auth::AuthorizationError; use golem_service_base::model::auth::{AuthCtx, AuthDetailsForEnvironment, EnvironmentAction}; diff --git a/golem-worker-service/src/service/component.rs b/golem-worker-service/src/service/component.rs index df22b7f5fc..58eaf86722 100644 --- a/golem-worker-service/src/service/component.rs +++ b/golem-worker-service/src/service/component.rs @@ -18,7 +18,7 @@ use async_trait::async_trait; use golem_common::cache::{BackgroundEvictionMode, Cache, FullCacheEvictionMode, SimpleCache}; use golem_common::model::component::ComponentId; use golem_common::model::component::{ComponentDto, ComponentRevision}; -use golem_common::{error_forwarding, SafeDisplay}; +use golem_common::{SafeDisplay, error_forwarding}; use golem_service_base::clients::registry::{RegistryService, RegistryServiceError}; use std::sync::Arc; diff --git a/golem-worker-service/src/service/limit.rs b/golem-worker-service/src/service/limit.rs index 193df4df55..5549d12667 100644 --- a/golem-worker-service/src/service/limit.rs +++ b/golem-worker-service/src/service/limit.rs @@ -13,9 +13,9 @@ // limitations under the License. use async_trait::async_trait; -use golem_common::model::account::AccountId; use golem_common::model::WorkerId; -use golem_common::{error_forwarding, SafeDisplay}; +use golem_common::model::account::AccountId; +use golem_common::{SafeDisplay, error_forwarding}; use golem_service_base::clients::registry::{RegistryService, RegistryServiceError}; use std::sync::Arc; diff --git a/golem-worker-service/src/service/mod.rs b/golem-worker-service/src/service/mod.rs index ce6c293c8e..bcd0ffb775 100644 --- a/golem-worker-service/src/service/mod.rs +++ b/golem-worker-service/src/service/mod.rs @@ -16,207 +16,3 @@ pub mod auth; pub mod component; pub mod limit; pub mod worker; - -use self::auth::{AuthService, RemoteAuthService}; -use self::component::RemoteComponentService; -use self::limit::{LimitService, RemoteLimitService}; -use self::worker::WorkerService; -use crate::config::WorkerServiceConfig; -// use crate::gateway_execution::api_definition_lookup::{ -// HttpApiDefinitionsLookup, RegistryServiceApiDefinitionsLookup, -// }; -// use crate::gateway_execution::auth_call_back_binding_handler::{ -// AuthCallBackBindingHandler, DefaultAuthCallBackBindingHandler, -// }; -// use crate::gateway_execution::file_server_binding_handler::FileServerBindingHandler; -// use crate::gateway_execution::gateway_http_input_executor::GatewayHttpInputExecutor; -// use crate::gateway_execution::gateway_session_store::{ -// GatewaySessionStore, RedisGatewaySession, RedisGatewaySessionExpiration, SqliteGatewaySession, -// SqliteGatewaySessionExpiration, -// }; -// use crate::gateway_execution::http_handler_binding_handler::HttpHandlerBindingHandler; -// use crate::gateway_execution::route_resolver::RouteResolver; -// use crate::gateway_execution::GatewayWorkerRequestExecutor; -// use crate::gateway_security::DefaultIdentityProvider; -use crate::gateway_execution::api_definition_lookup::{ - HttpApiDefinitionsLookup, RegistryServiceApiDefinitionsLookup, -}; -use crate::gateway_execution::request_handler::RequestHandler; -use crate::gateway_execution::route_resolver::RouteResolver; -use crate::service::component::ComponentService; -use crate::service::worker::{AgentsService, WorkerClient, WorkerExecutorWorkerClient}; -use golem_api_grpc::proto::golem::workerexecutor::v1::worker_executor_client::WorkerExecutorClient; -use golem_service_base::clients::registry::{GrpcRegistryService, RegistryService}; -use golem_service_base::config::BlobStorageConfig; -use golem_service_base::db::sqlite::SqlitePool; -use golem_service_base::grpc::client::MultiTargetGrpcClient; -use golem_service_base::service::initial_component_files::InitialComponentFilesService; -use golem_service_base::service::routing_table::{RoutingTableService, RoutingTableServiceDefault}; -use golem_service_base::storage::blob::BlobStorage; -use std::sync::Arc; -use tonic::codec::CompressionEncoding; - -#[derive(Clone)] -pub struct Services { - pub auth_service: Arc, - pub limit_service: Arc, - pub component_service: Arc, - pub worker_service: Arc, - pub request_handler: Arc, - pub agents_service: Arc, -} - -impl Services { - pub async fn new(config: &WorkerServiceConfig) -> Result { - let registry_service_client: Arc = - Arc::new(GrpcRegistryService::new(&config.registry_service)); - - let auth_service: Arc = Arc::new(RemoteAuthService::new( - registry_service_client.clone(), - &config.auth_service, - )); - - // let gateway_session_store: Arc = - // match &config.gateway_session_storage { - // GatewaySessionStorageConfig::Redis(redis_config) => { - // let redis = RedisPool::configured(redis_config) - // .await - // .map_err(|e| e.to_string())?; - - // let gateway_session_with_redis = - // RedisGatewaySession::new(redis, RedisGatewaySessionExpiration::default()); - - // Arc::new(gateway_session_with_redis) - // } - - // GatewaySessionStorageConfig::Sqlite(sqlite_config) => { - // let pool = SqlitePool::configured(sqlite_config) - // .await - // .map_err(|e| e.to_string())?; - - // let gateway_session_with_sqlite = - // SqliteGatewaySession::new(pool, SqliteGatewaySessionExpiration::default()) - // .await?; - - // Arc::new(gateway_session_with_sqlite) - // } - // }; - - let blob_storage: Arc = match &config.blob_storage { - BlobStorageConfig::S3(config) => Arc::new( - golem_service_base::storage::blob::s3::S3BlobStorage::new(config.clone()).await, - ), - BlobStorageConfig::LocalFileSystem(config) => Arc::new( - golem_service_base::storage::blob::fs::FileSystemBlobStorage::new(&config.root) - .await - .map_err(|e| e.to_string())?, - ), - BlobStorageConfig::Sqlite(sqlite) => { - let pool = SqlitePool::configured(sqlite) - .await - .map_err(|e| format!("Failed to create sqlite pool: {e}"))?; - Arc::new( - golem_service_base::storage::blob::sqlite::SqliteBlobStorage::new(pool.clone()) - .await - .map_err(|e| e.to_string())?, - ) - } - BlobStorageConfig::InMemory(_) => { - Arc::new(golem_service_base::storage::blob::memory::InMemoryBlobStorage::new()) - } - _ => { - return Err("Unsupported blob storage configuration".to_string()); - } - }; - - let _initial_component_files_service: Arc = - Arc::new(InitialComponentFilesService::new(blob_storage.clone())); - - let component_service: Arc = Arc::new(RemoteComponentService::new( - registry_service_client.clone(), - &config.component_service, - )); - - // let identity_provider = Arc::new(DefaultIdentityProvider); - - let limit_service: Arc = - Arc::new(RemoteLimitService::new(registry_service_client.clone())); - - let routing_table_service: Arc = Arc::new( - RoutingTableServiceDefault::new(config.routing_table.clone()), - ); - - let worker_executor_clients = MultiTargetGrpcClient::new( - "worker_executor", - |channel| { - WorkerExecutorClient::new(channel) - .send_compressed(CompressionEncoding::Gzip) - .accept_compressed(CompressionEncoding::Gzip) - }, - config.worker_executor.client.clone(), - ); - - let worker_client: Arc = Arc::new(WorkerExecutorWorkerClient::new( - worker_executor_clients.clone(), - config.worker_executor.retries.clone(), - routing_table_service.clone(), - )); - - let worker_service: Arc = Arc::new(WorkerService::new( - component_service.clone(), - auth_service.clone(), - limit_service.clone(), - worker_client.clone(), - )); - - // let gateway_worker_request_executor: Arc = Arc::new( - // GatewayWorkerRequestExecutor::new(worker_service.clone(), component_service.clone()), - // ); - - // let file_server_binding_handler: Arc = - // Arc::new(FileServerBindingHandler::new( - // component_service.clone(), - // initial_component_files_service.clone(), - // worker_service.clone(), - // )); - - // let auth_call_back_binding_handler: Arc = - // Arc::new(DefaultAuthCallBackBindingHandler::new( - // gateway_session_store.clone(), - // identity_provider.clone(), - // )); - - // let http_handler_binding_handler: Arc = Arc::new( - // HttpHandlerBindingHandler::new(gateway_worker_request_executor.clone()), - // ); - - let api_definition_lookup_service: Arc = Arc::new( - RegistryServiceApiDefinitionsLookup::new(registry_service_client.clone()), - ); - - let route_resolver = Arc::new(RouteResolver::new( - &config.route_resolver, - api_definition_lookup_service.clone(), - )); - - let request_handler = Arc::new(RequestHandler::new( - route_resolver.clone(), - worker_service.clone(), - )); - - let agents_service: Arc = Arc::new(AgentsService::new( - registry_service_client.clone(), - component_service.clone(), - worker_service.clone(), - )); - - Ok(Self { - auth_service, - limit_service, - component_service, - worker_service, - request_handler, - agents_service, - }) - } -} diff --git a/golem-worker-service/src/service/worker/agents.rs b/golem-worker-service/src/service/worker/agents.rs index 9837d1e3ee..e5396c72ad 100644 --- a/golem-worker-service/src/service/worker/agents.rs +++ b/golem-worker-service/src/service/worker/agents.rs @@ -15,8 +15,8 @@ use crate::api::agents::{AgentInvocationMode, AgentInvocationRequest, AgentInvocationResult}; use crate::service::component::ComponentService; use crate::service::worker::{WorkerResult, WorkerService, WorkerServiceError}; -use golem_common::model::agent::{AgentError, AgentId, DataValue, UntypedDataValue}; use golem_common::model::WorkerId; +use golem_common::model::agent::{AgentError, AgentId, DataValue, UntypedDataValue}; use golem_service_base::clients::registry::RegistryService; use golem_service_base::model::auth::AuthCtx; use golem_wasm::{FromValue, IntoValueAndType, Value}; diff --git a/golem-worker-service/src/service/worker/client.rs b/golem-worker-service/src/service/worker/client.rs index 10a461db9c..6375200c5c 100644 --- a/golem-worker-service/src/service/worker/client.rs +++ b/golem-worker-service/src/service/worker/client.rs @@ -30,6 +30,7 @@ use golem_api_grpc::proto::golem::workerexecutor::v1::{ InvokeAndAwaitWorkerJsonRequest, InvokeAndAwaitWorkerRequest, ResumeWorkerRequest, RevertWorkerRequest, SearchOplogResponse, UpdateWorkerRequest, }; +use golem_common::model::RetryConfig; use golem_common::model::account::AccountId; use golem_common::model::component::{ ComponentFilePath, ComponentId, ComponentRevision, PluginPriority, @@ -39,7 +40,6 @@ use golem_common::model::oplog::{OplogCursor, PublicOplogEntry}; use golem_common::model::oplog::{OplogIndex, PublicOplogEntryWithIndex}; use golem_common::model::worker::WorkerUpdateMode; use golem_common::model::worker::{RevertWorkerTarget, WorkerMetadataDto}; -use golem_common::model::RetryConfig; use golem_common::model::{ FilterComparator, IdempotencyKey, PromiseId, ScanCursor, WorkerFilter, WorkerId, WorkerStatus, }; @@ -48,14 +48,14 @@ use golem_service_base::grpc::client::MultiTargetGrpcClient; use golem_service_base::model::auth::AuthCtx; use golem_service_base::model::{ComponentFileSystemNode, GetOplogResponse}; use golem_service_base::service::routing_table::{HasRoutingTableService, RoutingTableService}; +use golem_wasm::ValueAndType; use golem_wasm::analysis::AnalysedFunctionResult; use golem_wasm::protobuf::Val as ProtoVal; -use golem_wasm::ValueAndType; use std::collections::BTreeMap; use std::pin::Pin; use std::{collections::HashMap, sync::Arc}; -use tonic::transport::Channel; use tonic::Code; +use tonic::transport::Channel; use tonic_tracing_opentelemetry::middleware::client::OtelGrpcService; #[async_trait] diff --git a/golem-worker-service/src/service/worker/connect.rs b/golem-worker-service/src/service/worker/connect.rs index 53d33e4921..eb69fa3c6f 100644 --- a/golem-worker-service/src/service/worker/connect.rs +++ b/golem-worker-service/src/service/worker/connect.rs @@ -16,8 +16,8 @@ use super::WorkerStream; use crate::service::limit::LimitService; use futures::{Stream, StreamExt}; use golem_api_grpc::proto::golem::worker::LogEvent; -use golem_common::model::account::AccountId; use golem_common::model::WorkerId; +use golem_common::model::account::AccountId; use std::sync::Arc; use tonic::Status; diff --git a/golem-worker-service/src/service/worker/connect_proxy.rs b/golem-worker-service/src/service/worker/connect_proxy.rs index 90eb49b863..38b12c9f9c 100644 --- a/golem-worker-service/src/service/worker/connect_proxy.rs +++ b/golem-worker-service/src/service/worker/connect_proxy.rs @@ -320,13 +320,13 @@ mod keep_alive { mod test { use test_r::test; - use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Once; + use std::sync::atomic::{AtomicBool, Ordering}; use super::*; use poem::web::websocket::Message; use tokio::sync::mpsc; - use tokio::time::{timeout, Duration}; + use tokio::time::{Duration, timeout}; use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::PollSender; use tracing::{Instrument, Level}; diff --git a/golem-worker-service/src/service/worker/error.rs b/golem-worker-service/src/service/worker/error.rs index 538d077288..c988e007a8 100644 --- a/golem-worker-service/src/service/worker/error.rs +++ b/golem-worker-service/src/service/worker/error.rs @@ -16,10 +16,10 @@ use crate::service::auth::AuthServiceError; use crate::service::component::ComponentServiceError; use crate::service::limit::LimitServiceError; use crate::service::worker::CallWorkerExecutorError; +use golem_common::SafeDisplay; +use golem_common::model::WorkerId; use golem_common::model::account::AccountId; use golem_common::model::component::{ComponentFilePath, ComponentId}; -use golem_common::model::WorkerId; -use golem_common::SafeDisplay; use golem_service_base::clients::registry::RegistryServiceError; use golem_service_base::error::worker_executor::WorkerExecutorError; @@ -84,9 +84,9 @@ impl From for golem_api_grpc::proto::golem::worker::v1::Work impl From for golem_api_grpc::proto::golem::worker::v1::worker_error::Error { fn from(error: WorkerServiceError) -> Self { use golem_api_grpc::proto::golem::common::{ErrorBody, ErrorsBody}; - use golem_api_grpc::proto::golem::worker::v1::worker_execution_error::Error as GrpcError; use golem_api_grpc::proto::golem::worker::v1::UnknownError; use golem_api_grpc::proto::golem::worker::v1::WorkerExecutionError; + use golem_api_grpc::proto::golem::worker::v1::worker_execution_error::Error as GrpcError; match error { WorkerServiceError::ComponentNotFound(_) diff --git a/golem-worker-service/src/service/worker/invocation_parameters.rs b/golem-worker-service/src/service/worker/invocation_parameters.rs index 38420a3cdc..bcdf8e7c97 100644 --- a/golem-worker-service/src/service/worker/invocation_parameters.rs +++ b/golem-worker-service/src/service/worker/invocation_parameters.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use golem_wasm::json::OptionallyValueAndTypeJson; use golem_wasm::ValueAndType; +use golem_wasm::json::OptionallyValueAndTypeJson; pub enum InvocationParameters { TypedProtoVals(Vec), diff --git a/golem-worker-service/src/service/worker/routing_logic.rs b/golem-worker-service/src/service/worker/routing_logic.rs index 4f525f690e..2febbdd949 100644 --- a/golem-worker-service/src/service/worker/routing_logic.rs +++ b/golem-worker-service/src/service/worker/routing_logic.rs @@ -16,11 +16,11 @@ use crate::service::worker::WorkerServiceError; use async_trait::async_trait; use golem_api_grpc::proto::golem::worker::v1::WorkerExecutionError; use golem_api_grpc::proto::golem::workerexecutor::v1::worker_executor_client::WorkerExecutorClient; +use golem_common::SafeDisplay; use golem_common::model::RetryConfig; use golem_common::model::{Pod, ShardId, WorkerId}; use golem_common::retriable_error::IsRetriableError; use golem_common::retries::get_delay; -use golem_common::SafeDisplay; use golem_service_base::error::worker_executor::WorkerExecutorError; use golem_service_base::grpc::client::MultiTargetGrpcClient; use golem_service_base::service::routing_table::{HasRoutingTableService, RoutingTableError}; @@ -29,11 +29,11 @@ use std::fmt::Debug; use std::future::Future; use std::pin::Pin; use tokio::task::JoinSet; -use tokio::time::{sleep, Instant}; -use tonic::transport::Channel; +use tokio::time::{Instant, sleep}; use tonic::Status; +use tonic::transport::Channel; use tonic_tracing_opentelemetry::middleware::client::OtelGrpcService; -use tracing::{debug, error, info, trace, warn, Instrument}; +use tracing::{Instrument, debug, error, info, trace, warn}; #[async_trait] pub trait RoutingLogic { diff --git a/golem-worker-service/src/service/worker/service.rs b/golem-worker-service/src/service/worker/service.rs index 17f3354906..054b31ef78 100644 --- a/golem-worker-service/src/service/worker/service.rs +++ b/golem-worker-service/src/service/worker/service.rs @@ -32,9 +32,9 @@ use golem_common::model::worker::{RevertWorkerTarget, WorkerMetadataDto}; use golem_common::model::{IdempotencyKey, ScanCursor, WorkerFilter, WorkerId}; use golem_service_base::model::auth::{AuthCtx, EnvironmentAction}; use golem_service_base::model::{ComponentFileSystemNode, GetOplogResponse}; +use golem_wasm::ValueAndType; use golem_wasm::protobuf::Val as ProtoVal; use golem_wasm::protobuf::Val; -use golem_wasm::ValueAndType; use std::collections::{BTreeMap, BTreeSet}; use std::pin::Pin; use std::{collections::HashMap, sync::Arc}; @@ -267,13 +267,15 @@ impl WorkerService { auth_ctx: AuthCtx, ) -> WorkerResult> { // sanity check for consistency of auth and owner information - assert!(auth_ctx - .authorize_environment_action( - account_id_owning_environment, - &BTreeSet::new(), - EnvironmentAction::UpdateWorker, - ) - .is_ok()); + assert!( + auth_ctx + .authorize_environment_action( + account_id_owning_environment, + &BTreeSet::new(), + EnvironmentAction::UpdateWorker, + ) + .is_ok() + ); let result = self .worker_client diff --git a/golem-worker-service/src/service/worker/worker_stream.rs b/golem-worker-service/src/service/worker/worker_stream.rs index fd2e427c16..af96de8f69 100644 --- a/golem-worker-service/src/service/worker/worker_stream.rs +++ b/golem-worker-service/src/service/worker/worker_stream.rs @@ -21,7 +21,7 @@ use futures::{Stream, StreamExt}; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use tonic::{Status, Streaming}; -use tracing::{error, Instrument}; +use tracing::{Instrument, error}; use golem_common::metrics::api::{ record_closed_grpc_api_active_stream, record_new_grpc_api_active_stream, diff --git a/golem-worker-service/src/swagger_ui.rs b/golem-worker-service/src/swagger_ui.rs deleted file mode 100644 index e963be003e..0000000000 --- a/golem-worker-service/src/swagger_ui.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2024-2025 Golem Cloud -// -// Licensed under the Golem Source License v1.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://license.golem.cloud/LICENSE -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use crate::model::SwaggerHtml; -use golem_service_base::custom_api::openapi::HttpApiDefinitionOpenApiSpec; - -pub fn generate_swagger_html( - authority: &str, - open_api_spec: HttpApiDefinitionOpenApiSpec, -) -> Result { - let mut spec = open_api_spec.0; - - // Add server information to the OpenAPI spec - if spec.servers.is_empty() { - spec.servers = Vec::new(); - } - - spec.servers.push(openapiv3::Server { - url: format!("http://{authority}"), - description: Some("Local Development Server".to_string()), - variables: Default::default(), - extensions: Default::default(), - }); - - // Convert back to JSON - let modified_spec_json = serde_json::to_string(&spec) - .map_err(|e| format!("Failed to serialize OpenAPI spec: {e}"))?; - - // Generate Swagger UI HTML - let html = format!( - r#" - - - -Swagger UI - - - - - -
- - -"# - ); - - Ok(SwaggerHtml(html)) -} diff --git a/integration-tests/src/benchmarks/throughput.rs b/integration-tests/src/benchmarks/throughput.rs index ce787d1d50..91f9a8d13a 100644 --- a/integration-tests/src/benchmarks/throughput.rs +++ b/integration-tests/src/benchmarks/throughput.rs @@ -27,7 +27,9 @@ use golem_common::model::http_api_definition::{ GatewayBinding, HttpApiDefinitionCreation, HttpApiDefinitionName, HttpApiDefinitionVersion, HttpApiRoute, RouteMethod, WorkerGatewayBinding, }; -use golem_common::model::http_api_deployment::HttpApiDeploymentCreation; +use golem_common::model::http_api_deployment::{ + HttpApiDeploymentAgentOptions, HttpApiDeploymentCreation, +}; use golem_common::model::{RoutingTable, WorkerId}; use golem_test_framework::benchmark::{Benchmark, BenchmarkRecorder, RunConfig}; use golem_test_framework::config::benchmark::TestMode; @@ -38,7 +40,7 @@ use golem_wasm::{IntoValueAndType, ValueAndType}; use indoc::indoc; use reqwest::{Body, Method, Request, Url}; use serde_json::json; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use tracing::{info, Level}; pub struct ThroughputEcho { @@ -670,7 +672,10 @@ impl ThroughputBenchmark { let http_api_deployment_creation = HttpApiDeploymentCreation { domain: domain.clone(), - agent_types: [AgentTypeName("benchmark-agent".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("benchmark-agent".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; client diff --git a/integration-tests/tests/api/deployment.rs b/integration-tests/tests/api/deployment.rs index 9bb4516c2c..04bb33ff4a 100644 --- a/integration-tests/tests/api/deployment.rs +++ b/integration-tests/tests/api/deployment.rs @@ -31,7 +31,9 @@ use golem_common::model::http_api_definition::{ GatewayBinding, HttpApiDefinitionCreation, HttpApiDefinitionName, HttpApiDefinitionVersion, HttpApiRoute, RouteMethod, WorkerGatewayBinding, }; -use golem_common::model::http_api_deployment::HttpApiDeploymentCreation; +use golem_common::model::http_api_deployment::{ + HttpApiDeploymentAgentOptions, HttpApiDeploymentCreation, +}; use golem_test_framework::config::{EnvBasedTestDependencies, TestDependencies}; use golem_test_framework::dsl::{TestDsl, TestDslExtended}; use std::collections::BTreeMap; @@ -252,7 +254,10 @@ async fn full_deployment(deps: &EnvBasedTestDependencies) -> anyhow::Result<()> let http_api_deployment_creation = HttpApiDeploymentCreation { domain: domain.clone(), - agent_types: [AgentTypeName("shopping-cart".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("shopping-cart".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment = client diff --git a/integration-tests/tests/api/http_api_deployment.rs b/integration-tests/tests/api/http_api_deployment.rs index 80a59b49cc..441094267d 100644 --- a/integration-tests/tests/api/http_api_deployment.rs +++ b/integration-tests/tests/api/http_api_deployment.rs @@ -27,10 +27,11 @@ use golem_common::model::http_api_definition::{ HttpApiRoute, RouteMethod, WorkerGatewayBinding, }; use golem_common::model::http_api_deployment::{ - HttpApiDeploymentCreation, HttpApiDeploymentUpdate, + HttpApiDeploymentAgentOptions, HttpApiDeploymentCreation, HttpApiDeploymentUpdate, }; use golem_test_framework::config::{EnvBasedTestDependencies, TestDependencies}; use golem_test_framework::dsl::{TestDsl, TestDslExtended}; +use std::collections::BTreeMap; use test_r::{inherit_test_dep, test}; inherit_test_dep!(EnvBasedTestDependencies); @@ -47,7 +48,10 @@ async fn create_http_api_deployment_for_nonexitant_domain( let http_api_deployment_creation = HttpApiDeploymentCreation { domain: Domain("testdomain.com".to_string()), - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let result = client @@ -74,7 +78,10 @@ async fn create_http_api_deployment(deps: &EnvBasedTestDependencies) -> anyhow:: let http_api_deployment_creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment = client @@ -118,7 +125,10 @@ async fn update_http_api_deployment(deps: &EnvBasedTestDependencies) -> anyhow:: let http_api_deployment_creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment = client @@ -127,13 +137,16 @@ async fn update_http_api_deployment(deps: &EnvBasedTestDependencies) -> anyhow:: let http_api_deployment_update = HttpApiDeploymentUpdate { current_revision: http_api_deployment.revision, - agent_types: Some( - [ + agents: Some(BTreeMap::from_iter([ + ( AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + ), + ( AgentTypeName("test-api-2".to_string()), - ] - .into(), - ), + HttpApiDeploymentAgentOptions::default(), + ), + ])), }; let updated_http_api_deployment = client @@ -142,9 +155,7 @@ async fn update_http_api_deployment(deps: &EnvBasedTestDependencies) -> anyhow:: assert!(updated_http_api_deployment.id == http_api_deployment.id); assert!(updated_http_api_deployment.revision == http_api_deployment.revision.next()?); - assert!( - updated_http_api_deployment.agent_types == http_api_deployment_update.agent_types.unwrap() - ); + assert!(updated_http_api_deployment.agents == http_api_deployment_update.agents.unwrap()); { let fetched_http_api_deployment = client @@ -201,7 +212,10 @@ async fn delete_http_api_deployment(deps: &EnvBasedTestDependencies) -> anyhow:: let http_api_deployment_creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment = client @@ -260,7 +274,10 @@ async fn cannot_create_two_http_api_deployments_for_same_domain( let http_api_deployment_creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; client @@ -293,7 +310,10 @@ async fn updates_with_wrong_revision_number_are_rejected( let http_api_deployment_creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment = client @@ -302,13 +322,16 @@ async fn updates_with_wrong_revision_number_are_rejected( let http_api_deployment_update = HttpApiDeploymentUpdate { current_revision: http_api_deployment.revision.next()?, - agent_types: Some( - [ + agents: Some(BTreeMap::from_iter([ + ( AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + ), + ( AgentTypeName("test-api-2".to_string()), - ] - .into(), - ), + HttpApiDeploymentAgentOptions::default(), + ), + ])), }; let result = client @@ -335,7 +358,10 @@ async fn http_api_deployment_recreation(deps: &EnvBasedTestDependencies) -> anyh let http_api_deployment_creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment_1 = client @@ -414,7 +440,10 @@ async fn fetch_in_deployment(deps: &EnvBasedTestDependencies) -> anyhow::Result< let http_api_deployment_creation = HttpApiDeploymentCreation { domain: domain.clone(), - agent_types: [AgentTypeName("ephemeral-echo-agent".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("ephemeral-echo-agent".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let http_api_deployment = client @@ -453,7 +482,10 @@ async fn cannot_access_http_api_deployment_from_another_user( let creation = HttpApiDeploymentCreation { domain: domain.clone(), - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let deployment = client_a @@ -488,7 +520,10 @@ async fn cannot_delete_http_api_deployment_from_another_user( let creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let deployment = client_a @@ -520,7 +555,10 @@ async fn delete_with_wrong_revision_is_rejected( let creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let deployment = client @@ -553,7 +591,10 @@ async fn deleting_twice_returns_404(deps: &EnvBasedTestDependencies) -> anyhow:: let creation = HttpApiDeploymentCreation { domain, - agent_types: [AgentTypeName("test-api".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("test-api".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; let deployment = client diff --git a/integration-tests/tests/custom_api/agent_http_routes_ts.rs b/integration-tests/tests/custom_api/agent_http_routes_ts.rs index 5f578140e7..fe16bf78ae 100644 --- a/integration-tests/tests/custom_api/agent_http_routes_ts.rs +++ b/integration-tests/tests/custom_api/agent_http_routes_ts.rs @@ -18,14 +18,16 @@ use golem_common::model::agent::AgentTypeName; use golem_common::model::deployment::DeploymentRevision; use golem_common::model::domain_registration::{Domain, DomainRegistrationCreation}; use golem_common::model::environment::EnvironmentId; -use golem_common::model::http_api_deployment::HttpApiDeploymentCreation; +use golem_common::model::http_api_deployment::{ + HttpApiDeploymentAgentOptions, HttpApiDeploymentCreation, +}; use golem_test_framework::config::dsl_impl::TestUserContext; use golem_test_framework::config::{EnvBasedTestDependencies, TestDependencies}; use golem_test_framework::dsl::{TestDsl, TestDslExtended}; use pretty_assertions::assert_eq; use reqwest::Url; use serde_json::json; -use std::collections::BTreeSet; +use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; use test_r::test_dep; use test_r::{inherit_test_dep, test}; @@ -75,7 +77,16 @@ async fn test_context_internal(deps: &EnvBasedTestDependencies) -> anyhow::Resul let http_api_deployment_creation = HttpApiDeploymentCreation { domain: domain.clone(), - agent_types: BTreeSet::from_iter([AgentTypeName("http-agent".into())]), + agents: BTreeMap::from_iter([ + ( + AgentTypeName("http-agent".to_string()), + HttpApiDeploymentAgentOptions::default(), + ), + ( + AgentTypeName("cors-agent".to_string()), + HttpApiDeploymentAgentOptions::default(), + ), + ]), }; client @@ -111,7 +122,7 @@ async fn string_path_var(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/string-path-var/foo")?, + .join("/http-agents/test-agent/string-path-var/foo")?, ) .send() .await?; @@ -130,7 +141,7 @@ async fn multi_path_vars(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/multi-path-vars/foo/bar")?, + .join("/http-agents/test-agent/multi-path-vars/foo/bar")?, ) .send() .await?; @@ -148,7 +159,11 @@ async fn multi_path_vars(agent: &TestContext) -> anyhow::Result<()> { async fn remaining_path_variable(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .get(agent.base_url.join("/agents/test-agent/rest/a/b/c/d")?) + .get( + agent + .base_url + .join("/http-agents/test-agent/rest/a/b/c/d")?, + ) .send() .await?; @@ -170,7 +185,7 @@ async fn remaining_path_variable(agent: &TestContext) -> anyhow::Result<()> { async fn remaining_path_missing(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .get(agent.base_url.join("/agents/test-agent/rest")?) + .get(agent.base_url.join("/http-agents/test-agent/rest")?) .send() .await?; @@ -186,7 +201,7 @@ async fn path_and_query(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/path-and-query/item-123?limit=10")?, + .join("/http-agents/test-agent/path-and-query/item-123?limit=10")?, ) .send() .await?; @@ -213,7 +228,7 @@ async fn path_and_header(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/path-and-header/res-42")?, + .join("/http-agents/test-agent/path-and-header/res-42")?, ) .header("x-request-id", "req-abc") .send() @@ -238,7 +253,11 @@ async fn path_and_header(agent: &TestContext) -> anyhow::Result<()> { async fn json_body(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .post(agent.base_url.join("/agents/test-agent/json-body/item-1")?) + .post( + agent + .base_url + .join("/http-agents/test-agent/json-body/item-1")?, + ) .json(&json!({ "name": "test", "count": 42 @@ -259,7 +278,11 @@ async fn json_body(agent: &TestContext) -> anyhow::Result<()> { async fn json_body_missing_field(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .post(agent.base_url.join("/agents/test-agent/json-body/item-1")?) + .post( + agent + .base_url + .join("/http-agents/test-agent/json-body/item-1")?, + ) .json(&json!({ "name": "test" })) @@ -275,7 +298,11 @@ async fn json_body_missing_field(agent: &TestContext) -> anyhow::Result<()> { async fn json_body_wrong_type(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .post(agent.base_url.join("/agents/test-agent/json-body/item-1")?) + .post( + agent + .base_url + .join("/http-agents/test-agent/json-body/item-1")?, + ) .json(&json!({ "name": "test", "count": "not-a-number" @@ -295,7 +322,7 @@ async fn unrestricted_unstructured_binary_inline(agent: &TestContext) -> anyhow: .post( agent .base_url - .join("/agents/test-agent/unrestricted-unstructured-binary/my-bucket")?, + .join("/http-agents/test-agent/unrestricted-unstructured-binary/my-bucket")?, ) .header(reqwest::header::CONTENT_TYPE, "application/octet-stream") .body(vec![1u8, 2, 3, 4, 5]) @@ -318,7 +345,7 @@ async fn unrestricted_unstructured_binary_missing_body(agent: &TestContext) -> a .post( agent .base_url - .join("/agents/test-agent/unrestricted-unstructured-binary/my-bucket")?, + .join("/http-agents/test-agent/unrestricted-unstructured-binary/my-bucket")?, ) .send() .await?; @@ -341,7 +368,7 @@ async fn unrestricted_unstructured_binary_json_content_type( .post( agent .base_url - .join("/agents/test-agent/unrestricted-unstructured-binary/my-bucket")?, + .join("/http-agents/test-agent/unrestricted-unstructured-binary/my-bucket")?, ) .json(&json!({ "oops": true })) .send() @@ -363,7 +390,7 @@ async fn restricted_unstructured_binary_inline(agent: &TestContext) -> anyhow::R .post( agent .base_url - .join("/agents/test-agent/restricted-unstructured-binary/my-bucket")?, + .join("/http-agents/test-agent/restricted-unstructured-binary/my-bucket")?, ) .header(reqwest::header::CONTENT_TYPE, "image/gif") .body(vec![1u8, 2, 3, 4, 5]) @@ -386,7 +413,7 @@ async fn restricted_unstructured_binary_missing_body(agent: &TestContext) -> any .post( agent .base_url - .join("/agents/test-agent/restricted-unstructured-binary/my-bucket")?, + .join("/http-agents/test-agent/restricted-unstructured-binary/my-bucket")?, ) .send() .await?; @@ -406,7 +433,7 @@ async fn restricted_unstructured_binary_unsupported_mime_type( .post( agent .base_url - .join("/agents/test-agent/restricted-unstructured-binary/my-bucket")?, + .join("/http-agents/test-agent/restricted-unstructured-binary/my-bucket")?, ) .json(&json!({ "oops": true })) .send() @@ -422,7 +449,11 @@ async fn restricted_unstructured_binary_unsupported_mime_type( async fn response_no_content(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .get(agent.base_url.join("/agents/test-agent/resp/no-content")?) + .get( + agent + .base_url + .join("/http-agents/test-agent/resp/no-content")?, + ) .send() .await?; @@ -437,7 +468,7 @@ async fn response_no_content(agent: &TestContext) -> anyhow::Result<()> { async fn response_json(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .get(agent.base_url.join("/agents/test-agent/resp/json")?) + .get(agent.base_url.join("/http-agents/test-agent/resp/json")?) .send() .await?; @@ -457,7 +488,7 @@ async fn response_optional_found(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/resp/optional/true")?, + .join("/http-agents/test-agent/resp/optional/true")?, ) .send() .await?; @@ -478,7 +509,7 @@ async fn response_optional_not_found(agent: &TestContext) -> anyhow::Result<()> .get( agent .base_url - .join("/agents/test-agent/resp/optional/false")?, + .join("/http-agents/test-agent/resp/optional/false")?, ) .send() .await?; @@ -497,7 +528,7 @@ async fn response_result_ok(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/resp/result-json-json/true")?, + .join("/http-agents/test-agent/resp/result-json-json/true")?, ) .send() .await?; @@ -518,7 +549,7 @@ async fn response_result_err(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/resp/result-json-json/false")?, + .join("/http-agents/test-agent/resp/result-json-json/false")?, ) .send() .await?; @@ -542,7 +573,7 @@ async fn response_result_void_err(agent: &TestContext) -> anyhow::Result<()> { .post( agent .base_url - .join("/agents/test-agent/resp/result-void-json")?, + .join("/http-agents/test-agent/resp/result-void-json")?, ) .send() .await?; @@ -566,7 +597,7 @@ async fn response_result_json_void(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/resp/result-json-void")?, + .join("/http-agents/test-agent/resp/result-json-void")?, ) .send() .await?; @@ -584,7 +615,7 @@ async fn response_result_json_void(agent: &TestContext) -> anyhow::Result<()> { async fn response_binary(agent: &TestContext) -> anyhow::Result<()> { let response = agent .client - .get(agent.base_url.join("/agents/test-agent/resp/binary")?) + .get(agent.base_url.join("/http-agents/test-agent/resp/binary")?) .send() .await?; @@ -612,7 +643,7 @@ async fn negative_missing_path_var(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/multi-path-vars/foo")?, + .join("/http-agents/test-agent/multi-path-vars/foo")?, ) .send() .await?; @@ -629,7 +660,7 @@ async fn negative_extra_path_segment(agent: &TestContext) -> anyhow::Result<()> .get( agent .base_url - .join("/agents/test-agent/string-path-var/foo/bar")?, + .join("/http-agents/test-agent/string-path-var/foo/bar")?, ) .send() .await?; @@ -646,7 +677,7 @@ async fn negative_missing_query_param(agent: &TestContext) -> anyhow::Result<()> .get( agent .base_url - .join("/agents/test-agent/path-and-query/item-123")?, + .join("/http-agents/test-agent/path-and-query/item-123")?, ) .send() .await?; @@ -663,7 +694,7 @@ async fn negative_invalid_query_param_type(agent: &TestContext) -> anyhow::Resul .get( agent .base_url - .join("/agents/test-agent/path-and-query/item-123?limit=not-a-number")?, + .join("/http-agents/test-agent/path-and-query/item-123?limit=not-a-number")?, ) .send() .await?; @@ -680,7 +711,7 @@ async fn negative_missing_header(agent: &TestContext) -> anyhow::Result<()> { .get( agent .base_url - .join("/agents/test-agent/path-and-header/res-42")?, + .join("/http-agents/test-agent/path-and-header/res-42")?, ) // no x-request-id header .send() @@ -689,3 +720,138 @@ async fn negative_missing_header(agent: &TestContext) -> anyhow::Result<()> { assert_eq!(response.status(), reqwest::StatusCode::BAD_REQUEST); Ok(()) } + +#[test] +#[tracing::instrument] +async fn cors_preflight_wildcard(agent: &TestContext) -> anyhow::Result<()> { + let response = agent + .client + .request( + reqwest::Method::OPTIONS, + agent.base_url.join("/cors-agents/test-agent/wildcard")?, + ) + .header("Origin", "https://any-origin.com") + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); + + let allow_origin = response + .headers() + .get("access-control-allow-origin") + .unwrap() + .to_str()?; + assert_eq!(allow_origin, "https://any-origin.com"); + + let vary = response.headers().get("vary").unwrap().to_str()?; + assert_eq!(vary, "Origin"); + + Ok(()) +} + +#[test] +#[tracing::instrument] +async fn cors_preflight_specific_origin(agent: &TestContext) -> anyhow::Result<()> { + let response = agent + .client + .request( + reqwest::Method::OPTIONS, + agent + .base_url + .join("/cors-agents/test-agent/preflight-required")?, + ) + .header("Origin", "https://app.example.com") + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); + + let allow_origin = response + .headers() + .get("access-control-allow-origin") + .unwrap() + .to_str()?; + assert_eq!(allow_origin, "https://app.example.com"); + + let allow_methods = response + .headers() + .get("access-control-allow-methods") + .unwrap() + .to_str()?; + assert!(allow_methods.contains("POST")); + + let vary = response.headers().get("vary").unwrap().to_str()?; + assert_eq!(vary, "Origin"); + + Ok(()) +} + +#[test] +#[tracing::instrument] +async fn cors_get_with_origin_header(agent: &TestContext) -> anyhow::Result<()> { + let response = agent + .client + .get(agent.base_url.join("/cors-agents/test-agent/inherited")?) + .header("Origin", "https://mount.example.com") + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::OK); + + let allow_origin = response + .headers() + .get("access-control-allow-origin") + .unwrap() + .to_str()?; + assert_eq!(allow_origin, "https://mount.example.com"); + + let vary = response.headers().get("vary").unwrap().to_str()?; + assert_eq!(vary, "Origin"); + + Ok(()) +} + +#[test] +#[tracing::instrument] +async fn cors_get_with_origin_header_invalid(agent: &TestContext) -> anyhow::Result<()> { + let response = agent + .client + .get(agent.base_url.join("/cors-agents/test-agent/inherited")?) + .header("Origin", "https://not-allowed.com") + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::OK); + + assert!(response + .headers() + .get("access-control-allow-origin") + .is_none()); + + Ok(()) +} + +#[test] +#[tracing::instrument] +async fn cors_get_wildcard_origin(agent: &TestContext) -> anyhow::Result<()> { + let response = agent + .client + .get(agent.base_url.join("/cors-agents/test-agent/wildcard")?) + .header("Origin", "https://random-origin.com") + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::OK); + + let allow_origin = response + .headers() + .get("access-control-allow-origin") + .unwrap() + .to_str()?; + assert_eq!(allow_origin, "https://random-origin.com"); + + let vary = response.headers().get("vary").unwrap().to_str()?; + assert_eq!(vary, "Origin"); + + Ok(()) +} diff --git a/integration-tests/tests/invocation_context.rs b/integration-tests/tests/invocation_context.rs index 6ab87a7ece..c2d5540ff5 100644 --- a/integration-tests/tests/invocation_context.rs +++ b/integration-tests/tests/invocation_context.rs @@ -26,14 +26,16 @@ use golem_common::model::http_api_definition::{ GatewayBinding, HttpApiDefinitionCreation, HttpApiDefinitionName, HttpApiDefinitionVersion, HttpApiRoute, RouteMethod, WorkerGatewayBinding, }; -use golem_common::model::http_api_deployment::HttpApiDeploymentCreation; +use golem_common::model::http_api_deployment::{ + HttpApiDeploymentAgentOptions, HttpApiDeploymentCreation, +}; use golem_common::model::invocation_context::{SpanId, TraceId}; use golem_test_framework::config::{EnvBasedTestDependencies, TestDependencies}; use golem_test_framework::dsl::{TestDsl, TestDslExtended}; use reqwest::header::HeaderValue; use reqwest::Client; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use test_r::{inherit_test_dep, test, timeout}; @@ -157,7 +159,10 @@ async fn invocation_context_test(deps: &EnvBasedTestDependencies) -> anyhow::Res let http_api_deployment_creation = HttpApiDeploymentCreation { domain: domain.clone(), - agent_types: [AgentTypeName("placeholder-agent".to_string())].into(), + agents: BTreeMap::from_iter([( + AgentTypeName("placeholder-agent".to_string()), + HttpApiDeploymentAgentOptions::default(), + )]), }; client diff --git a/openapi/golem-registry-service.yaml b/openapi/golem-registry-service.yaml index b2c88a536c..1cd55ee950 100644 --- a/openapi/golem-registry-service.yaml +++ b/openapi/golem-registry-service.yaml @@ -9213,7 +9213,7 @@ components: - environmentId - domain - hash - - agentTypes + - agents - createdAt properties: id: @@ -9230,26 +9230,36 @@ components: hash: type: string format: hash - agentTypes: - type: array - items: - type: string + agents: + type: object + additionalProperties: + $ref: '#/components/schemas/HttpApiDeploymentAgentOptions' createdAt: type: string format: date-time + HttpApiDeploymentAgentOptions: + type: object + title: HttpApiDeploymentAgentOptions + properties: + securityScheme: + type: string + description: |- + Security scheme to use for all agent methods that require auth. + Failure to provide a security scheme for an agent that requires one will lead to a deployment failure. + If the requested security scheme does not exist in the environment, the route will be disabled at runtime. HttpApiDeploymentCreation: type: object title: HttpApiDeploymentCreation required: - domain - - agentTypes + - agents properties: domain: type: string - agentTypes: - type: array - items: - type: string + agents: + type: object + additionalProperties: + $ref: '#/components/schemas/HttpApiDeploymentAgentOptions' HttpApiDeploymentUpdate: type: object title: HttpApiDeploymentUpdate @@ -9259,10 +9269,10 @@ components: currentRevision: type: integer format: uint64 - agentTypes: - type: array - items: - type: string + agents: + type: object + additionalProperties: + $ref: '#/components/schemas/HttpApiDeploymentAgentOptions' HttpApiRoute: type: object title: HttpApiRoute diff --git a/openapi/golem-service.yaml b/openapi/golem-service.yaml index dfec3d147b..5d6f684d53 100644 --- a/openapi/golem-service.yaml +++ b/openapi/golem-service.yaml @@ -13792,10 +13792,10 @@ components: hash: type: string format: hash - agentTypes: - type: array - items: - type: string + agents: + type: object + additionalProperties: + $ref: '#/components/schemas/HttpApiDeploymentAgentOptions' createdAt: type: string format: date-time @@ -13805,21 +13805,31 @@ components: - environmentId - domain - hash - - agentTypes + - agents - createdAt + HttpApiDeploymentAgentOptions: + title: HttpApiDeploymentAgentOptions + type: object + properties: + securityScheme: + description: |- + Security scheme to use for all agent methods that require auth. + Failure to provide a security scheme for an agent that requires one will lead to a deployment failure. + If the requested security scheme does not exist in the environment, the route will be disabled at runtime. + type: string HttpApiDeploymentCreation: title: HttpApiDeploymentCreation type: object properties: domain: type: string - agentTypes: - type: array - items: - type: string + agents: + type: object + additionalProperties: + $ref: '#/components/schemas/HttpApiDeploymentAgentOptions' required: - domain - - agentTypes + - agents HttpApiDeploymentUpdate: title: HttpApiDeploymentUpdate type: object @@ -13827,10 +13837,10 @@ components: currentRevision: type: integer format: uint64 - agentTypes: - type: array - items: - type: string + agents: + type: object + additionalProperties: + $ref: '#/components/schemas/HttpApiDeploymentAgentOptions' required: - currentRevision HttpApiRoute: diff --git a/test-components/agent-http-routes-ts/components-ts/golem-it-agent-http-routes-ts/src/main.ts b/test-components/agent-http-routes-ts/components-ts/golem-it-agent-http-routes-ts/src/main.ts index 155672d109..6f23b4833a 100644 --- a/test-components/agent-http-routes-ts/components-ts/golem-it-agent-http-routes-ts/src/main.ts +++ b/test-components/agent-http-routes-ts/components-ts/golem-it-agent-http-routes-ts/src/main.ts @@ -6,7 +6,7 @@ import { UnstructuredBinary, } from '@golemcloud/golem-ts-sdk'; -@agent({ mount: '/agents/{agentName}' }) +@agent({ mount: '/http-agents/{agentName}' }) class HttpAgent extends BaseAgent { constructor(readonly agentName: string) { @@ -112,3 +112,40 @@ class HttpAgent extends BaseAgent { return UnstructuredBinary.fromInline(new Uint8Array([1, 2, 3, 4]), 'application/octet-stream') } } + +@agent({ + mount: '/cors-agents/{agentName}', + cors: ["https://mount.example.com"] +}) +class CorsAgent extends BaseAgent { + + constructor(readonly agentName: string) { + super(); + } + + // GET endpoint adds additional CORS on top of mount + @endpoint({ + get: "/wildcard", + cors: ["*"] // union with mount CORS + }) + wildcard(): { ok: boolean } { + return { ok: true }; + } + + // GET endpoint inherits mount CORS if empty + @endpoint({ + get: "/inherited" + }) + inherited(): { ok: boolean } { + return { ok: true }; + } + + // POST endpoint requiring preflight + @endpoint({ + post: "/preflight-required", + cors: ["https://app.example.com"] + }) + preflight(body: { name: string }): { received: string } { + return { received: body.name }; + } +} diff --git a/test-components/golem_it_agent_http_routes_ts.wasm b/test-components/golem_it_agent_http_routes_ts.wasm index 0ed0b1b3ee..39e8c69982 100644 Binary files a/test-components/golem_it_agent_http_routes_ts.wasm and b/test-components/golem_it_agent_http_routes_ts.wasm differ