Skip to content
Merged
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 20 additions & 10 deletions crates/sail-cli/src/spark/mcp_server.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use std::fmt;
use std::fmt::Formatter;
use std::net::Ipv4Addr;
use std::sync::Arc;

use clap::ValueEnum;
use log::info;
use pyo3::prelude::PyAnyMethods;
use pyo3::{PyResult, Python};
use sail_spark_connect::entrypoint::serve;
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_spark_connect::entrypoint::{serve, SessionManagerOptions};
use sail_telemetry::telemetry::init_telemetry;
use tokio::net::TcpListener;
use tokio::runtime::Runtime;

use crate::python::Modules;

Expand Down Expand Up @@ -42,31 +44,39 @@ pub struct McpSettings {
pub spark_remote: Option<String>,
}

fn run_spark_connect_server(runtime: &Runtime) -> Result<String, Box<dyn std::error::Error>> {
let (server_port, server_task) = runtime.block_on(async move {
fn run_spark_connect_server(
options: SessionManagerOptions,
) -> Result<String, Box<dyn std::error::Error>> {
let handle = options.runtime.primary().clone();
let (server_port, server_task) = handle.block_on(async move {
// Listen on only the loopback interface for security.
let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 0)).await?;
let port = listener.local_addr()?.port();
let task = async move {
info!("Starting the Spark Connect server on port {port}...");
let _ = serve(listener, shutdown()).await;
let _ = serve(listener, shutdown(), options).await;
info!("The Spark Connect server has stopped.");
};
<Result<_, Box<dyn std::error::Error>>>::Ok((port, task))
})?;
runtime.spawn(server_task);
handle.spawn(server_task);
Ok(format!("sc://127.0.0.1:{server_port}"))
}

pub fn run_spark_mcp_server(settings: McpSettings) -> Result<(), Box<dyn std::error::Error>> {
init_telemetry()?;

let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let config = Arc::new(AppConfig::load()?);
let runtime = RuntimeManager::try_new(&config.runtime)?;

let spark_remote = match settings.spark_remote {
None => run_spark_connect_server(&runtime)?,
None => {
let options = SessionManagerOptions {
config: Arc::clone(&config),
runtime: runtime.handle(),
};
run_spark_connect_server(options)?
}
Some(x) => x,
};

Expand Down
21 changes: 12 additions & 9 deletions crates/sail-cli/src/spark/server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::net::IpAddr;
use std::sync::Arc;

use log::info;
use sail_spark_connect::entrypoint::serve;
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_spark_connect::entrypoint::{serve, SessionManagerOptions};
use sail_telemetry::telemetry::init_telemetry;
use tokio::net::TcpListener;

const SERVER_STACK_SIZE: usize = 1024 * 1024 * 8;

/// Handles graceful shutdown by waiting for a `SIGINT` signal in [tokio].
///
/// The `SIGINT` signal is captured by Python if the `_signal` module is imported [1].
Expand All @@ -28,19 +29,21 @@ async fn shutdown() {
pub fn run_spark_connect_server(ip: IpAddr, port: u16) -> Result<(), Box<dyn std::error::Error>> {
init_telemetry()?;

let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_stack_size(SERVER_STACK_SIZE)
.enable_all()
.build()?;
let config = Arc::new(AppConfig::load()?);
let runtime = RuntimeManager::try_new(&config.runtime)?;
let options = SessionManagerOptions {
config: Arc::clone(&config),
runtime: runtime.handle(),
};

runtime.block_on(async {
runtime.handle().primary().block_on(async {
// A secure connection can be handled by a gateway in production.
let listener = TcpListener::bind((ip, port)).await?;
info!(
"Starting the Spark Connect server on {}...",
listener.local_addr()?
);
serve(listener, shutdown()).await?;
serve(listener, shutdown(), options).await?;
info!("The Spark Connect server has stopped.");
<Result<(), Box<dyn std::error::Error>>>::Ok(())
})?;
Expand Down
21 changes: 14 additions & 7 deletions crates/sail-cli/src/spark/shell.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
use std::net::Ipv4Addr;
use std::sync::Arc;

use pyo3::prelude::PyAnyMethods;
use pyo3::{PyResult, Python};
use sail_spark_connect::entrypoint::serve;
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_spark_connect::entrypoint::{serve, SessionManagerOptions};
use tokio::net::TcpListener;
use tokio::sync::oneshot;

use crate::python::Modules;

pub fn run_pyspark_shell() -> Result<(), Box<dyn std::error::Error>> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let config = Arc::new(AppConfig::load()?);
let runtime = RuntimeManager::try_new(&config.runtime)?;
let options = SessionManagerOptions {
config,
runtime: runtime.handle(),
};
let (_tx, rx) = oneshot::channel::<()>();
let (server_port, server_task) = runtime.block_on(async move {
let handle = runtime.handle().primary().clone();
let (server_port, server_task) = handle.block_on(async move {
// Listen on only the loopback interface for security.
let listener = TcpListener::bind((Ipv4Addr::new(127, 0, 0, 1), 0)).await?;
let port = listener.local_addr()?.port();
Expand All @@ -25,11 +32,11 @@ pub fn run_pyspark_shell() -> Result<(), Box<dyn std::error::Error>> {
let _ = rx.await;
};
let task = async {
let _ = serve(listener, shutdown).await;
let _ = serve(listener, shutdown, options).await;
};
<Result<_, Box<dyn std::error::Error>>>::Ok((port, task))
})?;
runtime.spawn(server_task);
handle.spawn(server_task);
Python::with_gil(|py| -> PyResult<_> {
let shell = Modules::SPARK_SHELL.load(py)?;
shell
Expand Down
13 changes: 8 additions & 5 deletions crates/sail-cli/src/worker/entrypoint.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use sail_common::config::AppConfig;
use sail_common::runtime::RuntimeManager;
use sail_telemetry::telemetry::init_telemetry;

pub fn run_worker() -> Result<(), Box<dyn std::error::Error>> {
init_telemetry()?;

let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;

runtime.block_on(sail_execution::run_worker())?;
let config = AppConfig::load()?;
let runtime = RuntimeManager::try_new(&config.runtime)?;
runtime
.handle()
.primary()
.block_on(sail_execution::run_worker(&config, runtime.handle()))?;

fastrace::flush();

Expand Down
1 change: 1 addition & 0 deletions crates/sail-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ figment = { workspace = true }
half = { workspace = true }
log = { workspace = true }
iana-time-zone = { workspace = true }
tokio = { workspace = true }
7 changes: 7 additions & 0 deletions crates/sail-common/src/config/application.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const DEFAULT_CONFIG: &str = include_str!("default.toml");
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
pub mode: ExecutionMode,
pub runtime: RuntimeConfig,
pub cluster: ClusterConfig,
pub execution: ExecutionConfig,
pub kubernetes: KubernetesConfig,
Expand Down Expand Up @@ -67,6 +68,12 @@ pub enum ExecutionMode {
KubernetesCluster,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeConfig {
pub stack_size: usize,
pub enable_secondary: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterConfig {
pub enable_tls: bool,
Expand Down
4 changes: 4 additions & 0 deletions crates/sail-common/src/config/default.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
mode = "local"

[runtime]
stack_size = 8388608
enable_secondary = false

[cluster]
enable_tls = false
driver_listen_host = "127.0.0.1"
Expand Down
1 change: 1 addition & 0 deletions crates/sail-common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod datetime;
pub mod debug;
pub mod error;
pub mod object;
pub mod runtime;
pub mod spec;
pub mod string;
pub mod tests;
53 changes: 53 additions & 0 deletions crates/sail-common/src/runtime.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use tokio::runtime::{Handle, Runtime};

use crate::config::RuntimeConfig;
use crate::error::{CommonError, CommonResult};

#[derive(Debug)]
pub struct RuntimeManager {
primary: Runtime,
secondary: Option<Runtime>,
}

impl RuntimeManager {
pub fn try_new(config: &RuntimeConfig) -> CommonResult<Self> {
let primary = Self::build_runtime(config.stack_size)?;
let secondary = if config.enable_secondary {
Some(Self::build_runtime(config.stack_size)?)
} else {
None
};

Ok(Self { primary, secondary })
}

pub fn handle(&self) -> RuntimeHandle {
let primary = self.primary.handle().clone();
let secondary = self.secondary.as_ref().map(|r| r.handle().clone());
RuntimeHandle { primary, secondary }
}

fn build_runtime(stack_size: usize) -> CommonResult<Runtime> {
tokio::runtime::Builder::new_multi_thread()
.thread_stack_size(stack_size)
.enable_all()
.build()
.map_err(|e| CommonError::internal(e.to_string()))
}
}

#[derive(Debug, Clone)]
pub struct RuntimeHandle {
primary: Handle,
secondary: Option<Handle>,
}

impl RuntimeHandle {
pub fn primary(&self) -> &Handle {
&self.primary
}

pub fn secondary(&self) -> Option<&Handle> {
self.secondary.as_ref()
}
}
4 changes: 3 additions & 1 deletion crates/sail-execution/src/driver/actor/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ impl Actor for DriverActor {

fn new(options: DriverOptions) -> Self {
let worker_manager: Arc<dyn WorkerManager> = match &options.worker_manager {
WorkerManagerOptions::Local => Arc::new(LocalWorkerManager::new()),
WorkerManagerOptions::Local => {
Arc::new(LocalWorkerManager::new(options.runtime.clone()))
}
WorkerManagerOptions::Kubernetes(options) => {
Arc::new(KubernetesWorkerManager::new(options.clone()))
}
Expand Down
3 changes: 2 additions & 1 deletion crates/sail-execution/src/driver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ pub(crate) use actor::DriverActor;
pub(crate) use client::DriverClient;
pub(crate) use event::DriverEvent;
pub(crate) use gen::driver_service_client::DriverServiceClient;
pub(crate) use options::{DriverOptions, WorkerManagerOptions};
pub use options::DriverOptions;
pub(crate) use options::WorkerManagerOptions;
8 changes: 5 additions & 3 deletions crates/sail-execution/src/driver/options.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::time::Duration;

use sail_common::config::{AppConfig, ExecutionMode};
use sail_common::runtime::RuntimeHandle;
use sail_server::RetryStrategy;

use crate::error::{ExecutionError, ExecutionResult};
Expand All @@ -25,6 +26,7 @@ pub struct DriverOptions {
pub job_output_buffer: usize,
pub rpc_retry_strategy: RetryStrategy,
pub worker_manager: WorkerManagerOptions,
pub runtime: RuntimeHandle,
}

#[derive(Debug)]
Expand All @@ -33,9 +35,8 @@ pub enum WorkerManagerOptions {
Kubernetes(KubernetesWorkerManagerOptions),
}

impl TryFrom<&AppConfig> for DriverOptions {
type Error = ExecutionError;
fn try_from(config: &AppConfig) -> ExecutionResult<Self> {
impl DriverOptions {
pub fn try_new(config: &AppConfig, runtime: RuntimeHandle) -> ExecutionResult<Self> {
let worker_manager = match config.mode {
ExecutionMode::Local => {
return Err(ExecutionError::InvalidArgument(
Expand Down Expand Up @@ -79,6 +80,7 @@ impl TryFrom<&AppConfig> for DriverOptions {
task_launch_timeout: Duration::from_secs(config.cluster.task_launch_timeout_secs),
job_output_buffer: config.cluster.job_output_buffer,
worker_manager,
runtime,
})
}
}
Loading