Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.lib_debug
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export RUST_LOG=warn,cocoindex_engine=trace,tower_http=trace
export RUST_BACKTRACE=1

export COCOINDEX_SERVER_CORS_ORIGIN=http://localhost:3000
export COCOINDEX_SERVER_CORS_ORIGINS=http://localhost:3000,https://cocoindex.io
32 changes: 27 additions & 5 deletions python/cocoindex/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import click
import datetime
import urllib.parse

from rich.console import Console

from . import flow, lib
Expand Down Expand Up @@ -151,30 +153,50 @@ def evaluate(flow_name: str | None, output_dir: str | None, cache: bool = True):

_default_server_settings = lib.ServerSettings.from_env()

COCOINDEX_HOST = 'https://cocoindex.io'

@cli.command()
@click.option(
"-a", "--address", type=str, default=_default_server_settings.address,
help="The address to bind the server to, in the format of IP:PORT.")
@click.option(
"-c", "--cors-origin", type=str, default=_default_server_settings.cors_origin,
help="The origin of the client (e.g. CocoInsight UI) to allow CORS from. "
"e.g. `http://cocoindex.io` if you want to allow CocoInsight to access the server.")
"-c", "--cors-origin", type=str,
default=_default_server_settings.cors_origins and ','.join(_default_server_settings.cors_origins),
help="The origins of the clients (e.g. CocoInsight UI) to allow CORS from. "
"Multiple origins can be specified as a comma-separated list. "
"e.g. `https://cocoindex.io,http://localhost:3000`")
@click.option(
"-ci", "--cors-cocoindex", is_flag=True, show_default=True, default=False,
help=f"Allow {COCOINDEX_HOST} to access the server.")
@click.option(
"-cl", "--cors-local", type=int,
help=f"Allow http://localhost:<port> to access the server.")
@click.option(
"-L", "--live-update", is_flag=True, show_default=True, default=False,
help="Continuously watch changes from data sources and apply to the target index.")
@click.option(
"-q", "--quiet", is_flag=True, show_default=True, default=False,
help="Avoid printing anything to the standard output, e.g. statistics.")
def server(address: str, live_update: bool, quiet: bool, cors_origin: str | None):
def server(address: str, live_update: bool, quiet: bool, cors_origin: str | None,
cors_cocoindex: bool, cors_local: int | None):
"""
Start a HTTP server providing REST APIs.

It will allow tools like CocoInsight to access the server.
"""
lib.start_server(lib.ServerSettings(address=address, cors_origin=cors_origin))
cors_origins : set[str] = set()
if cors_origin is not None:
cors_origins.update(s for o in cors_origin.split(',') if (s:= o.strip()) != '')
if cors_cocoindex:
cors_origins.add(COCOINDEX_HOST)
if cors_local is not None:
cors_origins.add(f"http://localhost:{cors_local}")
lib.start_server(lib.ServerSettings(address=address, cors_origins=list(cors_origins)))
if live_update:
options = flow.FlowLiveUpdaterOptions(live_mode=True, print_stats=not quiet)
execution_context.run(flow.update_all_flows(options))
if COCOINDEX_HOST in cors_origins:
click.echo(f"Open CocoInsight at: {COCOINDEX_HOST}/cocoinsight")
input("Press Enter to stop...")


Expand Down
18 changes: 9 additions & 9 deletions python/cocoindex/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,22 @@
import functools
import inspect

from typing import Callable, Self
from typing import Callable, Self, Any
from dataclasses import dataclass

from . import _engine
from . import flow, query, cli
from .convert import dump_engine_object


def _load_field(target: dict[str, str], name: str, env_name: str, required: bool = False):
def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool = False,
parse: Callable[[str], Any] | None = None):
value = os.getenv(env_name)
if value is None:
if required:
raise ValueError(f"{env_name} is not set")
else:
target[name] = value
target[name] = value if parse is None else parse(value)

@dataclass
class DatabaseConnectionSpec:
Expand Down Expand Up @@ -56,17 +57,16 @@ class ServerSettings:
# The address to bind the server to.
address: str = "127.0.0.1:8080"

# The origin of the client (e.g. CocoInsight UI) to allow CORS from.
cors_origin: str | None = None
# The origins of the clients (e.g. CocoInsight UI) to allow CORS from.
cors_origins: list[str] | None = None

@classmethod
def from_env(cls) -> Self:
"""Load settings from environment variables."""

kwargs: dict[str, str] = dict()
kwargs: dict[str, Any] = dict()
_load_field(kwargs, "address", "COCOINDEX_SERVER_ADDRESS")
_load_field(kwargs, "cors_origin", "COCOINDEX_SERVER_CORS_ORIGIN")

_load_field(kwargs, "cors_origins", "COCOINDEX_SERVER_CORS_ORIGINS",
parse=lambda s: [o for e in s.split(",") if (o := e.strip()) != ""])
return cls(**kwargs)


Expand Down
20 changes: 12 additions & 8 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use crate::{lib_context::LibContext, service};
use crate::prelude::*;

use anyhow::Result;
use crate::{lib_context::LibContext, service};
use axum::{routing, Router};
use futures::{future::BoxFuture, FutureExt};
use serde::Deserialize;
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{
cors::{AllowOrigin, CorsLayer},
Expand All @@ -14,7 +11,8 @@ use tower_http::{
#[derive(Deserialize, Debug)]
pub struct ServerSettings {
pub address: String,
pub cors_origin: Option<String>,
#[serde(default)]
pub cors_origins: Vec<String>,
}

/// Initialize the server and return a future that will actually handle requests.
Expand All @@ -23,9 +21,15 @@ pub async fn init_server(
settings: ServerSettings,
) -> Result<BoxFuture<'static, ()>> {
let mut cors = CorsLayer::default();
if let Some(ui_cors_origin) = &settings.cors_origin {
debug!("cors_origins: {:?}", settings.cors_origins);
if !settings.cors_origins.is_empty() {
let origins: Vec<_> = settings
.cors_origins
.iter()
.map(|origin| origin.parse())
.collect::<Result<_, _>>()?;
cors = cors
.allow_origin(AllowOrigin::exact(ui_cors_origin.parse()?))
.allow_origin(AllowOrigin::list(origins))
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
Expand Down