Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 107 additions & 154 deletions frontend/rust-lib/flowy-ai/src/local_ai/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use flowy_ai_pub::persistence::{
use flowy_ai_pub::user_service::AIUserService;
use futures_util::SinkExt;
use lib_infra::util::get_operating_system;
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
use ollama_rs::Ollama;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::path::PathBuf;
Expand All @@ -32,11 +34,6 @@ use tokio_stream::StreamExt;
use tracing::{debug, error, info, instrument, warn};
use uuid::Uuid;

#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
use ollama_rs::Ollama;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LocalAISetting {
pub ollama_server_url: String,
Expand All @@ -62,7 +59,6 @@ pub struct LocalAIController {
current_chat_id: ArcSwapOption<Uuid>,
store_preferences: Weak<KVStorePreferences>,
user_service: Arc<dyn AIUserService>,
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
ollama: ArcSwapOption<Ollama>,
}

Expand Down Expand Up @@ -95,105 +91,87 @@ impl LocalAIController {
res_impl,
));

#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
let ollama = {
let mut ollama = ArcSwapOption::default();
let sys = get_operating_system();
if sys.is_desktop() {
let setting = local_ai_resource.get_llm_setting();
ollama.store(
Ollama::try_new(&setting.ollama_server_url)
.map(Arc::new)
.ok(),
);
}
ollama
};

// Subscribe to state changes
let mut running_state_rx = local_ai.subscribe_running_state();
let cloned_llm_res = Arc::clone(&local_ai_resource);
let cloned_store_preferences = store_preferences.clone();
let cloned_local_ai = Arc::clone(&local_ai);
let cloned_user_service = Arc::clone(&user_service);
let ollama = ArcSwapOption::default();
let sys = get_operating_system();
if sys.is_desktop() {
let setting = local_ai_resource.get_llm_setting();
ollama.store(
Ollama::try_new(&setting.ollama_server_url)
.map(Arc::new)
.ok(),
);

// Spawn a background task to listen for plugin state changes
tokio::spawn(async move {
while let Some(state) = running_state_rx.next().await {
// Skip if we can't get workspace_id
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
continue;
};
// Subscribe to state changes
let mut running_state_rx = local_ai.subscribe_running_state();
let cloned_llm_res = Arc::clone(&local_ai_resource);
let cloned_store_preferences = store_preferences.clone();
let cloned_local_ai = Arc::clone(&local_ai);
let cloned_user_service = Arc::clone(&user_service);

// Spawn a background task to listen for plugin state changes
tokio::spawn(async move {
while let Some(state) = running_state_rx.next().await {
// Skip if we can't get workspace_id
let Ok(workspace_id) = cloned_user_service.workspace_id() else {
continue;
};

let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string());
info!("[AI Plugin] state: {:?}", state);

// Read whether plugin is enabled from store; default to true
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
// Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled
let (plugin_downloaded, lack_of_resource) =
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
// Possibly check plugin readiness and resource concurrency in parallel,
// but here we do it sequentially for clarity.
let downloaded = is_plugin_ready();
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
(downloaded, resource_lack)
let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string());
info!("[AI Plugin] state: {:?}", state);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider extracting the inner loop of the spawned async task into a separate helper function to improve code structure and readability.

Consider extracting the inner loop of the spawned async task into its own helper function. This reduces nesting and makes the control flow easier to follow. For example:

```rust
async fn process_plugin_state(
    mut running_state_rx: impl StreamExt<Item = RunningState> + Unpin,
    cloned_llm_res: Arc<LocalAIResourceController>,
    cloned_store_preferences: Weak<KVStorePreferences>,
    cloned_local_ai: Arc<OllamaAIPlugin>,
    cloned_user_service: Arc<dyn AIUserService>,
) {
    while let Some(state) = running_state_rx.next().await {
        // Early exit if workspace_id is not accessible
        let workspace_id = match cloned_user_service.workspace_id() {
            Ok(id) => id,
            Err(_) => continue,
        };

        let key = crate::local_ai::controller::local_ai_enabled_key(&workspace_id.to_string());
        info!("[AI Plugin] state: {:?}", state);

        if let Some(store_preferences) = cloned_store_preferences.upgrade() {
            let enabled = store_preferences.get_bool(&key).unwrap_or(true);

            let (plugin_downloaded, lack_of_resource) = if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
                let downloaded = is_plugin_ready();
                let resource_lack = cloned_llm_res.get_lack_of_resource().await;
                (downloaded, resource_lack)
            } else {
                (false, None)
            };

            let plugin_version = if matches!(state, RunningState::Running { .. }) {
                match cloned_local_ai.plugin_info().await {
                    Ok(info) => Some(info.version),
                    Err(_) => None,
                }
            } else {
                None
            };

            let new_state = RunningStatePB::from(state);
            chat_notification_builder(
                APPFLOWY_AI_NOTIFICATION_KEY,
                ChatNotification::UpdateLocalAIState,
            )
            .payload(LocalAIPB {
                enabled,
                plugin_downloaded,
                lack_of_resource,
                state: new_state,
                plugin_version,
            })
            .send();
        } else {
            warn!("[AI Plugin] store preferences is dropped");
        }
    }
}

Then call it from the spawn block like:

tokio::spawn(async move {
    process_plugin_state(
        running_state_rx,
        cloned_llm_res,
        cloned_store_preferences,
        cloned_local_ai,
        cloned_user_service,
    )
    .await;
});

This extraction flattens the async task, reduces the nesting in your main flow, and keeps the functionality intact.


// Read whether plugin is enabled from store; default to true
if let Some(store_preferences) = cloned_store_preferences.upgrade() {
let enabled = store_preferences.get_bool(&key).unwrap_or(true);
// Only check resource status if the plugin isn't in "UnexpectedStop" and is enabled
let (plugin_downloaded, lack_of_resource) =
if !matches!(state, RunningState::UnexpectedStop { .. }) && enabled {
// Possibly check plugin readiness and resource concurrency in parallel,
// but here we do it sequentially for clarity.
let downloaded = is_plugin_ready();
let resource_lack = cloned_llm_res.get_lack_of_resource().await;
(downloaded, resource_lack)
} else {
(false, None)
};

// If plugin is running, retrieve version
let plugin_version = if matches!(state, RunningState::Running { .. }) {
match cloned_local_ai.plugin_info().await {
Ok(info) => Some(info.version),
Err(_) => None,
}
} else {
(false, None)
None
};

// If plugin is running, retrieve version
let plugin_version = if matches!(state, RunningState::Running { .. }) {
match cloned_local_ai.plugin_info().await {
Ok(info) => Some(info.version),
Err(_) => None,
}
// Broadcast the new local AI state
let new_state = RunningStatePB::from(state);
chat_notification_builder(
APPFLOWY_AI_NOTIFICATION_KEY,
ChatNotification::UpdateLocalAIState,
)
.payload(LocalAIPB {
enabled,
plugin_downloaded,
lack_of_resource,
state: new_state,
plugin_version,
})
.send();
} else {
None
};

// Broadcast the new local AI state
let new_state = RunningStatePB::from(state);
chat_notification_builder(
APPFLOWY_AI_NOTIFICATION_KEY,
ChatNotification::UpdateLocalAIState,
)
.payload(LocalAIPB {
enabled,
plugin_downloaded,
lack_of_resource,
state: new_state,
plugin_version,
})
.send();
} else {
warn!("[AI Plugin] store preferences is dropped");
warn!("[AI Plugin] store preferences is dropped");
}
}
}
});

#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
{
Self {
ai_plugin: local_ai,
resource: local_ai_resource,
current_chat_id: ArcSwapOption::default(),
store_preferences,
user_service,
ollama,
}
});
}

#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
{
Self {
ai_plugin: local_ai,
resource: local_ai_resource,
current_chat_id: ArcSwapOption::default(),
store_preferences,
user_service,
}
Self {
ai_plugin: local_ai,
resource: local_ai_resource,
current_chat_id: ArcSwapOption::default(),
store_preferences,
user_service,
ollama,
}
}
#[instrument(level = "debug", skip_all)]
Expand Down Expand Up @@ -329,35 +307,18 @@ impl LocalAIController {
}

pub async fn get_all_chat_local_models(&self) -> Vec<AIModel> {
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
{
self
.get_filtered_local_models(|name| !name.contains("embed"))
.await
}

#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
{
vec![]
}
self
.get_filtered_local_models(|name| !name.contains("embed"))
.await
}

pub async fn get_all_embedded_local_models(&self) -> Vec<AIModel> {
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
{
self
.get_filtered_local_models(|name| name.contains("embed"))
.await
}

#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
{
vec![]
}
self
.get_filtered_local_models(|name| name.contains("embed"))
.await
}

// Helper function to avoid code duplication in model retrieval
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
async fn get_filtered_local_models<F>(&self, filter_fn: F) -> Vec<AIModel>
where
F: Fn(&str) -> bool,
Expand All @@ -383,43 +344,35 @@ impl LocalAIController {
let mut conn = self.user_service.sqlite_connection(uid)?;
match select_local_ai_model(&mut conn, model_name) {
None => {
#[cfg(any(target_os = "macos", target_os = "linux", target_os = "windows"))]
{
let ollama = self
.ollama
.load_full()
.ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?;

let request = GenerateEmbeddingsRequest::new(
model_name.to_string(),
EmbeddingsInput::Single("Hello".to_string()),
);

let model_type = match ollama.generate_embeddings(request).await {
Ok(value) => {
if value.embeddings.is_empty() {
ModelType::Chat
} else {
ModelType::Embedding
}
},
Err(_) => ModelType::Chat,
};
let ollama = self
.ollama
.load_full()
.ok_or_else(|| FlowyError::local_ai().with_context("ollama is not initialized"))?;

let request = GenerateEmbeddingsRequest::new(
model_name.to_string(),
EmbeddingsInput::Single("Hello".to_string()),
);

upsert_local_ai_model(
&mut conn,
&LocalAIModelTable {
name: model_name.to_string(),
model_type: model_type as i16,
},
)?;
Ok(model_type)
}
let model_type = match ollama.generate_embeddings(request).await {
Ok(value) => {
if value.embeddings.is_empty() {
ModelType::Chat
} else {
ModelType::Embedding
}
},
Err(_) => ModelType::Chat,
};

#[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
{
Ok(ModelType::Chat)
}
upsert_local_ai_model(
&mut conn,
&LocalAIModelTable {
name: model_name.to_string(),
model_type: model_type as i16,
},
)?;
Ok(model_type)
},
Some(r) => Ok(ModelType::from(r.model_type)),
}
Expand Down
Loading