|
148 | 148 |
|
149 | 149 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
150 | 150 |
|
| 151 | +# Define DATA_SOURCES_CONFIG to map service_name to engine class |
| 152 | +DATA_SOURCES_CONFIG: Dict[str, Dict[str, Any]] = { |
| 153 | + "gmail": { |
| 154 | + "engine_class": GmailContextEngine, |
| 155 | + "enabled_by_default": True, # Or fetch from user settings |
| 156 | + # Add other service-specific configs if needed |
| 157 | + }, |
| 158 | + # "gcalendar": { |
| 159 | + # "engine_class": GCalendarContextEngine, |
| 160 | + # "enabled_by_default": True, |
| 161 | + # }, |
| 162 | + # "internet_search": { |
| 163 | + # "engine_class": InternetSearchContextEngine, |
| 164 | + # "enabled_by_default": False, # Example: off by default |
| 165 | + # } |
| 166 | +} |
| 167 | + |
| 168 | +POLLING_SCHEDULER_INTERVAL_SECONDS = int(os.getenv("POLLING_SCHEDULER_INTERVAL_SECONDS", 30)) |
| 169 | +async def polling_scheduler_loop(): |
| 170 | + print(f"[POLLING_SCHEDULER] Starting loop (interval: {POLLING_SCHEDULER_INTERVAL_SECONDS}s)") |
| 171 | + await mongo_manager.reset_stale_polling_locks() # Reset any locks from crashed previous runs |
| 172 | + |
| 173 | + while True: |
| 174 | + try: |
| 175 | + # print(f"[POLLING_SCHEDULER] Checking for due polling tasks at {datetime.now(timezone.utc).isoformat()}") |
| 176 | + due_tasks_states = await mongo_manager.get_due_polling_tasks() # Fetches tasks that are enabled, due, and not locked |
| 177 | + |
| 178 | + if not due_tasks_states: |
| 179 | + # print(f"[POLLING_SCHEDULER] No tasks due at this time.") |
| 180 | + pass |
| 181 | + else: |
| 182 | + print(f"[POLLING_SCHEDULER] Found {len(due_tasks_states)} due polling tasks.") |
| 183 | + |
| 184 | + for task_state in due_tasks_states: |
| 185 | + user_id = task_state["user_id"] |
| 186 | + service_name = task_state["service_name"] # Changed from engine_category |
| 187 | + |
| 188 | + print(f"[POLLING_SCHEDULER] Attempting to process task for {user_id}/{service_name}") |
| 189 | + |
| 190 | + # Atomically try to acquire the lock for this specific task |
| 191 | + locked_task_state = await mongo_manager.set_polling_status_and_get(user_id, service_name) |
| 192 | + |
| 193 | + if locked_task_state: |
| 194 | + print(f"[POLLING_SCHEDULER] Acquired lock for {user_id}/{service_name}. Triggering poll.") |
| 195 | + |
| 196 | + engine_instance = active_context_engines.get(user_id, {}).get(service_name) |
| 197 | + |
| 198 | + if not engine_instance: |
| 199 | + engine_config = DATA_SOURCES_CONFIG.get(service_name) |
| 200 | + if engine_config and engine_config.get("engine_class"): |
| 201 | + engine_class = engine_config["engine_class"] |
| 202 | + print(f"[POLLING_SCHEDULER] Creating new {engine_class.__name__} instance for {user_id}/{service_name}") |
| 203 | + engine_instance = engine_class( |
| 204 | + user_id=user_id, |
| 205 | + task_queue=task_queue, |
| 206 | + memory_backend=memory_backend, |
| 207 | + websocket_manager=manager, # Global websocket_manager |
| 208 | + mongo_manager_instance=mongo_manager # Global mongo_manager |
| 209 | + ) |
| 210 | + if user_id not in active_context_engines: |
| 211 | + active_context_engines[user_id] = {} |
| 212 | + active_context_engines[user_id][service_name] = engine_instance |
| 213 | + else: |
| 214 | + print(f"[POLLING_SCHEDULER_ERROR] No engine class configured for service: {service_name}") |
| 215 | + # Release lock as we can't process it |
| 216 | + await mongo_manager.update_polling_state(user_id, service_name, {"is_currently_polling": False}) |
| 217 | + continue |
| 218 | + |
| 219 | + # Run the poll cycle in a new asyncio task so the scheduler doesn't block |
| 220 | + # The run_poll_cycle itself will handle releasing the lock via calculate_and_schedule_next_poll |
| 221 | + asyncio.create_task(engine_instance.run_poll_cycle()) |
| 222 | + else: |
| 223 | + # This means another scheduler instance/worker picked it up, or it's no longer due. |
| 224 | + print(f"[POLLING_SCHEDULER] Could not acquire lock for {user_id}/{service_name} (already processing or no longer due).") |
| 225 | + |
| 226 | + except Exception as e: |
| 227 | + print(f"[POLLING_SCHEDULER_ERROR] Error in scheduler loop: {e}") |
| 228 | + traceback.print_exc() # Log full error |
| 229 | + |
| 230 | + await asyncio.sleep(POLLING_SCHEDULER_INTERVAL_SECONDS) |
| 231 | + |
| 232 | + |
151 | 233 | class Auth: |
152 | 234 | async def _validate_token_and_get_payload(self, token: str) -> dict: |
153 | 235 | # ... (Auth class definition remains the same) ... |
@@ -711,43 +793,48 @@ async def polling_scheduler_loop(): |
711 | 793 |
|
712 | 794 |
|
713 | 795 | async def start_user_context_engines(user_id: str): |
714 | | - """Starts context engines for a given user if not already running and initializes polling state.""" |
715 | | - if user_id not in active_context_engines: |
| 796 | + """Ensures polling state exists for all enabled services for a user.""" |
| 797 | + if user_id not in active_context_engines: # This dict might be less critical now for *running* engines |
716 | 798 | active_context_engines[user_id] = {} |
717 | 799 |
|
718 | | - user_profile = await load_user_profile(user_id) # load_user_profile is your existing helper |
719 | | - user_settings = user_profile.get("userData", {}) |
720 | | - |
721 | | - for source_name, config in DATA_SOURCES_CONFIG.items(): |
722 | | - is_enabled = user_settings.get(f"{source_name}Enabled", config["enabled_by_default"]) |
| 800 | + user_profile = await mongo_manager.get_user_profile(user_id) |
| 801 | + user_settings = user_profile.get("userData", {}) if user_profile else {} |
| 802 | + |
| 803 | + for service_name, config in DATA_SOURCES_CONFIG.items(): |
| 804 | + # Determine if the service is enabled for the user |
| 805 | + # This could come from user_profile's userData, or a default. |
| 806 | + # Example: is_service_enabled = user_settings.get(f"{service_name}_polling_enabled", config["enabled_by_default"]) |
| 807 | + # For now, let's assume we check a specific field or default |
723 | 808 |
|
724 | | - if is_enabled: |
725 | | - if source_name not in active_context_engines[user_id]: |
726 | | - print(f"[CONTEXT_ENGINE_MGR] Starting {source_name} engine for user {user_id}...") |
| 809 | + is_service_enabled_in_db = False |
| 810 | + polling_state_doc = await mongo_manager.get_polling_state(user_id, service_name) |
| 811 | + if polling_state_doc: |
| 812 | + is_service_enabled_in_db = polling_state_doc.get("is_enabled", False) # Default to False if key missing |
| 813 | + else: # No state yet, use default from config |
| 814 | + is_service_enabled_in_db = config.get("enabled_by_default", True) |
| 815 | + |
| 816 | + |
| 817 | + if is_service_enabled_in_db: |
| 818 | + # Check if an engine instance exists (less critical now, but can be kept for potential direct calls) |
| 819 | + engine_instance = active_context_engines.get(user_id, {}).get(service_name) |
| 820 | + if not engine_instance: |
727 | 821 | engine_class = config["engine_class"] |
| 822 | + print(f"[CONTEXT_ENGINE_MGR] Creating transient {engine_class.__name__} instance for {user_id}/{service_name} for state init.") |
728 | 823 | engine_instance = engine_class( |
729 | 824 | user_id=user_id, |
730 | | - task_queue=task_queue, |
| 825 | + task_queue=task_queue, # Pass your global/initialized instances |
731 | 826 | memory_backend=memory_backend, |
732 | 827 | websocket_manager=manager, |
733 | 828 | mongo_manager_instance=mongo_manager |
734 | 829 | ) |
735 | | - active_context_engines[user_id][source_name] = engine_instance |
736 | | - # Initialize polling state (will set next_poll_at to now if new) |
737 | | - await engine_instance.initialize_polling_state() |
738 | | - print(f"[CONTEXT_ENGINE_MGR] {source_name} engine started and polling state initialized for user {user_id}.") |
739 | | - else: |
740 | | - # Engine already active, ensure its polling state is initialized (e.g. if server restarted) |
741 | | - await active_context_engines[user_id][source_name].initialize_polling_state() |
742 | | - print(f"[CONTEXT_ENGINE_MGR] {source_name} engine for user {user_id} already active. Ensured polling state.") |
743 | | - |
744 | | - elif not is_enabled and source_name in active_context_engines[user_id]: |
745 | | - print(f"[CONTEXT_ENGINE_MGR] {source_name} engine for user {user_id} is disabled. Stopping (if implemented) and removing.") |
746 | | - # Add engine stop logic if available: |
747 | | - # if hasattr(active_context_engines[user_id][source_name], 'stop'): |
748 | | - # await active_context_engines[user_id][source_name].stop() |
749 | | - del active_context_engines[user_id][source_name] |
750 | | - |
| 830 | + # active_context_engines[user_id][service_name] = engine_instance # Optional to store |
| 831 | + |
| 832 | + # Crucially, ensure the polling state is initialized in the DB |
| 833 | + # This will set it up for the central scheduler if it's new or enable it. |
| 834 | + await engine_instance.initialize_polling_state() |
| 835 | + print(f"[CONTEXT_ENGINE_MGR] Polling state ensured for {service_name} for user {user_id}.") |
| 836 | + else: |
| 837 | + print(f"[CONTEXT_ENGINE_MGR] Service {service_name} is disabled for user {user_id}. Skipping engine start/state init.") |
751 | 838 |
|
752 | 839 | @app.on_event("startup") |
753 | 840 | async def startup_event(): |
|
0 commit comments