Skip to content

Commit 1c72261

Browse files
authored
feat(server): allow multiple cors, add -ci / -cl shortcut (#419)
1 parent c342be6 commit 1c72261

File tree

4 files changed

+49
-23
lines changed

4 files changed

+49
-23
lines changed

.env.lib_debug

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export RUST_LOG=warn,cocoindex_engine=trace,tower_http=trace
22
export RUST_BACKTRACE=1
33

4-
export COCOINDEX_SERVER_CORS_ORIGIN=http://localhost:3000
4+
export COCOINDEX_SERVER_CORS_ORIGINS=http://localhost:3000,https://cocoindex.io

python/cocoindex/cli.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import click
22
import datetime
3+
import urllib.parse
4+
35
from rich.console import Console
46

57
from . import flow, lib
@@ -151,30 +153,50 @@ def evaluate(flow_name: str | None, output_dir: str | None, cache: bool = True):
151153

152154
_default_server_settings = lib.ServerSettings.from_env()
153155

156+
COCOINDEX_HOST = 'https://cocoindex.io'
157+
154158
@cli.command()
155159
@click.option(
156160
"-a", "--address", type=str, default=_default_server_settings.address,
157161
help="The address to bind the server to, in the format of IP:PORT.")
158162
@click.option(
159-
"-c", "--cors-origin", type=str, default=_default_server_settings.cors_origin,
160-
help="The origin of the client (e.g. CocoInsight UI) to allow CORS from. "
161-
"e.g. `http://cocoindex.io` if you want to allow CocoInsight to access the server.")
163+
"-c", "--cors-origin", type=str,
164+
default=_default_server_settings.cors_origins and ','.join(_default_server_settings.cors_origins),
165+
help="The origins of the clients (e.g. CocoInsight UI) to allow CORS from. "
166+
"Multiple origins can be specified as a comma-separated list. "
167+
"e.g. `https://cocoindex.io,http://localhost:3000`")
168+
@click.option(
169+
"-ci", "--cors-cocoindex", is_flag=True, show_default=True, default=False,
170+
help=f"Allow {COCOINDEX_HOST} to access the server.")
171+
@click.option(
172+
"-cl", "--cors-local", type=int,
173+
help=f"Allow http://localhost:<port> to access the server.")
162174
@click.option(
163175
"-L", "--live-update", is_flag=True, show_default=True, default=False,
164176
help="Continuously watch changes from data sources and apply to the target index.")
165177
@click.option(
166178
"-q", "--quiet", is_flag=True, show_default=True, default=False,
167179
help="Avoid printing anything to the standard output, e.g. statistics.")
168-
def server(address: str, live_update: bool, quiet: bool, cors_origin: str | None):
180+
def server(address: str, live_update: bool, quiet: bool, cors_origin: str | None,
181+
cors_cocoindex: bool, cors_local: int | None):
169182
"""
170183
Start a HTTP server providing REST APIs.
171184
172185
It will allow tools like CocoInsight to access the server.
173186
"""
174-
lib.start_server(lib.ServerSettings(address=address, cors_origin=cors_origin))
187+
cors_origins : set[str] = set()
188+
if cors_origin is not None:
189+
cors_origins.update(s for o in cors_origin.split(',') if (s:= o.strip()) != '')
190+
if cors_cocoindex:
191+
cors_origins.add(COCOINDEX_HOST)
192+
if cors_local is not None:
193+
cors_origins.add(f"http://localhost:{cors_local}")
194+
lib.start_server(lib.ServerSettings(address=address, cors_origins=list(cors_origins)))
175195
if live_update:
176196
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
177197
execution_context.run(flow.update_all_flows(options))
198+
if COCOINDEX_HOST in cors_origins:
199+
click.echo(f"Open CocoInsight at: {COCOINDEX_HOST}/cocoinsight")
178200
input("Press Enter to stop...")
179201

180202

python/cocoindex/lib.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
import functools
77
import inspect
88

9-
from typing import Callable, Self
9+
from typing import Callable, Self, Any
1010
from dataclasses import dataclass
1111

1212
from . import _engine
1313
from . import flow, query, cli
1414
from .convert import dump_engine_object
1515

1616

17-
def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
17+
def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool = False,
18+
parse: Callable[[str], Any] | None = None):
1819
value = os.getenv(env_name)
1920
if value is None:
2021
if required:
2122
raise ValueError(f"{env_name} is not set")
2223
else:
23-
target[name] = value
24+
target[name] = value if parse is None else parse(value)
2425

2526
@dataclass
2627
class DatabaseConnectionSpec:
@@ -56,17 +57,16 @@ class ServerSettings:
5657
# The address to bind the server to.
5758
address: str = "127.0.0.1:8080"
5859

59-
# The origin of the client (e.g. CocoInsight UI) to allow CORS from.
60-
cors_origin: str | None = None
60+
# The origins of the clients (e.g. CocoInsight UI) to allow CORS from.
61+
cors_origins: list[str] | None = None
6162

6263
@classmethod
6364
def from_env(cls) -> Self:
6465
"""Load settings from environment variables."""
65-
66-
kwargs: dict[str, str] = dict()
66+
kwargs: dict[str, Any] = dict()
6767
_load_field(kwargs, "address", "COCOINDEX_SERVER_ADDRESS")
68-
_load_field(kwargs, "cors_origin", "COCOINDEX_SERVER_CORS_ORIGIN")
69-
68+
_load_field(kwargs, "cors_origins", "COCOINDEX_SERVER_CORS_ORIGINS",
69+
parse=lambda s: [o for e in s.split(",") if (o := e.strip()) != ""])
7070
return cls(**kwargs)
7171

7272

src/server.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
use crate::{lib_context::LibContext, service};
1+
use crate::prelude::*;
22

3-
use anyhow::Result;
3+
use crate::{lib_context::LibContext, service};
44
use axum::{routing, Router};
5-
use futures::{future::BoxFuture, FutureExt};
6-
use serde::Deserialize;
7-
use std::sync::Arc;
85
use tower::ServiceBuilder;
96
use tower_http::{
107
cors::{AllowOrigin, CorsLayer},
@@ -14,7 +11,8 @@ use tower_http::{
1411
#[derive(Deserialize, Debug)]
1512
pub struct ServerSettings {
1613
pub address: String,
17-
pub cors_origin: Option<String>,
14+
#[serde(default)]
15+
pub cors_origins: Vec<String>,
1816
}
1917

2018
/// Initialize the server and return a future that will actually handle requests.
@@ -23,9 +21,15 @@ pub async fn init_server(
2321
settings: ServerSettings,
2422
) -> Result<BoxFuture<'static, ()>> {
2523
let mut cors = CorsLayer::default();
26-
if let Some(ui_cors_origin) = &settings.cors_origin {
24+
debug!("cors_origins: {:?}", settings.cors_origins);
25+
if !settings.cors_origins.is_empty() {
26+
let origins: Vec<_> = settings
27+
.cors_origins
28+
.iter()
29+
.map(|origin| origin.parse())
30+
.collect::<Result<_, _>>()?;
2731
cors = cors
28-
.allow_origin(AllowOrigin::exact(ui_cors_origin.parse()?))
32+
.allow_origin(AllowOrigin::list(origins))
2933
.allow_methods([
3034
axum::http::Method::GET,
3135
axum::http::Method::POST,

0 commit comments

Comments
 (0)