Skip to content

Commit 319554c

Browse files
authored
✨ Optimize the CPU_cores x celery_concurrency allocation
2 parents 8a3ee4f + 87ad909 commit 319554c

24 files changed

+1228
-1035
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ frontend_standalone/
2424
.pnpm-store/
2525
frontend-dist/
2626

27-
backend/assets/clip-vit-base-patch32
28-
model-assets
27+
model-assets/
2928

3029
# Test coverage reports
3130
*coverage_html

backend/apps/data_process_app.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77

88
from consts.model import TaskResponse, TaskRequest, BatchTaskResponse, BatchTaskRequest, SimpleTaskStatusResponse, \
9-
SimpleTasksListResponse
9+
SimpleTasksListResponse, ConvertStateRequest, ConvertStateResponse
1010
from data_process.utils import get_task_info
1111
from data_process.tasks import process_and_forward, process_sync
1212
from services.data_process_service import get_data_process_service
@@ -363,3 +363,51 @@ async def process_text_file(
363363
status_code=500,
364364
detail=f"An error occurred while processing the file: {str(e)}"
365365
)
366+
367+
@router.post("/convert_state", response_model=ConvertStateResponse, status_code=200)
368+
async def convert_state(request: ConvertStateRequest):
369+
"""Convert Celery task states to custom frontend state.
370+
371+
This helper endpoint allows callers that do **not** install Celery dependencies
372+
to obtain the corresponding frontend state for a pair of Celery task states.
373+
"""
374+
from celery import states
375+
376+
def _convert_to_custom_state_inner(process_celery_state: str, forward_celery_state: str) -> str:
377+
"""Inner helper to keep the original mapping logic in one place."""
378+
# Handle failure states first
379+
if process_celery_state == states.FAILURE:
380+
return "PROCESS_FAILED"
381+
if forward_celery_state == states.FAILURE:
382+
return "FORWARD_FAILED"
383+
384+
# Handle completed state - both must be SUCCESS
385+
if process_celery_state == states.SUCCESS and forward_celery_state == states.SUCCESS:
386+
return "COMPLETED"
387+
388+
# Handle case where nothing has started
389+
if not process_celery_state and not forward_celery_state:
390+
return "WAIT_FOR_PROCESSING"
391+
392+
# Define state mappings
393+
forward_state_map = {
394+
states.PENDING: "WAIT_FOR_FORWARDING",
395+
states.STARTED: "FORWARDING",
396+
states.SUCCESS: "COMPLETED",
397+
states.FAILURE: "FORWARD_FAILED",
398+
}
399+
process_state_map = {
400+
states.PENDING: "WAIT_FOR_PROCESSING",
401+
states.STARTED: "PROCESSING",
402+
states.SUCCESS: "WAIT_FOR_FORWARDING", # Process done, waiting for forward
403+
states.FAILURE: "PROCESS_FAILED",
404+
}
405+
406+
if forward_celery_state:
407+
return forward_state_map.get(forward_celery_state, "WAIT_FOR_FORWARDING")
408+
if process_celery_state:
409+
return process_state_map.get(process_celery_state, "WAIT_FOR_PROCESSING")
410+
return "WAIT_FOR_PROCESSING"
411+
412+
state = _convert_to_custom_state_inner(request.process_state or "", request.forward_state or "")
413+
return ConvertStateResponse(state=state)

backend/apps/elasticsearch_app.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,11 @@ def create_index_documents(
151151
@router.get("/{index_name}/files")
152152
async def get_index_files(
153153
index_name: str = Path(..., description="Name of the index"),
154-
search_redis: bool = Query(True, description="Whether to search Redis to get incomplete files"),
155154
es_core: ElasticSearchCore = Depends(get_es_core)
156155
):
157156
"""Get all files from an index, including those that are not yet stored in ES"""
158157
try:
159-
result = await ElasticSearchService.list_files(index_name, include_chunks=False, search_redis=search_redis, es_core=es_core)
158+
result = await ElasticSearchService.list_files(index_name, include_chunks=False, es_core=es_core)
160159
# Transform result to match frontend expectations
161160
return {
162161
"status": "success",
@@ -188,7 +187,7 @@ def delete_documents(
188187
result["redis_cleanup"] = redis_cleanup_result
189188

190189
# Update the message to include Redis cleanup info
191-
original_message = result.get("message", f"Documents deleted successfully")
190+
original_message = result.get("message", "Documents deleted successfully")
192191
result["message"] = (f"{original_message}. "
193192
f"Cleaned up {redis_cleanup_result['total_deleted']} Redis records "
194193
f"({redis_cleanup_result['celery_tasks_deleted']} tasks, "
@@ -205,7 +204,7 @@ def delete_documents(
205204
logger.warning(f"Redis cleanup failed for document {path_or_url} in index {index_name}: {str(redis_error)}")
206205

207206
result["redis_cleanup_error"] = str(redis_error)
208-
original_message = result.get("message", f"Documents deleted successfully")
207+
original_message = result.get("message", "Documents deleted successfully")
209208
result["message"] = (f"{original_message}, "
210209
f"but Redis cleanup encountered an error: {str(redis_error)}")
211210

backend/consts/model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,14 @@ class ExportAndImportAgentInfo(BaseModel):
276276
class AgentImportRequest(BaseModel):
277277
agent_id: int
278278
agent_info: ExportAndImportAgentInfo
279+
280+
281+
class ConvertStateRequest(BaseModel):
282+
"""Request schema for /tasks/convert_state endpoint"""
283+
process_state: str = ""
284+
forward_state: str = ""
285+
286+
287+
class ConvertStateResponse(BaseModel):
288+
"""Response schema for /tasks/convert_state endpoint"""
289+
state: str

backend/data_process/app.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212

1313
# Determine package path dynamically
1414
import_path = 'data_process.tasks'
15-
logger.info(f"Using import path: {import_path}")
15+
logger.debug(f"Using import path: {import_path}")
1616

1717
REDIS_URL = config.redis_url
1818
REDIS_BACKEND_URL = config.redis_backend_url
1919

2020
if not REDIS_URL or not REDIS_BACKEND_URL:
2121
raise ValueError("FATAL: REDIS_URL or REDIS_BACKEND_URL is not configured. Please check the environment variables in this container.")
2222

23-
logger.info(f"Broker URL from config: {REDIS_URL}")
24-
logger.info(f"Backend URL from config: {REDIS_BACKEND_URL}")
23+
logger.debug(f"Broker URL from config: {REDIS_URL}")
24+
logger.debug(f"Backend URL from config: {REDIS_BACKEND_URL}")
2525

2626
# Create Celery app instance
2727
app = Celery(
@@ -62,8 +62,8 @@
6262
result_backend_always_retry=True, # Always retry backend operations
6363
result_backend_max_retries=10, # Max retries for backend operations
6464
task_time_limit=3600, # 1 hour time limit per task
65-
worker_prefetch_multiplier=1, # Don't prefetch tasks, process one at a time
66-
worker_max_tasks_per_child=100, # Restart worker after 100 tasks
65+
worker_prefetch_multiplier=4, # Allow prefetching for better throughput
66+
worker_max_tasks_per_child=1000, # Reduce restart frequency
6767
# Important for task chains
6868
task_acks_late=True, # Tasks are acknowledged after completion
6969
task_reject_on_worker_lost=True, # Tasks are rejected if worker is lost

backend/data_process/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def ray_plasma_directory(self) -> str:
6565
@property
6666
def ray_object_store_memory_gb(self) -> float:
6767
"""Ray object store memory limit (GB)"""
68-
return float(os.getenv('RAY_OBJECT_STORE_MEMORY_GB', '2.0'))
68+
return float(os.getenv('RAY_OBJECT_STORE_MEMORY_GB', '4.0'))
6969

7070
@property
7171
def ray_temp_dir(self) -> str:

backend/data_process/ray_actors.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
from database.attachment_db import get_file_stream
88

99
logger = logging.getLogger(__name__)
10-
NUM_CPUS = int(os.getenv("RAY_NUM_CPUS", "1"))
10+
# This now controls the number of CPUs requested by each DataProcessorRayActor instance.
11+
# It allows a single file processing task to potentially use more than one core if the
12+
# underlying processing library (e.g., unstructured) can leverage it.
13+
RAY_ACTOR_NUM_CPUS = int(os.getenv("RAY_ACTOR_NUM_CPUS", "2"))
1114

1215

13-
@ray.remote(num_cpus=NUM_CPUS)
16+
@ray.remote(num_cpus=RAY_ACTOR_NUM_CPUS)
1417
class DataProcessorRayActor:
1518
"""
1619
Ray actor for handling data processing tasks.
1720
Encapsulates the DataProcessCore to be used in a Ray cluster.
1821
"""
1922
def __init__(self):
20-
logger.info(f"Ray starting using {NUM_CPUS} CPUs...")
23+
logger.info(f"Ray actor initialized using {RAY_ACTOR_NUM_CPUS} CPU cores...")
2124
self._processor = DataProcessCore()
2225

2326
def process_file(self, source: str, chunking_strategy: str, destination: str, task_id: Optional[str] = None, **params) -> List[Dict[str, Any]]:

backend/data_process/ray_config.py

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13+
# Forward declaration variable so runtime references succeed before instantiation
14+
ray_config: Optional["RayConfig"] = None
15+
1316

1417
class RayConfig:
1518
"""Ray configuration manager"""
@@ -82,8 +85,13 @@ def init_ray(self, **kwargs) -> bool:
8285

8386
params = self.get_init_params(**kwargs)
8487

85-
logger.info("Initializing Ray cluster...")
86-
logger.debug(f"Ray configuration parameters:")
88+
# Get Ray configuration from environment
89+
ray_num_cpus = os.environ.get('RAY_NUM_CPUS')
90+
num_cpus = int(ray_num_cpus) if ray_num_cpus else None # None lets Ray decide
91+
92+
# Log the attempt to initialize
93+
logger.debug("Initializing Ray cluster...")
94+
logger.debug("Ray configuration parameters:")
8795
for key, value in params.items():
8896
if key.startswith('_'):
8997
logger.debug(f" {key}: {value}")
@@ -133,7 +141,7 @@ def connect_to_cluster(self, address: str = "auto") -> bool:
133141
return True
134142

135143
except Exception as e:
136-
logger.info(f"Failed to connect to Ray cluster: {str(e)}")
144+
logger.info(f"Cannot connect to Ray cluster: {str(e)}")
137145
return False
138146

139147
def start_local_cluster(self,
@@ -167,54 +175,36 @@ def log_configuration(self):
167175
logger.debug(f" ObjectStore memory: {self.object_store_memory_gb} GB")
168176
logger.debug(f" Temp directory: {self.temp_dir}")
169177

178+
@classmethod
179+
def init_ray_for_worker(cls, address: str = "auto") -> bool:
180+
"""Initialize Ray connection for Celery Worker (class method wrapper)."""
181+
logger.info("Initialize Ray connection for Celery Worker...")
182+
ray_config.log_configuration()
183+
return ray_config.connect_to_cluster(address)
170184

171-
# Create a global RayConfiguration instance
172-
ray_config = RayConfig()
185+
@classmethod
186+
def init_ray_for_service(cls,
187+
num_cpus: Optional[int] = None,
188+
dashboard_port: int = 8265,
189+
try_connect_first: bool = True,
190+
include_dashboard: bool = True) -> bool:
191+
"""Initialize Ray for data processing service (class method wrapper)."""
192+
ray_config.log_configuration()
173193

194+
if try_connect_first:
195+
# Try to connect to existing cluster first
196+
logger.debug("Trying to connect to existing Ray cluster...")
197+
if ray_config.connect_to_cluster("auto"):
198+
return True
199+
logger.info("Starting local cluster...")
174200

175-
def init_ray_for_worker(address: str = "auto") -> bool:
176-
"""
177-
Initialize Ray connection for Celery Worker
178-
179-
Args:
180-
address: Ray cluster address
181-
182-
Returns:
183-
Whether initialization is successful
184-
"""
185-
logger.info("Initialize Ray connection for Celery Worker...")
186-
ray_config.log_configuration()
187-
188-
return ray_config.connect_to_cluster(address)
189-
201+
# Start local cluster
202+
return ray_config.start_local_cluster(
203+
num_cpus=num_cpus,
204+
include_dashboard=include_dashboard,
205+
dashboard_port=dashboard_port
206+
)
190207

191-
def init_ray_for_service(num_cpus: Optional[int] = None,
192-
dashboard_port: int = 8265,
193-
try_connect_first: bool = True) -> bool:
194-
"""
195-
Initialize Ray for data processing service
196-
197-
Args:
198-
num_cpus: Number of CPU cores
199-
dashboard_port: Dashboard port
200-
try_connect_first: Whether to try connecting to existing cluster first
201-
202-
Returns:
203-
Whether initialization is successful
204-
"""
205-
ray_config.log_configuration()
206-
207-
if try_connect_first:
208-
# Try to connect to existing cluster first
209-
logger.debug("Trying to connect to existing Ray cluster...")
210-
if ray_config.connect_to_cluster("auto"):
211-
return True
212-
213-
logger.info("Starting local cluster...")
214-
215-
# Start local cluster
216-
return ray_config.start_local_cluster(
217-
num_cpus=num_cpus,
218-
dashboard_port=dashboard_port
219-
)
208+
# Create a global RayConfig instance accessible throughout the module
209+
ray_config = RayConfig()
220210

backend/data_process/tasks.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def run_async(coro):
7070
logger.warning("nest_asyncio not available, creating new thread for async operation")
7171
# Fallback: run in a new thread
7272
import concurrent.futures
73-
import threading
7473

7574
def run_in_thread():
7675
new_loop = asyncio.new_event_loop()
@@ -97,22 +96,15 @@ def run_in_thread():
9796
# This will be initialized on first task run by a worker process
9897
def get_ray_actor() -> Any:
9998
"""
100-
Creates or gets a handle to the named DataProcessorRayActor.
101-
This is an idempotent operation, safe from race conditions.
99+
Creates a new, anonymous DataProcessorRayActor instance for each call.
100+
This allows for parallel execution of data processing tasks, with each
101+
task running in its own actor.
102102
"""
103103
with ray_init_lock:
104104
init_ray_in_worker()
105-
106-
# Use get_if_exists=True to make this operation idempotent.
107-
# This will create the actor if it doesn't exist, or get a handle to it if it does.
108-
# This is safe to be called from multiple workers concurrently.
109-
actor = DataProcessorRayActor.options(
110-
name="data_processor_actor",
111-
lifetime="detached",
112-
get_if_exists=True
113-
).remote()
105+
actor = DataProcessorRayActor.remote()
114106

115-
logger.debug("Successfully obtained handle for DataProcessorRayActor.")
107+
logger.debug("Successfully created a new DataProcessorRayActor for a task.")
116108
return actor
117109

118110
class LoggingTask(Task):
@@ -160,18 +152,6 @@ def process(self, source: str, source_type: str,
160152

161153
logger.info(f"[{self.request.id}] PROCESS TASK: source_type: {source_type}")
162154

163-
self.update_state(
164-
state=states.PENDING,
165-
meta={
166-
'source': source,
167-
'source_type': source_type,
168-
'index_name': index_name,
169-
'original_filename': original_filename,
170-
'task_name': 'process',
171-
'start_time': start_time
172-
}
173-
)
174-
175155
self.update_state(
176156
state=states.STARTED,
177157
meta={
@@ -473,7 +453,7 @@ async def index_documents():
473453
es_result = run_async(index_documents())
474454
logger.debug(f"[{self.request.id}] FORWARD TASK: API response from main_server for source '{original_source}': {es_result}")
475455

476-
if isinstance(es_result, dict) and es_result.get("success") == True:
456+
if isinstance(es_result, dict) and es_result.get("success"):
477457
total_indexed = es_result.get("total_indexed", 0)
478458
total_submitted = es_result.get("total_submitted", len(formatted_chunks))
479459
logger.debug(f"[{self.request.id}] FORWARD TASK: main_server reported {total_indexed}/{total_submitted} documents indexed successfully for '{original_source}'. Message: {es_result.get('message')}")
@@ -482,7 +462,7 @@ async def index_documents():
482462
logger.info("Value when raise Exception:")
483463
logger.info(f"original_source: {original_source}")
484464
logger.info(f"original_index_name: {original_index_name}")
485-
logger.info(f"task_name: forward")
465+
logger.info("task_name: forward")
486466
logger.info(f"source: {original_source}")
487467
raise Exception(json.dumps({
488468
"message": f"Failure reported by main_server. Expected {total_submitted} chunks, indexed {total_indexed} chunks.",
@@ -491,7 +471,7 @@ async def index_documents():
491471
"source": original_source,
492472
"original_filename": original_filename
493473
}, ensure_ascii=False))
494-
elif isinstance(es_result, dict) and es_result.get("success") == False:
474+
elif isinstance(es_result, dict) and not es_result.get("success"):
495475
error_message = es_result.get("message", "Unknown error from main_server")
496476
raise Exception(json.dumps({
497477
"message": f"main_server API error: {error_message}",

0 commit comments

Comments
 (0)