diff --git a/latencypredictor-v1/Dockerfile-prediction b/latencypredictor-v1/Dockerfile-prediction index 0ec1d9540..6e3015f23 100644 --- a/latencypredictor-v1/Dockerfile-prediction +++ b/latencypredictor-v1/Dockerfile-prediction @@ -6,6 +6,12 @@ WORKDIR /app # Copy the requirements file and install dependencies # (It's good practice to manage dependencies in a requirements.txt file) + + +RUN apt-get update && apt-get install -y \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt diff --git a/latencypredictor-v1/Dockerfile-test b/latencypredictor-v1/Dockerfile-test new file mode 100644 index 000000000..ea2335dc0 --- /dev/null +++ b/latencypredictor-v1/Dockerfile-test @@ -0,0 +1,38 @@ +# Dockerfile-test +FROM python:3.9-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + curl \ + wget \ + jq \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /app + +# Copy requirements and install Python dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Install additional testing dependencies +RUN pip install --no-cache-dir \ + pytest \ + pytest-asyncio \ + requests \ + httpx \ + aiohttp + +# Copy test files +COPY test_dual_server_client.py . + + +# Create test results directory +RUN mkdir -p /test-results + +# Set environment variables +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 + +# Default command runs the specific test +CMD ["pytest", "-v", "-s", "test_dual_server_client.py"] \ No newline at end of file diff --git a/latencypredictor-v1/Dockerfile-training b/latencypredictor-v1/Dockerfile-training index 5767c59af..4a2c2ef14 100644 --- a/latencypredictor-v1/Dockerfile-training +++ b/latencypredictor-v1/Dockerfile-training @@ -6,6 +6,13 @@ WORKDIR /app # Copy the requirements file and install dependencies # (It's good practice to manage dependencies in a requirements.txt file) + + +RUN apt-get update && apt-get install -y \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + + COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt diff --git a/latencypredictor-v1/build-deploy.sh b/latencypredictor-v1/build-deploy.sh index 1531dbb1a..94a3f98f7 100755 --- a/latencypredictor-v1/build-deploy.sh +++ b/latencypredictor-v1/build-deploy.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Build and deploy script for both servers +# Build and deploy script for training, prediction, and test servers set -e @@ -7,8 +7,9 @@ set -e PROJECT_ID="kaushikmitra-gke-dev" REGION="asia-southeast1-c" REPOSITORY="kaushikmitra-docker-repo" -TRAINING_IMAGE="latencypredictor-v1-training-server" -PREDICTION_IMAGE="latencypredictor-v1-prediction-server" +TRAINING_IMAGE="latencypredictor-v3-training-server" +PREDICTION_IMAGE="latencypredictor-v3-prediction-server" +TEST_IMAGE="latencypredictor-v3-test" TAG="latest" # Colors for output @@ -41,7 +42,18 @@ check_files() { fi done - echo_status "All required files found." + # Check for test-specific files + local test_files=("Dockerfile-test") + for file in "${test_files[@]}"; do + if [[ ! -f "$file" ]]; then + echo_warning "Test file $file not found - test image will not be built" + TEST_BUILD_ENABLED=false + return + fi + done + + TEST_BUILD_ENABLED=true + echo_status "All required files found (including test files)." } # Build Docker images @@ -50,7 +62,7 @@ build_images() { # Build training server image echo_status "Building training server image..." - docker build -f Dockerfile-training -t ${TRAINING_IMAGE}:${TAG} . + docker build -f Dockerfile-training -t ${TRAINING_IMAGE}:${TAG} . # Tag for training server docker tag ${TRAINING_IMAGE}:${TAG} \ @@ -64,7 +76,19 @@ build_images() { docker tag ${PREDICTION_IMAGE}:${TAG} \ us-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${PREDICTION_IMAGE}:${TAG} - echo_status "Images built successfully." + # Build test image if enabled + if [[ "$TEST_BUILD_ENABLED" == "true" ]]; then + echo_status "Building test image..." + docker build -f Dockerfile-test -t ${TEST_IMAGE}:${TAG} . + + # Tag for test image + docker tag ${TEST_IMAGE}:${TAG} \ + us-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${TEST_IMAGE}:${TAG} + + echo_status "All images (including test) built successfully." + else + echo_status "Images built successfully (test image skipped)." + fi } # Push images to Artifact Registry @@ -82,7 +106,14 @@ push_images() { echo_status "Pushing prediction server image..." docker push us-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${PREDICTION_IMAGE}:${TAG} - echo_status "Images pushed successfully." + # Push test image if enabled + if [[ "$TEST_BUILD_ENABLED" == "true" ]]; then + echo_status "Pushing test image..." + docker push us-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${TEST_IMAGE}:${TAG} + echo_status "All images (including test) pushed successfully." + else + echo_status "Images pushed successfully (test image skipped)." + fi } # Deploy to GKE @@ -102,6 +133,112 @@ deploy_to_gke() { echo_status "Deployment completed successfully." } +# Deploy test job +deploy_test() { + echo_status "Deploying test job..." + + if [[ "$TEST_BUILD_ENABLED" != "true" ]]; then + echo_warning "Test image not available. Skipping test deployment." + return + fi + + # Check if test manifest exists + if [[ ! -f "test-job.yaml" ]]; then + echo_warning "test-job.yaml not found. Creating a basic test job..." + create_test_manifest + fi + + # Delete existing test job if it exists + kubectl delete job latency-predictor-test --ignore-not-found=true + + # Apply test job + kubectl apply -f test-job.yaml + + echo_status "Test job deployed. Monitor with: kubectl logs -f job/latency-predictor-test" +} + +# Create a basic test manifest +create_test_manifest() { + cat > test-job.yaml << EOF +apiVersion: batch/v1 +kind: Job +metadata: + name: latency-predictor-test + namespace: default + labels: + app: latency-predictor-test + component: test +spec: + template: + metadata: + labels: + app: latency-predictor-test + component: test + spec: + nodeSelector: + cloud.google.com/gke-nodepool: "pool-2" + restartPolicy: Never + containers: + - name: test-runner + image: us-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY}/${TEST_IMAGE}:${TAG} + imagePullPolicy: Always + command: ["pytest"] + args: ["-v", "-s", "test_dual_server_client.py"] + resources: + requests: + cpu: "500m" + memory: "1Gi" + limits: + cpu: "1000m" + memory: "2Gi" + env: + - name: TRAINING_SERVER_URL + value: "http://training-service:8000" + - name: PREDICTION_SERVER_URL + value: "http://prediction-service:80" + - name: TEST_TIMEOUT + value: "300" + volumeMounts: + - name: test-results + mountPath: /test-results + volumes: + - name: test-results + emptyDir: {} + backoffLimit: 3 +EOF + echo_status "Created basic test-job.yaml manifest." +} + +# Run tests +run_tests() { + echo_status "Running tests..." + + if [[ "$TEST_BUILD_ENABLED" != "true" ]]; then + echo_warning "Test image not available. Running basic connectivity tests instead..." + test_deployment + return + fi + + # Deploy and run test job + deploy_test + + # Wait for job completion and show logs + echo_status "Waiting for test job to complete..." + kubectl wait --for=condition=complete job/latency-predictor-test --timeout=600s || { + echo_error "Test job did not complete successfully" + kubectl describe job latency-predictor-test + kubectl logs job/latency-predictor-test + return 1 + } + + echo_status "Test job completed. Showing logs:" + kubectl logs job/latency-predictor-test + + # Clean up test job + echo_status "Cleaning up test job..." + kubectl delete job latency-predictor-test +} + # Get service information get_service_info() { echo_status "Getting service information..." @@ -131,7 +268,7 @@ get_service_info() { kubectl get services } -# Test the deployment +# Test the deployment (basic connectivity tests) test_deployment() { echo_status "Testing deployment..." @@ -165,6 +302,18 @@ test_deployment() { fi } +# List built images +list_images() { + echo_status "Listing built images..." + + echo_status "Local images:" + docker images | grep -E "${TRAINING_IMAGE}|${PREDICTION_IMAGE}|${TEST_IMAGE}" || echo "No local images found" + + echo_status "Remote images in Artifact Registry:" + gcloud artifacts docker images list us-docker.pkg.dev/${PROJECT_ID}/${REPOSITORY} \ + --include-tags --filter="package~(${TRAINING_IMAGE}|${PREDICTION_IMAGE}|${TEST_IMAGE})" || echo "No remote images found" +} + # Cleanup function cleanup() { echo_status "Cleaning up..." @@ -184,15 +333,27 @@ main() { build_images ;; "push") + check_files push_images ;; "deploy") deploy_to_gke ;; + "test-deploy") + check_files + deploy_test + ;; + "test") + check_files + run_tests + ;; "info") get_service_info ;; - "test") + "images") + list_images + ;; + "basic-test") test_deployment ;; "all") @@ -204,17 +365,30 @@ main() { test_deployment cleanup ;; + "full") + check_files + build_images + push_images + deploy_to_gke + get_service_info + run_tests + cleanup + ;; *) - echo "Usage: $0 {check|build|push|deploy|info|test|all}" + echo "Usage: $0 {check|build|push|deploy|test-deploy|test|info|images|basic-test|all|full}" echo "" echo "Commands:" - echo " check - Check if required files exist" - echo " build - Build Docker images" - echo " push - Push images to Artifact Registry" - echo " deploy - Deploy to GKE" - echo " info - Get service information" - echo " test - Test the deployment" - echo " all - Run complete build and deployment process" + echo " check - Check if required files exist" + echo " build - Build Docker images (including test if Dockerfile-test exists)" + echo " push - Push images to Artifact Registry" + echo " deploy - Deploy to GKE" + echo " test-deploy- Deploy test job only" + echo " test - Run comprehensive tests using test image" + echo " info - Get service information" + echo " images - List built images (local and remote)" + echo " basic-test - Run basic connectivity tests" + echo " all - Run complete build and deployment process (no tests)" + echo " full - Run complete process including comprehensive tests" exit 1 ;; esac diff --git a/latencypredictor-v1/manifests/dual-server-deployment.yaml b/latencypredictor-v1/manifests/dual-server-deployment.yaml index f337a538c..ad40d9697 100644 --- a/latencypredictor-v1/manifests/dual-server-deployment.yaml +++ b/latencypredictor-v1/manifests/dual-server-deployment.yaml @@ -84,11 +84,11 @@ spec: component: training spec: nodeSelector: - cloud.google.com/gke-nodepool: "pool-1" + cloud.google.com/gke-nodepool: "pool-2" containers: - name: training-server - image: us-docker.pkg.dev/kaushikmitra-gke-dev/kaushikmitra-docker-repo/latencypredictor-v1-training-server:latest - + image: us-docker.pkg.dev/kaushikmitra-gke-dev/kaushikmitra-docker-repo/latencypredictor-v3-training-server:latest + imagePullPolicy: Always ports: - containerPort: 8000 @@ -145,7 +145,7 @@ metadata: app: prediction-server component: prediction spec: - replicas: 5 + replicas: 10 selector: matchLabels: app: prediction-server @@ -157,10 +157,10 @@ spec: component: prediction spec: nodeSelector: - cloud.google.com/gke-nodepool: "pool-1" + cloud.google.com/gke-nodepool: "pool-2" containers: - name: prediction-server - image: us-docker.pkg.dev/kaushikmitra-gke-dev/kaushikmitra-docker-repo/latencypredictor-v1-prediction-server:latest + image: us-docker.pkg.dev/kaushikmitra-gke-dev/kaushikmitra-docker-repo/latencypredictor-v3-prediction-server:latest imagePullPolicy: Always ports: - containerPort: 8001 diff --git a/latencypredictor-v1/manifests/test-dual-server-deployment.yaml b/latencypredictor-v1/manifests/test-dual-server-deployment.yaml new file mode 100644 index 000000000..6d7361e66 --- /dev/null +++ b/latencypredictor-v1/manifests/test-dual-server-deployment.yaml @@ -0,0 +1,56 @@ +# --- Test Job --- +apiVersion: batch/v1 +kind: Job +metadata: + name: latency-predictor-test + namespace: default + labels: + app: latency-predictor-test + component: test +spec: + template: + metadata: + labels: + app: latency-predictor-test + component: test + spec: + # Use the same node pool as your services + nodeSelector: + cloud.google.com/gke-nodepool: "pool-2" + restartPolicy: Never + containers: + - name: test-runner + # Use your test image here + image: us-docker.pkg.dev/kaushikmitra-gke-dev/kaushikmitra-docker-repo/latencypredictor-v3-test:latest + imagePullPolicy: Always + resources: + requests: + cpu: "500m" + memory: "50Gi" + limits: + cpu: "1000m" + memory: "100Gi" + env: + # Point to your internal services + - name: TRAINING_SERVER_URL + value: "http://training-service:8000" + - name: PREDICTION_SERVER_URL + value: "http://prediction-service:80" + - name: TEST_TIMEOUT + value: "300" # 5 minutes + - name: TARGET_QPS + value: "1000" # Match the target QPS in your test script + envFrom: + - configMapRef: + name: prediction-server-config # Reuse existing config if needed + # Override the default command to run specific pytest + command: ["pytest"] + args: ["-v", "-s", "test_dual_server_client.py"] + # If your tests need to store results or logs + volumeMounts: + - name: test-results + mountPath: /test-results + volumes: + - name: test-results + emptyDir: {} + backoffLimit: 3 # Retry up to 3 times if test fails \ No newline at end of file diff --git a/latencypredictor-v1/prediction_server.py b/latencypredictor-v1/prediction_server.py index 31a6e216e..31ffe07de 100644 --- a/latencypredictor-v1/prediction_server.py +++ b/latencypredictor-v1/prediction_server.py @@ -23,10 +23,17 @@ XGBOOST_AVAILABLE = False logging.warning("XGBoost not available. Install with: pip install xgboost") - +try: + import lightgbm as lgb + LIGHTGBM_AVAILABLE = True +except ImportError: + LIGHTGBM_AVAILABLE = False + logging.warning("LightGBM not available. Install with: pip install lightgbm") + class ModelType(str, Enum): BAYESIAN_RIDGE = "bayesian_ridge" XGBOOST = "xgboost" + LIGHTGBM = "lightgbm" class PredictSettings: @@ -168,9 +175,15 @@ class LightweightPredictor: def __init__(self): mt = settings.MODEL_TYPE + + # Add LightGBM fallback logic if mt == ModelType.XGBOOST and not XGBOOST_AVAILABLE: - logging.warning("Falling back to Bayesian Ridge") + logging.warning("XGBoost not available. Falling back to Bayesian Ridge") mt = ModelType.BAYESIAN_RIDGE + elif mt == ModelType.LIGHTGBM and not LIGHTGBM_AVAILABLE: + logging.warning("LightGBM not available. Falling back to Bayesian Ridge") + mt = ModelType.BAYESIAN_RIDGE + self.model_type = mt self.quantile = settings.QUANTILE_ALPHA self.ttft_model = None @@ -186,7 +199,9 @@ def is_ready(self) -> bool: with self.lock: if self.model_type == ModelType.BAYESIAN_RIDGE: return all([self.ttft_model, self.tpot_model, self.ttft_scaler, self.tpot_scaler]) - return all([self.ttft_model, self.tpot_model]) + else: # XGBoost or LightGBM + return all([self.ttft_model, self.tpot_model]) + def load_models(self) -> bool: try: @@ -213,14 +228,15 @@ def load_models(self) -> bool: logging.error(f"Load error: {e}") return False - def predict(self, features: dict) -> Tuple[float, float, float, float]: +# 4. Update predict method to handle LightGBM + def predict(self, features: dict) -> Tuple[float, float]: """Make quantile predictions using the loaded models.""" try: with self.lock: if not self.is_ready: raise HTTPException(status_code=503, detail="Models not ready") - # Updated required features to include prefix_cache_score + # Validation remains the same required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score'] for f in required: if f not in features: @@ -228,7 +244,7 @@ def predict(self, features: dict) -> Tuple[float, float, float, float]: if not isinstance(features[f], (int, float)): raise ValueError(f"Invalid type for feature {f}: expected number") - # Updated TTFT features to include prefix_cache_score + # Feature columns remain the same ttft_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','prefix_cache_score'] tpot_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','num_tokens_generated'] @@ -237,33 +253,32 @@ def predict(self, features: dict) -> Tuple[float, float, float, float]: df_tpot = pd.DataFrame([{col: features[col] for col in tpot_cols}]) if self.model_type == ModelType.BAYESIAN_RIDGE: - # Use scaling for Bayesian Ridge + # Bayesian Ridge logic (unchanged) ttft_scaled = self.ttft_scaler.transform(df_ttft) tpot_scaled = self.tpot_scaler.transform(df_tpot) ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True) tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True) - - # Approximate quantile prediction by adding factor to mean - # This matches the logic in the training server + std_factor = 1.28 if self.quantile == 0.9 else (2.0 if self.quantile == 0.95 else 0.674) ttft_pred = ttft_pred_mean[0] + std_factor * ttft_std[0] tpot_pred = tpot_pred_mean[0] + std_factor * tpot_std[0] - return ttft_pred, tpot_pred, ttft_std[0], tpot_std[0] + return ttft_pred, tpot_pred - else: # XGBoost with true quantile regression - # XGBoost quantile regression directly predicts the quantile + elif self.model_type == ModelType.XGBOOST: + # XGBoost logic (unchanged) ttft_pred = self.ttft_model.predict(df_ttft) tpot_pred = self.tpot_model.predict(df_tpot) - # For XGBoost quantile regression, uncertainty estimation is more complex - # We'll use a simple heuristic based on the quantile value and prediction - # This is a rough approximation - ideally you'd train additional models for uncertainty - ttft_std = ttft_pred[0] * 0.15 # 15% of prediction as uncertainty estimate - tpot_std = tpot_pred[0] * 0.15 + return ttft_pred[0], tpot_pred[0] + + else: # LightGBM - NEW + # LightGBM quantile regression directly predicts the quantile + ttft_pred = self.ttft_model.predict(df_ttft) + tpot_pred = self.tpot_model.predict(df_tpot) - return ttft_pred[0], tpot_pred[0], ttft_std, tpot_std + return ttft_pred[0], tpot_pred[0] except ValueError as ve: logging.warning(f"Client error in predict(): {ve}") @@ -274,6 +289,70 @@ def predict(self, features: dict) -> Tuple[float, float, float, float]: logging.error("Error in predict():", exc_info=True) raise HTTPException(status_code=500, detail="Internal error during prediction") +# 5. Update predict_batch method to handle LightGBM + def predict_batch(self, features_list: List[dict]) -> Tuple[np.ndarray, np.ndarray]: + """Make batch quantile predictions using the loaded models.""" + try: + with self.lock: + if not self.is_ready: + raise HTTPException(status_code=503, detail="Models not ready") + + # Validation logic remains the same + required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score'] + for i, features in enumerate(features_list): + for f in required: + if f not in features: + raise ValueError(f"Missing required feature '{f}' in request {i}") + if not isinstance(features[f], (int, float)): + raise ValueError(f"Invalid type for feature '{f}' in request {i}: expected number") + + # Feature columns and DataFrame creation remains the same + ttft_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','prefix_cache_score'] + tpot_cols = ['kv_cache_percentage','input_token_length','num_request_waiting','num_request_running','num_tokens_generated'] + + ttft_data = [{col: features[col] for col in ttft_cols} for features in features_list] + tpot_data = [{col: features[col] for col in tpot_cols} for features in features_list] + + df_ttft_batch = pd.DataFrame(ttft_data) + df_tpot_batch = pd.DataFrame(tpot_data) + + if self.model_type == ModelType.BAYESIAN_RIDGE: + # Bayesian Ridge logic (unchanged) + ttft_scaled = self.ttft_scaler.transform(df_ttft_batch) + tpot_scaled = self.tpot_scaler.transform(df_tpot_batch) + + ttft_pred_mean, ttft_std = self.ttft_model.predict(ttft_scaled, return_std=True) + tpot_pred_mean, tpot_std = self.tpot_model.predict(tpot_scaled, return_std=True) + + std_factor = 1.28 if self.quantile == 0.9 else (2.0 if self.quantile == 0.95 else 0.674) + ttft_pred = ttft_pred_mean + std_factor * ttft_std + tpot_pred = tpot_pred_mean + std_factor * tpot_std + + return ttft_pred, tpot_pred + + elif self.model_type == ModelType.XGBOOST: + # XGBoost logic (unchanged) + ttft_pred = self.ttft_model.predict(df_ttft_batch) + tpot_pred = self.tpot_model.predict(df_tpot_batch) + + return ttft_pred, tpot_pred + + else: # LightGBM - NEW + # LightGBM quantile regression directly predicts the quantile + ttft_pred = self.ttft_model.predict(df_ttft_batch) + tpot_pred = self.tpot_model.predict(df_tpot_batch) + + return ttft_pred, tpot_pred + + except ValueError as ve: + logging.warning(f"Client error in predict_batch(): {ve}") + raise HTTPException(status_code=400, detail=str(ve)) + except HTTPException: + raise + except Exception as e: + logging.error("Error in predict_batch():", exc_info=True) + raise HTTPException(status_code=500, detail="Internal error during batch prediction") + # Instantiate model_syncer = ModelSyncer() @@ -300,10 +379,6 @@ class PredictionRequest(BaseModel): class PredictionResponse(BaseModel): ttft_ms: float = Field(..., description=f"Predicted {settings.QUANTILE_ALPHA:.0%} quantile TTFT in milliseconds") tpot_ms: float = Field(..., description=f"Predicted {settings.QUANTILE_ALPHA:.0%} quantile TPOT in milliseconds") - ttft_uncertainty: float = Field(..., description="Uncertainty estimate for TTFT prediction") - tpot_uncertainty: float = Field(..., description="Uncertainty estimate for TPOT prediction") - ttft_prediction_bounds: Tuple[float, float] = Field(..., description="Approximate prediction bounds for TTFT") - tpot_prediction_bounds: Tuple[float, float] = Field(..., description="Approximate prediction bounds for TPOT") predicted_at: datetime model_type: str = Field(..., description="Type of model used for prediction") quantile: float = Field(..., description="Quantile being predicted") @@ -319,9 +394,8 @@ class StatusResponse(BaseModel): models_exist: dict - class BulkPredictionRequest(BaseModel): - requests: List[PredictionRequest] = Field(..., min_items=1, max_items=100, description="List of prediction requests (max 100)") + requests: List[PredictionRequest] = Field(..., min_items=1, max_items=10000, description="List of prediction requests (max 10000)") class BulkPredictionResponse(BaseModel): predictions: List[PredictionResponse] = Field(..., description="List of prediction responses") @@ -373,24 +447,15 @@ async def status_endpoint(): async def predict_endpoint(request: PredictionRequest): """Make quantile latency predictions.""" try: - ttft_pred, tpot_pred, ttft_std, tpot_std = predictor.predict(request.dict()) + ttft_pred, tpot_pred = predictor.predict(request.dict()) # Ensure non-negative predictions ttft_pred = max(0, ttft_pred) tpot_pred = max(0, tpot_pred) - # Calculate approximate confidence bounds - # For quantile predictions, these represent uncertainty around the quantile estimate - ttft_bounds = (max(0, ttft_pred - 2*ttft_std), ttft_pred + 2*ttft_std) - tpot_bounds = (max(0, tpot_pred - 2*tpot_std), tpot_pred + 2*tpot_std) - return PredictionResponse( ttft_ms=ttft_pred, tpot_ms=tpot_pred, - ttft_uncertainty=ttft_std, - tpot_uncertainty=tpot_std, - ttft_prediction_bounds=ttft_bounds, - tpot_prediction_bounds=tpot_bounds, predicted_at=datetime.now(timezone.utc), model_type=predictor.model_type.value, quantile=predictor.quantile, @@ -401,105 +466,33 @@ async def predict_endpoint(request: PredictionRequest): except Exception as e: logging.error(f"Prediction failed: {e}") raise HTTPException(status_code=500, detail="An internal error occurred during prediction") - - -# Add this endpoint after the existing predict endpoint -@app.post("/predict/bulk", response_model=BulkPredictionResponseWithErrors) -async def predict_bulk_endpoint(request: BulkPredictionRequest): - """Make bulk quantile latency predictions.""" - start_time = time.time() - - predictions = [] - errors = [] - successful_count = 0 - failed_count = 0 - - for i, pred_request in enumerate(request.requests): - try: - ttft_pred, tpot_pred, ttft_std, tpot_std = predictor.predict(pred_request.dict()) - - # Ensure non-negative predictions - ttft_pred = max(0, ttft_pred) - tpot_pred = max(0, tpot_pred) - - # Calculate approximate confidence bounds - ttft_bounds = (max(0, ttft_pred - 2*ttft_std), ttft_pred + 2*ttft_std) - tpot_bounds = (max(0, tpot_pred - 2*tpot_std), tpot_pred + 2*tpot_std) - - prediction_response = PredictionResponse( - ttft_ms=ttft_pred, - tpot_ms=tpot_pred, - ttft_uncertainty=ttft_std, - tpot_uncertainty=tpot_std, - ttft_prediction_bounds=ttft_bounds, - tpot_prediction_bounds=tpot_bounds, - predicted_at=datetime.now(timezone.utc), - model_type=predictor.model_type.value, - quantile=predictor.quantile, - last_model_load=predictor.last_load - ) - - predictions.append(prediction_response) - successful_count += 1 - - except HTTPException as he: - predictions.append(None) - errors.append(BulkPredictionError( - index=i, - error=he.detail, - request=pred_request - )) - failed_count += 1 - - except Exception as e: - predictions.append(None) - errors.append(BulkPredictionError( - index=i, - error=f"Internal error: {str(e)}", - request=pred_request - )) - failed_count += 1 - - processing_time_ms = (time.time() - start_time) * 1000 - - return BulkPredictionResponseWithErrors( - predictions=predictions, - errors=errors, - total_requests=len(request.requests), - successful_predictions=successful_count, - failed_predictions=failed_count, - processing_time_ms=processing_time_ms - ) -# Optional: Add a simpler bulk endpoint that fails fast on any error @app.post("/predict/bulk/strict", response_model=BulkPredictionResponse) async def predict_bulk_strict_endpoint(request: BulkPredictionRequest): - """Make bulk quantile latency predictions (fails on any single error).""" + """Make bulk quantile latency predictions using batch processing (fails on any single error).""" start_time = time.time() - predictions = [] - try: - for pred_request in request.requests: - ttft_pred, tpot_pred, ttft_std, tpot_std = predictor.predict(pred_request.dict()) - + # Convert all requests to dict format + features_list = [pred_request.dict() for pred_request in request.requests] + + # Make batch prediction + ttft_preds, tpot_preds = predictor.predict_batch(features_list) + + # Build response list + predictions = [] + current_time = datetime.now(timezone.utc) + + for i in range(len(request.requests)): # Ensure non-negative predictions - ttft_pred = max(0, ttft_pred) - tpot_pred = max(0, tpot_pred) - - # Calculate approximate confidence bounds - ttft_bounds = (max(0, ttft_pred - 2*ttft_std), ttft_pred + 2*ttft_std) - tpot_bounds = (max(0, tpot_pred - 2*tpot_std), tpot_pred + 2*tpot_std) + ttft_pred = max(0, ttft_preds[i]) + tpot_pred = max(0, tpot_preds[i]) prediction_response = PredictionResponse( ttft_ms=ttft_pred, tpot_ms=tpot_pred, - ttft_uncertainty=ttft_std, - tpot_uncertainty=tpot_std, - ttft_prediction_bounds=ttft_bounds, - tpot_prediction_bounds=tpot_bounds, - predicted_at=datetime.now(timezone.utc), + predicted_at=current_time, model_type=predictor.model_type.value, quantile=predictor.quantile, last_model_load=predictor.last_load @@ -523,6 +516,94 @@ async def predict_bulk_strict_endpoint(request: BulkPredictionRequest): logging.error(f"Bulk prediction failed: {e}") raise HTTPException(status_code=500, detail="Bulk prediction failed") + +@app.post("/predict/bulk", response_model=BulkPredictionResponseWithErrors) +async def predict_bulk_endpoint(request: BulkPredictionRequest): + """Make bulk quantile latency predictions using batch processing with error handling.""" + start_time = time.time() + + # Separate valid and invalid requests + valid_requests = [] + valid_indices = [] + errors = [] + + # Pre-validate all requests + required = ['kv_cache_percentage', 'input_token_length', 'num_request_waiting', 'num_request_running', 'num_tokens_generated', 'prefix_cache_score'] + + for i, pred_request in enumerate(request.requests): + try: + features = pred_request.dict() + # Validate features + for f in required: + if f not in features: + raise ValueError(f"Missing required feature: {f}") + if not isinstance(features[f], (int, float)): + raise ValueError(f"Invalid type for feature {f}: expected number") + + valid_requests.append(features) + valid_indices.append(i) + + except Exception as e: + errors.append(BulkPredictionError( + index=i, + error=str(e), + request=pred_request + )) + + # Initialize predictions list with None values + predictions = [None] * len(request.requests) + successful_count = len(valid_requests) + failed_count = len(errors) + + # Process valid requests in batch if any exist + if valid_requests: + try: + # Make batch prediction for all valid requests + ttft_preds, tpot_preds = predictor.predict_batch(valid_requests) + + current_time = datetime.now(timezone.utc) + + # Fill in predictions for valid requests + for batch_idx, original_idx in enumerate(valid_indices): + # Ensure non-negative predictions + ttft_pred = max(0, ttft_preds[batch_idx]) + tpot_pred = max(0, tpot_preds[batch_idx]) + + prediction_response = PredictionResponse( + ttft_ms=ttft_pred, + tpot_ms=tpot_pred, + predicted_at=current_time, + model_type=predictor.model_type.value, + quantile=predictor.quantile, + last_model_load=predictor.last_load + ) + + predictions[original_idx] = prediction_response + + except Exception as e: + # If batch prediction fails, mark all valid requests as failed + for original_idx in valid_indices: + errors.append(BulkPredictionError( + index=original_idx, + error=f"Batch prediction error: {str(e)}", + request=request.requests[original_idx] + )) + predictions[original_idx] = None + + successful_count = 0 + failed_count = len(request.requests) + + processing_time_ms = (time.time() - start_time) * 1000 + + return BulkPredictionResponseWithErrors( + predictions=predictions, + errors=errors, + total_requests=len(request.requests), + successful_predictions=successful_count, + failed_predictions=failed_count, + processing_time_ms=processing_time_ms + ) + @app.post("/reload") async def reload_models(): """Manually trigger model reload.""" diff --git a/latencypredictor-v1/requirements.txt b/latencypredictor-v1/requirements.txt index 6014c2d71..6e31838f5 100644 --- a/latencypredictor-v1/requirements.txt +++ b/latencypredictor-v1/requirements.txt @@ -8,4 +8,5 @@ river pydantic requests xgboost -aiohttp \ No newline at end of file +aiohttp +lightgbm \ No newline at end of file diff --git a/latencypredictor-v1/test_dual_server_client.py b/latencypredictor-v1/test_dual_server_client.py index b36cf7f8a..86e277153 100644 --- a/latencypredictor-v1/test_dual_server_client.py +++ b/latencypredictor-v1/test_dual_server_client.py @@ -6,21 +6,18 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from collections import defaultdict import random - -import pytest import requests - +import pytest import joblib import numpy as np import tempfile -import xgboost # Base URLs for the dual-server architecture -# Base URLs for the dual-server architecture -PREDICTION_URL = os.getenv("PREDICTION_SERVER_URL", "http://") # Update this -TRAINING_URL = os.getenv("TRAINING_SERVER_URL", "http://:8080") # Update this - +PREDICTION_URL = os.getenv("PREDICTION_SERVER_URL", "http://34.158.41.245:80") # Update this +TRAINING_URL = os.getenv("TRAINING_SERVER_URL", "http://34.143.208.0:8080") # Update this +TARGET_QPS = float(os.getenv("TARGET_QPS", 1000)) # Update this +TARGET_QPS_LARGE_BATCH = float(os.getenv("TARGET_QPS_LARGE_BATCH", 100)) # Update this # Helper to wait until the servers are ready def wait_for_ready(url: str, timeout: float = 30.0, interval: float = 1.0): start = time.time() @@ -81,9 +78,12 @@ def test_prediction_server_status(): assert "is_ready" in data assert "model_type" in data assert "models_exist" in data - assert data["model_type"] in ["bayesian_ridge", "xgboost"] + assert "quantile" in data + assert data["model_type"] in ["bayesian_ridge", "xgboost", "lightgbm"] + assert 0 < data["quantile"] <= 1.0 print(f"Prediction server using model type: {data['model_type']}") + print(f"Quantile: {data['quantile']}") print(f"Models ready: {data['is_ready']}") print(f"Models exist: {data['models_exist']}") @@ -96,7 +96,7 @@ def test_training_server_model_info(): data = r.json() assert "model_type" in data assert "available_endpoints" in data - assert data["model_type"] in ["bayesian_ridge", "xgboost"] + assert data["model_type"] in ["bayesian_ridge", "xgboost", "lightgbm"] print(f"Training server using model type: {data['model_type']}") @@ -162,7 +162,66 @@ def test_model_download_from_training_server(): continue time.sleep(2) # Wait before retry - +def test_lightgbm_endpoints_on_training_server(): + """Test LightGBM endpoints on training server if LightGBM is being used.""" + model_info_r = requests.get(f"{TRAINING_URL}/model/download/info") + model_type = model_info_r.json().get("model_type") + + if model_type != "lightgbm": + print("Skipping LightGBM endpoint tests - not using LightGBM model") + return + + print("Testing LightGBM endpoints on training server...") + + # Test TTFT model text format + ttft_txt_response = requests.get(f"{TRAINING_URL}/model/ttft/lgb/txt") + if ttft_txt_response.status_code == 200: + print("✓ TTFT LightGBM text model available") + assert ttft_txt_response.headers.get('content-type') == 'text/plain; charset=utf-8' + else: + print(f"TTFT LightGBM text model not yet available (status: {ttft_txt_response.status_code})") + + # Test TPOT model text format + tpot_txt_response = requests.get(f"{TRAINING_URL}/model/tpot/lgb/txt") + if tpot_txt_response.status_code == 200: + print("✓ TPOT LightGBM text model available") + assert tpot_txt_response.headers.get('content-type') == 'text/plain; charset=utf-8' + else: + print(f"TPOT LightGBM text model not yet available (status: {tpot_txt_response.status_code})") + + # Test TTFT feature importances + ttft_imp_response = requests.get(f"{TRAINING_URL}/model/ttft/lgb/importances") + if ttft_imp_response.status_code == 200: + ttft_importances = ttft_imp_response.json() + assert isinstance(ttft_importances, dict), "TTFT importances should be a dict" + + # Check for expected features including prefix_cache_score + expected_features = ["kv_cache_percentage", "input_token_length", "num_request_waiting", + "num_request_running", "prefix_cache_score"] + for feature in expected_features: + assert feature in ttft_importances, f"Missing feature importance: {feature}" + + print(f"✓ TTFT LightGBM importances available with {len(ttft_importances)} features") + else: + print(f"TTFT LightGBM importances not yet available (status: {ttft_imp_response.status_code})") + + # Test TPOT feature importances + tpot_imp_response = requests.get(f"{TRAINING_URL}/model/tpot/lgb/importances") + if tpot_imp_response.status_code == 200: + tpot_importances = tpot_imp_response.json() + assert isinstance(tpot_importances, dict), "TPOT importances should be a dict" + + # Check for expected features + expected_features = ["kv_cache_percentage", "input_token_length", "num_request_waiting", + "num_request_running", "num_tokens_generated"] + for feature in expected_features: + assert feature in tpot_importances, f"Missing feature importance: {feature}" + + print(f"✓ TPOT LightGBM importances available with {len(tpot_importances)} features") + else: + print(f"TPOT LightGBM importances not yet available (status: {tpot_imp_response.status_code})") + + def test_add_training_data_to_training_server(): """ Send training data to the training server. @@ -248,8 +307,7 @@ def test_prediction_via_prediction_server(): data = r.json() required_fields = [ - "ttft_ms", "tpot_ms", "ttft_uncertainty", "tpot_uncertainty", - "ttft_prediction_bounds", "tpot_prediction_bounds", + "ttft_ms", "tpot_ms", "predicted_at", "model_type", "last_model_load" ] @@ -259,13 +317,142 @@ def test_prediction_via_prediction_server(): # Verify predictions are reasonable assert data["ttft_ms"] > 0 assert data["tpot_ms"] > 0 - assert data["ttft_uncertainty"] >= 0 - assert data["tpot_uncertainty"] >= 0 + #assert data["ttft_uncertainty"] >= 0 + #assert data["tpot_uncertainty"] >= 0 print(f"Prediction successful: TTFT={data['ttft_ms']:.2f}ms, TPOT={data['tpot_ms']:.2f}ms") print(f"Model type: {data['model_type']}") +def test_bulk_prediction_strict(): + """Test bulk predictions with strict error handling.""" + print("Testing bulk prediction strict endpoint...") + + requests_data = [ + { + "kv_cache_percentage": 0.5, + "input_token_length": 200, + "num_request_waiting": 4, + "num_request_running": 1, + "num_tokens_generated": 4, + "prefix_cache_score": 0.7, + }, + { + "kv_cache_percentage": 0.3, + "input_token_length": 150, + "num_request_waiting": 2, + "num_request_running": 1, + "num_tokens_generated": 5, + "prefix_cache_score": 0.5, + } + ] + + bulk_request = {"requests": requests_data} + + r = requests.post(f"{PREDICTION_URL}/predict/bulk/strict", json=bulk_request) + assert r.status_code == 200 + + data = r.json() + + # Check bulk response structure + assert "predictions" in data + assert "total_requests" in data + assert "successful_predictions" in data + assert "failed_predictions" in data + assert "processing_time_ms" in data + + assert len(data["predictions"]) == 2 + assert data["total_requests"] == 2 + assert data["successful_predictions"] == 2 + assert data["failed_predictions"] == 0 + + # Check individual prediction structure + for prediction in data["predictions"]: + assert "ttft_ms" in prediction + assert "tpot_ms" in prediction + #assert "ttft_uncertainty" in prediction + #assert "tpot_uncertainty" in prediction + #assert "ttft_prediction_bounds" in prediction + #assert "tpot_prediction_bounds" in prediction + assert "predicted_at" in prediction + assert "model_type" in prediction + assert "quantile" in prediction + + print("✓ Bulk prediction strict endpoint test passed") + + +def test_bulk_prediction_with_validation_errors(): + """Test that bulk predictions fail completely when any request has validation errors.""" + print("Testing bulk prediction validation error handling...") + + requests_data = [ + # Valid request + { + "kv_cache_percentage": 0.5, + "input_token_length": 200, + "num_request_waiting": 4, + "num_request_running": 1, + "num_tokens_generated": 4, + "prefix_cache_score": 0.7, + }, + # Invalid request (missing prefix_cache_score) + { + "kv_cache_percentage": 0.3, + "input_token_length": 150, + "num_request_waiting": 2, + "num_request_running": 1, + "num_tokens_generated": 5, + # Missing prefix_cache_score + } + ] + + bulk_request = {"requests": requests_data} + + r = requests.post(f"{PREDICTION_URL}/predict/bulk", json=bulk_request) + assert r.status_code == 422 # Validation error expected + + # Check that error response contains validation details + error_data = r.json() + assert "detail" in error_data + + print("✓ Bulk prediction correctly failed when any request had validation errors") + + +def test_bulk_prediction_all_valid(): + """Test bulk predictions when all requests are valid.""" + print("Testing bulk prediction with all valid requests...") + + requests_data = [ + { + "kv_cache_percentage": 0.5, + "input_token_length": 200, + "num_request_waiting": 4, + "num_request_running": 1, + "num_tokens_generated": 4, + "prefix_cache_score": 0.7, + }, + { + "kv_cache_percentage": 0.3, + "input_token_length": 150, + "num_request_waiting": 2, + "num_request_running": 1, + "num_tokens_generated": 5, + "prefix_cache_score": 0.5, # Include required field + } + ] + + bulk_request = {"requests": requests_data} + + r = requests.post(f"{PREDICTION_URL}/predict/bulk", json=bulk_request) + assert r.status_code == 200 + + data = r.json() + assert data["total_requests"] == 2 + assert data["successful_predictions"] == 2 + assert data["failed_predictions"] == 0 + + print("✓ Bulk prediction succeeded with all valid requests") + def test_prediction_missing_prefix_cache_score(): """Test that predictions fail when prefix_cache_score is missing.""" features = { @@ -329,34 +516,38 @@ def test_model_consistency_between_servers(): print(f"Model type consistent across servers: {training_model_type}") -def test_xgboost_tree_endpoints_on_training_server(): - """Test XGBoost tree endpoints on training server if XGBoost is being used.""" +# 6. Update test_xgboost_tree_endpoints_on_training_server function name and add both +def test_model_specific_endpoints_on_training_server(): + """Test model-specific endpoints on training server based on model type.""" model_info_r = requests.get(f"{TRAINING_URL}/model/download/info") model_type = model_info_r.json().get("model_type") - if model_type != "xgboost": - print("Skipping XGBoost tree tests - not using XGBoost model") - return - - print("Testing XGBoost tree endpoints on training server...") + if model_type == "xgboost": + print("Testing XGBoost tree endpoints on training server...") + + # Test TTFT trees + ttft_response = requests.get(f"{TRAINING_URL}/model/ttft/xgb/json") + if ttft_response.status_code == 200: + ttft_trees = ttft_response.json() + assert isinstance(ttft_trees, list), "TTFT trees should be a list" + print(f"✓ TTFT XGBoost trees available: {len(ttft_trees)} trees") + else: + print(f"TTFT XGBoost trees not yet available (status: {ttft_response.status_code})") + + # Test TPOT trees + tpot_response = requests.get(f"{TRAINING_URL}/model/tpot/xgb/json") + if tpot_response.status_code == 200: + tpot_trees = tpot_response.json() + assert isinstance(tpot_trees, list), "TPOT trees should be a list" + print(f"✓ TPOT XGBoost trees available: {len(tpot_trees)} trees") + else: + print(f"TPOT XGBoost trees not yet available (status: {tpot_response.status_code})") + + elif model_type == "lightgbm": + test_lightgbm_endpoints_on_training_server() - # Test TTFT trees - ttft_response = requests.get(f"{TRAINING_URL}/model/ttft/xgb/json") - if ttft_response.status_code == 200: - ttft_trees = ttft_response.json() - assert isinstance(ttft_trees, list), "TTFT trees should be a list" - print(f"✓ TTFT XGBoost trees available: {len(ttft_trees)} trees") - else: - print(f"TTFT XGBoost trees not yet available (status: {ttft_response.status_code})") - - # Test TPOT trees - tpot_response = requests.get(f"{TRAINING_URL}/model/tpot/xgb/json") - if tpot_response.status_code == 200: - tpot_trees = tpot_response.json() - assert isinstance(tpot_trees, list), "TPOT trees should be a list" - print(f"✓ TPOT XGBoost trees available: {len(tpot_trees)} trees") else: - print(f"TPOT XGBoost trees not yet available (status: {tpot_response.status_code})") + print(f"No model-specific endpoints to test for {model_type}") async def async_predict_request(session, payload, request_id): @@ -385,500 +576,34 @@ async def async_predict_request(session, payload, request_id): 'model_type': None } -def test_dual_server_model_learns_equation(): - """ - Test that the dual-server architecture can learn equations end-to-end. - Updated with more robust training and validation. - """ - print("Testing dual-server end-to-end learning with prefix cache score...") - - # Step 1: Get current model type from training server - model_info_r = requests.get(f"{TRAINING_URL}/model/download/info") - assert model_info_r.status_code == 200 - model_type = model_info_r.json().get("model_type", "unknown") - print(f"Training server model type: {model_type}") - - # Step 2: Generate more training data with stronger signal - print("Step 1: Generating training data with known pattern (including prefix cache)...") - entries = [] - - # Generate 1000 training samples with clearer patterns and less noise - for i in range(1, 1001): - kv = random.uniform(0.1, 0.9) - input_len = random.randint(50, 1000) # Reduced range for clearer signal - waiting = random.randint(0, 10) # Reduced range - running = random.randint(1, 5) # Reduced range - tokens_gen = random.randint(1, 30) # Reduced range - prefix_cache = random.uniform(0.0, 1.0) - - # Reduced noise for clearer signal - noise_ttft = random.uniform(-2, 2) # Reduced noise - noise_tpot = random.uniform(-1, 1) # Reduced noise - - # Updated TTFT equation - actual_ttft = ( - input_len * 2.0 - + waiting * 3.0 - + running * 4.0 - + kv * 50.0 - + prefix_cache * 30.0 - + 95 - ) + noise_ttft - - # TPOT equation (no prefix cache) - actual_tpot = ( - kv * 100.0 - + input_len * 0.5 - + tokens_gen * 1.0 - + running * 5.0 - + 9 - ) + noise_tpot - - entries.append({ - "kv_cache_percentage": kv, - "input_token_length": input_len, - "num_request_waiting": waiting, - "num_request_running": running, - "actual_ttft_ms": max(1.0, actual_ttft), - "actual_tpot_ms": max(1.0, actual_tpot), - "num_tokens_generated": tokens_gen, - "prefix_cache_score": prefix_cache, - }) - - # Step 3: Send training data to training server - print(f"Step 2: Sending {len(entries)} training samples to training server...") - payload = {"entries": entries} - training_r = requests.post(f"{TRAINING_URL}/add_training_data_bulk", json=payload, timeout=60) - assert training_r.status_code == 202, f"Training data rejected: {training_r.status_code}" - print(f"✓ Training server accepted {len(entries)} samples") - - # Step 4: Wait longer for training to complete - print("Step 3: Waiting for training server to retrain models...") - training_deadline = time.time() + 180 # 3 minutes max wait for training - - while time.time() < training_deadline: - try: - metrics_r = requests.get(f"{TRAINING_URL}/metrics", timeout=10) - if metrics_r.status_code == 200: - metrics = metrics_r.text - if "ttft_r2_score" in metrics and "tpot_r2_score" in metrics: - print("✓ Training server has R² metrics - training likely completed") - break - except: - pass - - print(" Waiting for training to complete...") - time.sleep(15) # Check less frequently - - # Step 5: Trigger prediction server to sync models multiple times - print("Step 4: Syncing models to prediction server...") - sync_deadline = time.time() + 90 # 1.5 minutes max for model sync - models_synced = False - - while time.time() < sync_deadline and not models_synced: - try: - reload_r = requests.post(f"{PREDICTION_URL}/reload", timeout=20) - if reload_r.status_code == 200: - reload_data = reload_r.json() - if reload_data.get("is_ready"): - print("✓ Prediction server models are ready") - models_synced = True - break - except Exception as e: - print(f" Sync attempt failed: {e}") - - if not models_synced: - print(" Waiting for model sync...") - time.sleep(8) - - assert models_synced, "Prediction server failed to sync models within timeout" - - # Step 6: Test predictions with more relaxed tolerance initially - print("Step 5: Testing that predictions match learned equations...") - - # Use simpler test cases with more predictable values - test_cases = [ - { - "kv_cache_percentage": 0.5, - "input_token_length": 100, - "num_request_waiting": 2, - "num_request_running": 1, - "num_tokens_generated": 10, - "prefix_cache_score": 0.5, - }, - { - "kv_cache_percentage": 0.3, - "input_token_length": 200, - "num_request_waiting": 4, - "num_request_running": 2, - "num_tokens_generated": 15, - "prefix_cache_score": 0.8, - }, - ] - - # More relaxed tolerance, especially for XGBoost - tolerance = 0.25 if model_type == "xgboost" else 0.15 # Increased tolerance - all_predictions_correct = True - - for i, test_case in enumerate(test_cases): - # Calculate expected values - expected_ttft = ( - test_case["input_token_length"] * 2.0 - + test_case["num_request_waiting"] * 3.0 - + test_case["num_request_running"] * 4.0 - + test_case["kv_cache_percentage"] * 50.0 - + test_case["prefix_cache_score"] * 30.0 - + 95 - ) - - expected_tpot = ( - test_case["kv_cache_percentage"] * 100.0 - + test_case["input_token_length"] * 0.5 - + test_case["num_tokens_generated"] * 1.0 - + test_case["num_request_running"] * 5.0 - + 9 - ) - - # Make prediction via prediction server - pred_r = requests.post(f"{PREDICTION_URL}/predict", json=test_case, timeout=15) - assert pred_r.status_code == 200, f"Prediction failed for test case {i+1}" - - pred_data = pred_r.json() - actual_ttft = pred_data["ttft_ms"] - actual_tpot = pred_data["tpot_ms"] - - # Check if predictions are within tolerance - ttft_error = abs(actual_ttft - expected_ttft) / expected_ttft - tpot_error = abs(actual_tpot - expected_tpot) / expected_tpot - - ttft_ok = ttft_error <= tolerance - tpot_ok = tpot_error <= tolerance - - print(f" Test case {i+1} (prefix_cache={test_case['prefix_cache_score']}):") - print(f" TTFT: expected={expected_ttft:.1f}, actual={actual_ttft:.1f}, error={ttft_error*100:.1f}% {'✓' if ttft_ok else '✗'}") - print(f" TPOT: expected={expected_tpot:.1f}, actual={actual_tpot:.1f}, error={tpot_error*100:.1f}% {'✓' if tpot_ok else '✗'}") - - if not (ttft_ok and tpot_ok): - all_predictions_correct = False - - # If still failing, provide detailed diagnostics - if not all_predictions_correct: - print(f"❌ Model learning test failed with {tolerance*100:.0f}% tolerance") - print("🔍 Diagnostic information:") - - # Check if the model is learning anything at all - try: - metrics_r = requests.get(f"{TRAINING_URL}/metrics") - if metrics_r.status_code == 200: - metrics = metrics_r.text - r2_lines = [line for line in metrics.split('\n') if 'r2_score' in line] - if r2_lines: - print(" R² scores from training server:") - for line in r2_lines[:4]: - print(f" {line}") - except: - pass - - # Test if prefix cache has any impact at all - try: - low_cache_test = {**test_cases[0], "prefix_cache_score": 0.0} - high_cache_test = {**test_cases[0], "prefix_cache_score": 1.0} - - low_pred = requests.post(f"{PREDICTION_URL}/predict", json=low_cache_test) - high_pred = requests.post(f"{PREDICTION_URL}/predict", json=high_cache_test) - - if low_pred.status_code == 200 and high_pred.status_code == 200: - low_ttft = low_pred.json()["ttft_ms"] - high_ttft = high_pred.json()["ttft_ms"] - cache_impact = high_ttft - low_ttft - print(f" Prefix cache impact: {cache_impact:.1f}ms (expected ~30ms)") - except: - pass - - # Don't fail immediately - try one more relaxed check - if not all_predictions_correct: - print("🔄 Trying more relaxed validation...") - very_relaxed_tolerance = 0.35 # 35% tolerance - relaxed_predictions_correct = True - - for i, test_case in enumerate(test_cases): - pred_r = requests.post(f"{PREDICTION_URL}/predict", json=test_case, timeout=15) - if pred_r.status_code == 200: - pred_data = pred_r.json() - actual_ttft = pred_data["ttft_ms"] - actual_tpot = pred_data["tpot_ms"] - - expected_ttft = ( - test_case["input_token_length"] * 2.0 + test_case["num_request_waiting"] * 3.0 + - test_case["num_request_running"] * 4.0 + test_case["kv_cache_percentage"] * 50.0 + - test_case["prefix_cache_score"] * 30.0 + 95 - ) - expected_tpot = ( - test_case["kv_cache_percentage"] * 100.0 + test_case["input_token_length"] * 0.5 + - test_case["num_tokens_generated"] * 1.0 + test_case["num_request_running"] * 5.0 + 9 - ) - - ttft_error = abs(actual_ttft - expected_ttft) / expected_ttft - tpot_error = abs(actual_tpot - expected_tpot) / expected_tpot - - if ttft_error > very_relaxed_tolerance or tpot_error > very_relaxed_tolerance: - relaxed_predictions_correct = False - - if relaxed_predictions_correct: - print(f"✓ Model learning acceptable with relaxed {very_relaxed_tolerance*100:.0f}% tolerance") - return - - assert all_predictions_correct, f"Model learning failed - predictions not within ±{tolerance*100:.0f}% tolerance" - - -def test_dual_server_model_convergence_over_time(): - """ - Test that the dual-server architecture improves predictions over time - as more training data is added. - """ - print("Testing model convergence over multiple training iterations...") - - # Test features for consistent testing - test_features = { - "kv_cache_percentage": 0.6, - "input_token_length": 300, - "num_request_waiting": 5, - "num_request_running": 2, - "num_tokens_generated": 15, - "prefix_cache_score": 0.75, # Added prefix cache score - } - - # Expected values (updated with prefix cache) - expected_ttft = (300 * 2.0 + 5 * 3.0 + 2 * 4.0 + 0.6 * 50.0 + 0.75 * 30.0 + 95) - expected_tpot = (0.6 * 100.0 + 300 * 0.5 + 15 * 1.0 + 2 * 5.0 + 9) - - predictions_over_time = [] - - # Send training data in batches and test convergence - for iteration in range(1, 4): # 3 iterations - print(f"\nIteration {iteration}: Adding more training data...") - - # Generate batch of training data - batch_entries = [] - for _ in range(50): # 50 samples per batch - kv = random.uniform(0.1, 0.9) - input_len = random.randint(50, 1000) - waiting = random.randint(0, 10) - running = random.randint(1, 5) - tokens_gen = random.randint(1, 30) - prefix_cache = random.uniform(0.0, 1.0) # Added prefix cache - - # Add small amount of noise - noise_ttft = random.uniform(-3, 3) - noise_tpot = random.uniform(-2, 2) - - # Updated equations with prefix cache - actual_ttft = (input_len * 2.0 + waiting * 3.0 + running * 4.0 + kv * 50.0 + prefix_cache * 30.0 + 95) + noise_ttft - actual_tpot = (kv * 100.0 + input_len * 0.5 + tokens_gen * 1.0 + running * 5.0 + 9) + noise_tpot - - batch_entries.append({ - "kv_cache_percentage": kv, - "input_token_length": input_len, - "num_request_waiting": waiting, - "num_request_running": running, - "actual_ttft_ms": max(1.0, actual_ttft), - "actual_tpot_ms": max(1.0, actual_tpot), - "num_tokens_generated": tokens_gen, - "prefix_cache_score": prefix_cache, # Added prefix cache score - }) - - # Send to training server - training_r = requests.post(f"{TRAINING_URL}/add_training_data_bulk", - json={"entries": batch_entries}, timeout=20) - assert training_r.status_code == 202 - - # Wait for training - time.sleep(15) - - # Sync models to prediction server - for attempt in range(3): # Try up to 3 times - reload_r = requests.post(f"{PREDICTION_URL}/reload", timeout=15) - if reload_r.status_code == 200 and reload_r.json().get("is_ready"): - break - time.sleep(5) - - # Make prediction - pred_r = requests.post(f"{PREDICTION_URL}/predict", json=test_features, timeout=10) - assert pred_r.status_code == 200 - - pred_data = pred_r.json() - ttft_error = abs(pred_data["ttft_ms"] - expected_ttft) / expected_ttft - tpot_error = abs(pred_data["tpot_ms"] - expected_tpot) / expected_tpot - - predictions_over_time.append({ - "iteration": iteration, - "training_samples": iteration * 50, - "ttft_prediction": pred_data["ttft_ms"], - "tpot_prediction": pred_data["tpot_ms"], - "ttft_error": ttft_error, - "tpot_error": tpot_error, - }) - - print(f" After {iteration * 50} samples:") - print(f" TTFT error: {ttft_error*100:.1f}%") - print(f" TPOT error: {tpot_error*100:.1f}%") - - # Verify that errors generally decrease over time (convergence) - print(f"\nConvergence Analysis:") - for pred in predictions_over_time: - print(f" {pred['training_samples']} samples: TTFT={pred['ttft_error']*100:.1f}%, TPOT={pred['tpot_error']*100:.1f}%") - - # Check that final iteration has reasonable accuracy - final_prediction = predictions_over_time[-1] - assert final_prediction["ttft_error"] < 0.2, f"TTFT error too high after convergence: {final_prediction['ttft_error']*100:.1f}%" - assert final_prediction["tpot_error"] < 0.2, f"TPOT error too high after convergence: {final_prediction['tpot_error']*100:.1f}%" - - print(f"✓ Model convergence test passed - final errors: TTFT={final_prediction['ttft_error']*100:.1f}%, TPOT={final_prediction['tpot_error']*100:.1f}%") - - -def test_dual_server_model_persistence(): - """ - Test that models persist correctly across prediction server restarts - (simulated by reloading models). - """ - print("Testing model persistence across prediction server 'restarts'...") - - # Make initial prediction - test_features = { - "kv_cache_percentage": 0.4, - "input_token_length": 150, - "num_request_waiting": 3, - "num_request_running": 1, - "num_tokens_generated": 8, - "prefix_cache_score": 0.6, # Added prefix cache score - } - - pred1_r = requests.post(f"{PREDICTION_URL}/predict", json=test_features, timeout=10) - assert pred1_r.status_code == 200 - pred1_data = pred1_r.json() - - print(f"Initial prediction: TTFT={pred1_data['ttft_ms']:.2f}, TPOT={pred1_data['tpot_ms']:.2f}") - - # Simulate "restart" by manually reloading models - print("Simulating prediction server restart by reloading models...") - reload_r = requests.post(f"{PREDICTION_URL}/reload", timeout=15) - assert reload_r.status_code == 200 - assert reload_r.json().get("is_ready"), "Models should be ready after reload" - - # Make same prediction again - pred2_r = requests.post(f"{PREDICTION_URL}/predict", json=test_features, timeout=10) - assert pred2_r.status_code == 200 - pred2_data = pred2_r.json() - - print(f"Post-restart prediction: TTFT={pred2_data['ttft_ms']:.2f}, TPOT={pred2_data['tpot_ms']:.2f}") - - # Predictions should be identical (deterministic models) - ttft_diff = abs(pred1_data["ttft_ms"] - pred2_data["ttft_ms"]) - tpot_diff = abs(pred1_data["tpot_ms"] - pred2_data["tpot_ms"]) - - # Allow tiny differences due to floating point precision - assert ttft_diff < 0.01, f"TTFT predictions should be identical: {ttft_diff}" - assert tpot_diff < 0.01, f"TPOT predictions should be identical: {tpot_diff}" - - print("✓ Model persistence test passed - predictions identical after reload") - - -def test_prefix_cache_score_impact_on_ttft(): - """ - Test that prefix_cache_score has the expected impact on TTFT predictions. - Higher prefix cache scores should generally lead to lower TTFT predictions. - """ - print("Testing prefix cache score impact on TTFT predictions...") - - base_features = { - "kv_cache_percentage": 0.5, - "input_token_length": 300, - "num_request_waiting": 4, - "num_request_running": 2, - "num_tokens_generated": 15, - } - - prefix_cache_scores = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] - predictions = [] - - for prefix_score in prefix_cache_scores: - test_features = {**base_features, "prefix_cache_score": prefix_score} - - pred_r = requests.post(f"{PREDICTION_URL}/predict", json=test_features, timeout=10) - assert pred_r.status_code == 200 - - pred_data = pred_r.json() - predictions.append({ - "prefix_cache_score": prefix_score, - "ttft_ms": pred_data["ttft_ms"], - "tpot_ms": pred_data["tpot_ms"] - }) - - print(f" Prefix cache {prefix_score:.1f}: TTFT={pred_data['ttft_ms']:.1f}ms, TPOT={pred_data['tpot_ms']:.1f}ms") - - # Check that TTFT generally decreases as prefix cache score increases - # (assuming the model learned the positive coefficient for prefix cache) - ttft_values = [p["ttft_ms"] for p in predictions] - - # Calculate correlation between prefix cache score and TTFT - # We expect a positive correlation since higher prefix cache should reduce TTFT - # but our equation has +30*prefix_cache_score, so we expect positive correlation - first_half_avg = sum(ttft_values[:3]) / 3 # Low prefix cache scores - second_half_avg = sum(ttft_values[3:]) / 3 # High prefix cache scores - - print(f"Low prefix cache avg TTFT: {first_half_avg:.1f}ms") - print(f"High prefix cache avg TTFT: {second_half_avg:.1f}ms") - - # Since our training equation has +30*prefix_cache_score, higher prefix cache should increase TTFT - # This tests that the model learned the relationship correctly - ttft_difference = second_half_avg - first_half_avg - print(f"TTFT difference (high - low prefix cache): {ttft_difference:.1f}ms") - - # Should be positive difference (higher prefix cache = higher TTFT in our test equation) - assert ttft_difference > 10, f"Expected TTFT to increase with prefix cache score, got difference: {ttft_difference:.1f}ms" - - # TPOT should not be significantly affected by prefix cache score - tpot_values = [p["tpot_ms"] for p in predictions] - tpot_first_half = sum(tpot_values[:3]) / 3 - tpot_second_half = sum(tpot_values[3:]) / 3 - tpot_difference = abs(tpot_second_half - tpot_first_half) - - print(f"TPOT difference (should be small): {tpot_difference:.1f}ms") - assert tpot_difference < 5, f"TPOT should not be significantly affected by prefix cache, got difference: {tpot_difference:.1f}ms" - - print("✓ Prefix cache score impact test passed") - -async def run_prediction_stress_test(duration_seconds=30, target_qps=300): - """Run stress test against the prediction server only.""" - interval = 1.0 / target_qps - start = time.time() - connector = aiohttp.TCPConnector(limit=1000, limit_per_host=1000) - - async with aiohttp.ClientSession(connector=connector) as session: - tasks = [] - req_id = 0 - next_time = start - - while time.time() - start < duration_seconds: - now = time.time() - while next_time <= now: - req_id += 1 - payload = generate_random_prediction_payload() - tasks.append(asyncio.create_task(async_predict_request(session, payload, req_id))) - next_time += interval - - await asyncio.sleep(0.001) - - print(f"Waiting for {len(tasks)} prediction requests to complete...") - results = await asyncio.gather(*tasks, return_exceptions=True) - valid_results = [r for r in results if isinstance(r, dict)] - - if valid_results: - actual_qps = len(valid_results) / duration_seconds - print(f"Target QPS: {target_qps}, Actual QPS: {actual_qps:.1f}") - - return valid_results +async def async_bulk_predict_request(session, payload, request_id): + """Make an async bulk prediction request.""" + start_time = time.time() + try: + async with session.post(f"{PREDICTION_URL}/predict/bulk/strict", json=payload, timeout=aiohttp.ClientTimeout(total=10)) as response: + end_time = time.time() + response_data = await response.json() + return { + 'request_id': request_id, + 'status_code': response.status, + 'response_time': end_time - start_time, + 'success': response.status == 200, + 'response_data': response_data, + 'batch_size': len(payload.get('requests', [])), + 'predictions_count': len(response_data.get('predictions', [])) if response.status == 200 else 0 + } + except Exception as e: + end_time = time.time() + return { + 'request_id': request_id, + 'status_code': 0, + 'response_time': end_time - start_time, + 'success': False, + 'error': str(e), + 'batch_size': len(payload.get('requests', [])), + 'predictions_count': 0 + } def generate_random_prediction_payload(): @@ -889,10 +614,25 @@ def generate_random_prediction_payload(): "num_request_waiting": random.randint(1, 20), "num_request_running": random.randint(1, 10), "num_tokens_generated": random.randint(1, 20), - "prefix_cache_score": random.uniform(0.0, 1.0), # Added prefix cache score + "prefix_cache_score": random.uniform(0.0, 1.0), } +def generate_bulk_prediction_payload(batch_size=10): + """Generate a bulk prediction payload with specified batch size.""" + requests_data = [] + for _ in range(batch_size): + requests_data.append({ + "kv_cache_percentage": random.uniform(0.1, 0.9), + "input_token_length": random.randint(10, 1000), + "num_request_waiting": random.randint(1, 20), + "num_request_running": random.randint(1, 10), + "num_tokens_generated": random.randint(1, 20), + "prefix_cache_score": random.uniform(0.0, 1.0), + }) + return {"requests": requests_data} + + def generate_random_training_payload(): """Generate a random training payload.""" input_tokens = random.randint(10, 1000) @@ -927,6 +667,219 @@ def generate_random_training_payload(): } +def test_dual_server_quantile_regression_learns_distribution(): + """ + Quantile regression should learn the q-quantile of a Gaussian residual model + with fixed sigma, verified by (a) relative error vs μ+zσ and (b) empirical coverage. + """ + import random, time, math + import numpy as np + import requests + from scipy.stats import norm + + RNG_SEED = 42 + random.seed(RNG_SEED) + np.random.seed(RNG_SEED) + + # Config + TRAIN_N = 3000 + TEST_N = 200 + TTFT_STD, TPOT_STD = 20.0, 10.0 + REL_ERR_TOL = 0.15 # 15% + COVERAGE_TOL = 0.05 # ±5% around target quantile + MAX_WAIT_S = 180 + POLL_INTERVAL_S = 3 + + # 1) Confirm server mode + r = requests.get(f"{TRAINING_URL}/model/download/info", timeout=10) + assert r.status_code == 200, "model info endpoint failed" + model_type = r.json().get("model_type", "unknown") + + s = requests.get(f"{PREDICTION_URL}/status", timeout=10) + assert s.status_code == 200, "prediction status endpoint failed" + target_quantile = float(s.json().get("quantile", 0.9)) + + assert "xgboost" in model_type.lower() or "lightgbm" in model_type.lower(), f"Model not in quantile mode: {model_type}" + + z = norm.ppf(target_quantile) + + # 2) Generate training data (vectorized) + kv = np.random.uniform(0.1, 0.9, size=TRAIN_N) + input_len = np.random.randint(50, 801, size=TRAIN_N) + waiting = np.random.randint(0, 9, size=TRAIN_N) + running = np.random.randint(1, 5, size=TRAIN_N) + tokens_gen = np.random.randint(1, 26, size=TRAIN_N) + prefix = np.random.uniform(0.0, 1.0, size=TRAIN_N) + + ttft_mu = (input_len*2.0 + waiting*3.0 + running*4.0 + kv*50.0 + prefix*30.0 + 95) + tpot_mu = (kv*100.0 + input_len*0.5 + tokens_gen*1.0 + running*5.0 + 9) + + ttft_y = np.maximum(1.0, ttft_mu + np.random.normal(0, TTFT_STD, size=TRAIN_N)) + tpot_y = np.maximum(1.0, tpot_mu + np.random.normal(0, TPOT_STD, size=TRAIN_N)) + + entries = [dict( + kv_cache_percentage=float(kv[i]), + input_token_length=int(input_len[i]), + num_request_waiting=int(waiting[i]), + num_request_running=int(running[i]), + actual_ttft_ms=float(ttft_y[i]), + actual_tpot_ms=float(tpot_y[i]), + num_tokens_generated=int(tokens_gen[i]), + prefix_cache_score=float(prefix[i]), + ) for i in range(TRAIN_N)] + + # 3) Submit training data (with a couple retries) + for _ in range(3): + tr = requests.post(f"{TRAINING_URL}/add_training_data_bulk", json={"entries": entries}, timeout=60) + if tr.status_code == 202: + break + time.sleep(2) + assert tr.status_code == 202, f"training submit failed: {tr.status_code}" + + # 4) Wait for training to complete + time.sleep(30) + # 5) Sync models to prediction server + synced = False + for _ in range(10): + rr = requests.post(f"{PREDICTION_URL}/reload", timeout=20) + if rr.status_code == 200 and rr.json().get("is_ready"): + synced = True + break + time.sleep(3) + assert synced, "Failed to sync models" + + # 6) Build test set + expected quantiles + kv_t = np.random.uniform(0.1, 0.9, size=TEST_N) + in_t = np.random.randint(100, 601, size=TEST_N) + wait_t = np.random.randint(1, 9, size=TEST_N) + run_t = np.random.randint(1, 5, size=TEST_N) + tok_t = np.random.randint(5, 21, size=TEST_N) + pre_t = np.random.uniform(0.0, 1.0, size=TEST_N) + + ttft_mu_t = (in_t*2.0 + wait_t*3.0 + run_t*4.0 + kv_t*50.0 + pre_t*30.0 + 95) + tpot_mu_t = (kv_t*100.0 + in_t*0.5 + tok_t*1.0 + run_t*5.0 + 9) + ttft_q_exp = ttft_mu_t + z*TTFT_STD + tpot_q_exp = tpot_mu_t + z*TPOT_STD + + test_cases = [dict( + kv_cache_percentage=float(kv_t[i]), + input_token_length=int(in_t[i]), + num_request_waiting=int(wait_t[i]), + num_request_running=int(run_t[i]), + num_tokens_generated=int(tok_t[i]), + prefix_cache_score=float(pre_t[i]), + ) for i in range(TEST_N)] + + # 7) Predict (bulk) + pr = requests.post(f"{PREDICTION_URL}/predict/bulk/strict", json={"requests": test_cases}, timeout=60) + assert pr.status_code == 200, f"predict failed: {pr.status_code}" + jd = pr.json() + assert jd["total_requests"] == TEST_N and jd["successful_predictions"] == TEST_N and jd["failed_predictions"] == 0 + preds = jd["predictions"] + + ttft_pred = np.array([p["ttft_ms"] for p in preds], dtype=float) + tpot_pred = np.array([p["tpot_ms"] for p in preds], dtype=float) + + # 8) Relative error vs μ + zσ + ttft_rel_err = np.abs(ttft_pred - ttft_q_exp) / ttft_q_exp + tpot_rel_err = np.abs(tpot_pred - tpot_q_exp) / tpot_q_exp + acc_mask = (ttft_rel_err <= REL_ERR_TOL) & (tpot_rel_err <= REL_ERR_TOL) + rel_accuracy = acc_mask.mean() + print(f"Relative-err accuracy (≤{int(REL_ERR_TOL*100)}%): {rel_accuracy*100:.1f}%") + + # 9) Coverage calibration (simulate actuals for the same test X) + # Generate fresh noise so it's an *unseen* draw from the same D|X: + ttft_actual = np.maximum(1.0, ttft_mu_t + np.random.normal(0, TTFT_STD, size=TEST_N)) + tpot_actual = np.maximum(1.0, tpot_mu_t + np.random.normal(0, TPOT_STD, size=TEST_N)) + + ttft_cov = (ttft_actual <= ttft_pred).mean() + tpot_cov = (tpot_actual <= tpot_pred).mean() + print(f"Coverage: TTFT={ttft_cov:.3f}, TPOT={tpot_cov:.3f} (target {target_quantile:.3f} ± {COVERAGE_TOL})") + + # 10) Monotonic sanity checks on a few random pairs (no hard fail, just helpful asserts) + # pick one sample index and perturb input_token_length upward + idx = 0 + base = test_cases[idx].copy(); up = test_cases[idx].copy(); up["input_token_length"] += 100 + br = requests.post(f"{PREDICTION_URL}/predict/bulk/strict", json={"requests":[base, up]}, timeout=30) + if br.status_code == 200: + _bp = br.json()["predictions"] + assert _bp[1]["ttft_ms"] >= _bp[0]["ttft_ms"] - 1e-6, "TTFT should not decrease with longer input" + + # 11) Final assertions + assert rel_accuracy >= 0.70, f"Only {rel_accuracy*100:.1f}% within ±{int(REL_ERR_TOL*100)}% (expected ≥70%)" + assert abs(ttft_cov - target_quantile) <= COVERAGE_TOL, f"TTFT coverage {ttft_cov:.3f} not within ±{COVERAGE_TOL} of {target_quantile:.3f}" + assert abs(tpot_cov - target_quantile) <= COVERAGE_TOL, f"TPOT coverage {tpot_cov:.3f} not within ±{COVERAGE_TOL} of {target_quantile:.3f}" + + + + +async def run_prediction_stress_test(duration_seconds=30, target_qps=1000): + """Run stress test against the prediction server only.""" + interval = 1.0 / target_qps + start = time.time() + connector = aiohttp.TCPConnector(limit=1000, limit_per_host=1000) + + async with aiohttp.ClientSession(connector=connector) as session: + tasks = [] + req_id = 0 + next_time = start + + while time.time() - start < duration_seconds: + now = time.time() + while next_time <= now: + req_id += 1 + payload = generate_random_prediction_payload() + tasks.append(asyncio.create_task(async_predict_request(session, payload, req_id))) + next_time += interval + + await asyncio.sleep(0.001) + + print(f"Waiting for {len(tasks)} prediction requests to complete...") + results = await asyncio.gather(*tasks, return_exceptions=True) + valid_results = [r for r in results if isinstance(r, dict)] + + if valid_results: + actual_qps = len(valid_results) / duration_seconds + print(f"Target QPS: {target_qps}, Actual QPS: {actual_qps:.1f}") + + return valid_results + + +async def run_bulk_prediction_stress_test(duration_seconds=30, target_rps=100, batch_size=10): + """Run stress test against the bulk prediction endpoint.""" + interval = 1.0 / target_rps # requests per second + start = time.time() + connector = aiohttp.TCPConnector(limit=200, limit_per_host=200) + + async with aiohttp.ClientSession(connector=connector) as session: + tasks = [] + req_id = 0 + next_time = start + + while time.time() - start < duration_seconds: + now = time.time() + while next_time <= now: + req_id += 1 + payload = generate_bulk_prediction_payload(batch_size) + tasks.append(asyncio.create_task(async_bulk_predict_request(session, payload, req_id))) + next_time += interval + + await asyncio.sleep(0.01) # Slightly longer sleep for bulk requests + + print(f"Waiting for {len(tasks)} bulk prediction requests to complete...") + results = await asyncio.gather(*tasks, return_exceptions=True) + valid_results = [r for r in results if isinstance(r, dict)] + + if valid_results: + actual_rps = len(valid_results) / duration_seconds + total_predictions = sum(r.get('predictions_count', 0) for r in valid_results) + actual_pps = total_predictions / duration_seconds # predictions per second + print(f"Target RPS: {target_rps}, Actual RPS: {actual_rps:.1f}") + print(f"Total Predictions: {total_predictions}, Predictions/sec: {actual_pps:.1f}") + + return valid_results + + def analyze_prediction_stress_results(results): """Analyze prediction stress test results.""" if not results: @@ -977,11 +930,60 @@ def analyze_prediction_stress_results(results): print(f" P99: {p99:.2f}ms") +def analyze_bulk_prediction_stress_results(results): + """Analyze bulk prediction stress test results.""" + if not results: + print("No results to analyze") + return + + total_requests = len(results) + successful_requests = sum(1 for r in results if r.get('success', False)) + failed_requests = total_requests - successful_requests + + total_predictions = sum(r.get('predictions_count', 0) for r in results) + total_batch_size = sum(r.get('batch_size', 0) for r in results) + + response_times = [r['response_time'] for r in results if r.get('response_time')] + avg_response_time = sum(response_times) / len(response_times) if response_times else 0 + + status_codes = defaultdict(int) + for r in results: + status_codes[r.get('status_code', 0)] += 1 + + print(f"\n{'='*50}") + print("BULK PREDICTION STRESS TEST RESULTS") + print(f"{'='*50}") + print(f"Total Bulk Requests: {total_requests}") + print(f"Successful: {successful_requests} ({successful_requests/total_requests*100:.1f}%)") + print(f"Failed: {failed_requests} ({failed_requests/total_requests*100:.1f}%)") + print(f"Total Individual Predictions: {total_predictions}") + print(f"Total Batch Size: {total_batch_size}") + print(f"Average Response Time: {avg_response_time*1000:.2f}ms") + + if total_batch_size > 0: + print(f"Average Batch Size: {total_batch_size/total_requests:.1f}") + print(f"Prediction Success Rate: {total_predictions/total_batch_size*100:.1f}%") + + print(f"\nStatus Code Distribution:") + for status, count in status_codes.items(): + print(f" {status}: {count}") + + if response_times: + sorted_times = sorted(response_times) + p50 = sorted_times[int(len(sorted_times) * 0.5)] * 1000 + p95 = sorted_times[int(len(sorted_times) * 0.95)] * 1000 + p99 = sorted_times[int(len(sorted_times) * 0.99)] * 1000 + print(f"\nResponse Time Percentiles:") + print(f" P50: {p50:.2f}ms") + print(f" P95: {p95:.2f}ms") + print(f" P99: {p99:.2f}ms") + + def test_prediction_server_stress_test(): """Stress test the prediction server.""" print("Running prediction server stress test...") - results = asyncio.run(run_prediction_stress_test(duration_seconds=60, target_qps=300)) + results = asyncio.run(run_prediction_stress_test(duration_seconds=100, target_qps=TARGET_QPS)) analyze_prediction_stress_results(results) @@ -995,6 +997,57 @@ def test_prediction_server_stress_test(): print(f"Prediction server stress test completed with {success_rate*100:.1f}% success rate") +def test_bulk_prediction_stress_test(): + """Stress test the bulk prediction endpoint.""" + print("Running bulk prediction stress test...") + + # Test with different batch sizes + batch_sizes = [5, 10, 25] + for batch_size in batch_sizes: + print(f"\nTesting with batch size {batch_size}...") + results = asyncio.run(run_bulk_prediction_stress_test( + duration_seconds=100, + target_rps=TARGET_QPS, # Lower RPS for bulk requests + batch_size=batch_size + )) + + analyze_bulk_prediction_stress_results(results) + + assert len(results) > 0, f"No bulk requests were made for batch size {batch_size}" + + successful_requests = sum(1 for r in results if r.get('success', False)) + success_rate = successful_requests / len(results) + + assert success_rate > 0.7, f"Bulk success rate too low for batch size {batch_size}: {success_rate*100:.1f}%" + + print(f"Bulk prediction stress test (batch size {batch_size}) completed with {success_rate*100:.1f}% success rate") + +def test_large_batch_prediction_stress_test(): + """Stress test the bulk prediction endpoint.""" + print("Running bulk prediction stress test...") + + # Test with different batch sizes + batch_sizes = [1000] + for batch_size in batch_sizes: + print(f"\nTesting with batch size {batch_size}...") + results = asyncio.run(run_bulk_prediction_stress_test( + duration_seconds=100, + target_rps=TARGET_QPS_LARGE_BATCH, # Lower RPS for bulk requests + batch_size=batch_size + )) + + analyze_bulk_prediction_stress_results(results) + + assert len(results) > 0, f"No bulk requests were made for batch size {batch_size}" + + successful_requests = sum(1 for r in results if r.get('success', False)) + success_rate = successful_requests / len(results) + + assert success_rate > 0.7, f"Bulk success rate too low for batch size {batch_size}: {success_rate*100:.1f}%" + + print(f"Bulk prediction stress test (batch size {batch_size}) completed with {success_rate*100:.1f}% success rate") + + def test_end_to_end_workflow(): """Test the complete end-to-end workflow with robust error handling.""" print("Testing end-to-end workflow...") @@ -1108,16 +1161,19 @@ def test_server_configuration(): ("Send Training Data", test_add_training_data_to_training_server), ("Model Sync", test_prediction_server_model_sync), ("Predictions", test_prediction_via_prediction_server), + ("Bulk Prediction Strict", test_bulk_prediction_strict), + ("Bulk Prediction With Errors", test_bulk_prediction_all_valid), + ("Bulk predictions all valid", test_bulk_prediction_with_validation_errors), ("Prediction Missing Prefix Cache", test_prediction_missing_prefix_cache_score), ("Training Metrics", test_training_server_metrics), ("Model Consistency", test_model_consistency_between_servers), - ("XGBoost Trees", test_xgboost_tree_endpoints_on_training_server), - ("Prefix Cache Score Impact", test_prefix_cache_score_impact_on_ttft), - ("Dual Server Model Learns Equation", test_dual_server_model_learns_equation), - ("Dual Server Model Convergence", test_dual_server_model_convergence_over_time), - ("Model Persistence", test_dual_server_model_persistence), + ("XGBoost Trees", test_model_specific_endpoints_on_training_server), + + ("Dual Server Model Learns Equation", test_dual_server_quantile_regression_learns_distribution), ("End-to-End Workflow", test_end_to_end_workflow), ("Prediction Stress Test", test_prediction_server_stress_test), + ("Bulk Prediction Stress Test", test_bulk_prediction_stress_test), + ("Large Batch Prediction Stress Test", test_large_batch_prediction_stress_test), ] passed = 0 diff --git a/latencypredictor-v1/test_latency_predictor_client.py b/latencypredictor-v1/test_latency_predictor_client.py deleted file mode 100644 index 402f14fb7..000000000 --- a/latencypredictor-v1/test_latency_predictor_client.py +++ /dev/null @@ -1,1244 +0,0 @@ -import os -import time -import asyncio -import aiohttp -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed -from collections import defaultdict -import random - -import pytest -import requests - -import joblib -import numpy as np -import tempfile -import xgboost - -# Base URL of your running FastAPI server -BASE_URL = os.getenv("TRAINING_SERVER_URL", "http://34.143.221.122:80") - -# Helper to wait until the server is ready -def wait_for_ready(timeout: float = 30.0, interval: float = 1.0): - start = time.time() - while True: - try: - r = requests.get(f"{BASE_URL}/readyz", timeout=2.0) - if r.status_code == 200: - return - except requests.RequestException: - pass - if time.time() - start > timeout: - pytest.skip("Server did not become ready in time") - time.sleep(interval) - -@pytest.fixture(scope="module", autouse=True) -def ensure_server_ready(): - """Wait for the /readyz endpoint before running tests.""" - wait_for_ready() - - -def test_healthz(): - r = requests.get(f"{BASE_URL}/healthz") - assert r.status_code == 200 - assert r.json().get("status") == "ok" - - -def test_readyz(): - r = requests.get(f"{BASE_URL}/readyz") - assert r.status_code == 200 - assert r.json().get("status") == "ready" - - -def test_model_info(): - """Test the simplified /model/download/info endpoint.""" - r = requests.get(f"{BASE_URL}/model/download/info") - assert r.status_code == 200 - - data = r.json() - assert "model_type" in data - assert "model_status" in data - assert "available_endpoints" in data - assert data["model_type"] in ["bayesian_ridge", "xgboost"] - assert isinstance(data["model_status"], dict) - - print(f"Server using model type: {data['model_type']}") - - if data["model_type"] == "bayesian_ridge": - assert "coefficients_info" in data - assert data["available_endpoints"]["coefficients"] == "/metrics" - else: # XGBoost - assert "trees" in data["available_endpoints"] - - -def test_root_endpoint_enhanced(): - """Test the enhanced root endpoint that now includes model info.""" - r = requests.get(f"{BASE_URL}/") - assert r.status_code == 200 - - data = r.json() - assert "message" in data - assert "model_type" in data - assert data["model_type"] in ["bayesian_ridge", "xgboost"] - - -def test_add_training_data_bulk(): - """ - Send 120 training samples in one bulk request so the server can retrain: - Updated equations with prefix cache score: - actual_ttft_ms = 2*input_token_length + 3*num_request_waiting + - 4*num_request_running + 50*kv_cache_percentage + - 30*prefix_cache_score + 95 - actual_tpot_ms = 100*kv_cache_percentage + 0.5*input_token_length + 1*num_tokens_generated + - 5*num_request_running + 9 - """ - entries = [] - common = { - "kv_cache_percentage": 0.5, - "num_request_running": 1, - } - - for i in range(1, 121): - waiting = i % 10 + 1 - tokens = waiting - inp_len = 10 * i - kv = common["kv_cache_percentage"] - running = common["num_request_running"] - prefix_cache = random.uniform(0.1, 0.9) # Added prefix cache score - - entries.append({ - "kv_cache_percentage": kv, - "input_token_length": inp_len, - "num_request_waiting": waiting, - "num_request_running": running, - # Updated TTFT formula to include prefix_cache_score - "actual_ttft_ms": (inp_len*2.0 + waiting*3.0 + running*4.0 + kv*50.0 + prefix_cache*30.0) + 95, - # TPOT formula remains unchanged - "actual_tpot_ms": (kv*100.0 + inp_len*0.5 + tokens*1.0 + running*5.0) + 9, - "num_tokens_generated": tokens, - "prefix_cache_score": prefix_cache, # Added prefix cache score - "timestamp": time.time() # FastAPI will coerce to datetime - }) - - payload = {"entries": entries} - r = requests.post(f"{BASE_URL}/add_training_data_bulk", json=payload) - assert r.status_code == 202, f"Expected 202, got {r.status_code}" - assert r.json().get("message") == "Accepted 120 training samples." - - -def test_model_learns_equation(): - """ - After sending bulk data, poll /predict until the model's predictions - match our linear equations within tolerance, or fail after 60s. - Updated to include prefix_cache_score in the test equation. - """ - # First check what model type we're using - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_type = model_info_r.json().get("model_type", "unknown") - - features = { - "kv_cache_percentage": 0.5, - "input_token_length": 200, - "num_request_waiting": 4, - "num_request_running": 1, - "num_tokens_generated": 4, - "prefix_cache_score": 0.7, # Added prefix cache score - } - - # Updated expected TTFT to include prefix cache score - expected_ttft = ( - features["input_token_length"] * 2.0 - + features["num_request_waiting"] * 3.0 - + features["num_request_running"] * 4.0 - + features["kv_cache_percentage"] * 50.0 - + features["prefix_cache_score"] * 30.0 # New term - + 95 - ) - # TPOT formula remains unchanged - expected_tpot = ( - features["kv_cache_percentage"] * 100.0 - + features["input_token_length"] * 0.5 - + features["num_tokens_generated"] * 1.0 - + features["num_request_running"] * 5.0 + 9 - ) - - # Adjust tolerance based on model type - # XGBoost might need more tolerance for tree-based predictions - tolerance = 0.15 if model_type == "xgboost" else 0.1 - - deadline = time.time() + 60.0 - last_ttft, last_tpot = None, None - - while time.time() < deadline: - r = requests.post(f"{BASE_URL}/predict", json=features) - if r.status_code != 200: - time.sleep(1) - continue - - body = r.json() - last_ttft = body["ttft_ms"] - last_tpot = body["tpot_ms"] - - # Verify the response includes model_type - assert "model_type" in body, "Response should include model_type" - assert body["model_type"] == model_type - - ttft_ok = abs(last_ttft - expected_ttft) <= tolerance * expected_ttft - tpot_ok = abs(last_tpot - expected_tpot) <= tolerance * expected_tpot - if ttft_ok and tpot_ok: - print(f"Model converged with {model_type} in {60.0 - (deadline - time.time()):.1f}s") - print(f" Expected TTFT: {expected_ttft:.1f}, Got: {last_ttft:.1f}") - print(f" Expected TPOT: {expected_tpot:.1f}, Got: {last_tpot:.1f}") - break - - time.sleep(1) - - assert last_ttft is not None, "Never got a successful prediction." - assert abs(last_ttft - expected_ttft) <= tolerance * expected_ttft, ( - f"TTFT={last_ttft:.1f} not within ±{tolerance*100}% of {expected_ttft:.1f} (model: {model_type})" - ) - assert abs(last_tpot - expected_tpot) <= tolerance * expected_tpot, ( - f"TPOT={last_tpot:.1f} not within ±{tolerance*100}% of {expected_tpot:.1f} (model: {model_type})" - ) - - -def test_prediction_missing_prefix_cache_score(): - """Test that predictions fail when prefix_cache_score is missing.""" - features = { - "kv_cache_percentage": 0.5, - "input_token_length": 200, - "num_request_waiting": 4, - "num_request_running": 1, - "num_tokens_generated": 4, - # Missing prefix_cache_score - } - - r = requests.post(f"{BASE_URL}/predict", json=features) - assert r.status_code == 422 # Should fail validation - - print("✓ Prediction correctly failed when prefix_cache_score was missing") - - -def test_prefix_cache_score_impact_on_ttft(): - """ - Test that prefix_cache_score has the expected impact on TTFT predictions. - Since our test equation has +30*prefix_cache_score, higher scores should increase TTFT. - """ - print("Testing prefix cache score impact on TTFT predictions...") - - base_features = { - "kv_cache_percentage": 0.5, - "input_token_length": 300, - "num_request_waiting": 4, - "num_request_running": 2, - "num_tokens_generated": 15, - } - - prefix_cache_scores = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] - predictions = [] - - for prefix_score in prefix_cache_scores: - test_features = {**base_features, "prefix_cache_score": prefix_score} - - pred_r = requests.post(f"{BASE_URL}/predict", json=test_features, timeout=10) - assert pred_r.status_code == 200 - - pred_data = pred_r.json() - predictions.append({ - "prefix_cache_score": prefix_score, - "ttft_ms": pred_data["ttft_ms"], - "tpot_ms": pred_data["tpot_ms"] - }) - - print(f" Prefix cache {prefix_score:.1f}: TTFT={pred_data['ttft_ms']:.1f}ms, TPOT={pred_data['tpot_ms']:.1f}ms") - - # Check that TTFT increases as prefix cache score increases - # (since our test equation has +30*prefix_cache_score) - ttft_values = [p["ttft_ms"] for p in predictions] - - # Calculate correlation between prefix cache score and TTFT - first_half_avg = sum(ttft_values[:3]) / 3 # Low prefix cache scores - second_half_avg = sum(ttft_values[3:]) / 3 # High prefix cache scores - - print(f"Low prefix cache avg TTFT: {first_half_avg:.1f}ms") - print(f"High prefix cache avg TTFT: {second_half_avg:.1f}ms") - - # Since our training equation has +30*prefix_cache_score, higher prefix cache should increase TTFT - ttft_difference = second_half_avg - first_half_avg - print(f"TTFT difference (high - low prefix cache): {ttft_difference:.1f}ms") - - # Should be positive difference (higher prefix cache = higher TTFT in our test equation) - assert ttft_difference > 10, f"Expected TTFT to increase with prefix cache score, got difference: {ttft_difference:.1f}ms" - - # TPOT should not be significantly affected by prefix cache score - tpot_values = [p["tpot_ms"] for p in predictions] - tpot_first_half = sum(tpot_values[:3]) / 3 - tpot_second_half = sum(tpot_values[3:]) / 3 - tpot_difference = abs(tpot_second_half - tpot_first_half) - - print(f"TPOT difference (should be small): {tpot_difference:.1f}ms") - assert tpot_difference < 5, f"TPOT should not be significantly affected by prefix cache, got difference: {tpot_difference:.1f}ms" - - print("✓ Prefix cache score impact test passed") - - -def test_prediction_response_format(): - """Test that prediction responses include all expected fields including new model_type.""" - features = generate_random_prediction_payload() - - r = requests.post(f"{BASE_URL}/predict", json=features) - assert r.status_code == 200 - - data = r.json() - required_fields = [ - "ttft_ms", "tpot_ms", "ttft_uncertainty", "tpot_uncertainty", - "ttft_prediction_bounds", "tpot_prediction_bounds", - "predicted_at", "model_type" - ] - - for field in required_fields: - assert field in data, f"Missing required field: {field}" - - # Verify model_type is valid - assert data["model_type"] in ["bayesian_ridge", "xgboost"] - - # Verify numeric fields are reasonable - assert data["ttft_ms"] >= 0 - assert data["tpot_ms"] >= 0 - assert data["ttft_uncertainty"] >= 0 - assert data["tpot_uncertainty"] >= 0 - - # Verify bounds are tuples - assert len(data["ttft_prediction_bounds"]) == 2 - assert len(data["tpot_prediction_bounds"]) == 2 - - -def test_metrics_endpoint_enhanced(): - """Test that metrics endpoint includes model-specific information with proper coefficients.""" - r = requests.get(f"{BASE_URL}/metrics") - assert r.status_code == 200 - - content = r.text - - # Should contain model type metric - assert "model_type{" in content - - # Should contain either coefficients (Bayesian Ridge) or importance (XGBoost) - has_coef = "ttft_coef{" in content or "tpot_coef{" in content - has_importance = "ttft_importance{" in content or "tpot_importance{" in content - - assert has_coef or has_importance, "Should have either coefficients or feature importance metrics" - - # Should have standard metrics - assert "ttft_r2_score{" in content - assert "tpot_r2_score{" in content - assert "training_samples_count" in content - - # Check for prefix_cache_score in TTFT metrics - if has_coef: - assert 'feature="prefix_cache_score"' in content, "Should have prefix_cache_score coefficient for TTFT model" - if has_importance: - assert 'feature="prefix_cache_score"' in content, "Should have prefix_cache_score importance for TTFT model" - - # Parse and validate coefficient values for Bayesian Ridge - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_type = model_info_r.json().get("model_type") - - if model_type == "bayesian_ridge": - # Check that coefficients are present and reasonable - lines = content.split('\n') - ttft_intercept = None - ttft_coefs = {} - tpot_intercept = None - tpot_coefs = {} - - for line in lines: - if line.startswith('ttft_intercept{'): - ttft_intercept = float(line.split('}')[1].strip()) - elif line.startswith('ttft_coef{'): - feature = line.split('feature="')[1].split('"')[0] - value = float(line.split('}')[1].strip()) - ttft_coefs[feature] = value - elif line.startswith('tpot_intercept{'): - tpot_intercept = float(line.split('}')[1].strip()) - elif line.startswith('tpot_coef{'): - feature = line.split('feature="')[1].split('"')[0] - value = float(line.split('}')[1].strip()) - tpot_coefs[feature] = value - - # Validate coefficients are present - assert ttft_intercept is not None, "TTFT intercept should be present" - assert tpot_intercept is not None, "TPOT intercept should be present" - - # Updated expected features to include prefix_cache_score for TTFT - expected_ttft_features = ["kv_cache_percentage", "input_token_length", "num_request_waiting", "num_request_running", "prefix_cache_score"] - expected_tpot_features = ["kv_cache_percentage", "input_token_length", "num_request_waiting", "num_request_running", "num_tokens_generated"] - - for feature in expected_ttft_features: - assert feature in ttft_coefs, f"TTFT coefficient for {feature} should be present" - - for feature in expected_tpot_features: - assert feature in tpot_coefs, f"TPOT coefficient for {feature} should be present" - - print(f"✓ Bayesian Ridge coefficients validated:") - print(f" TTFT intercept: {ttft_intercept:.4f}") - print(f" TTFT coefficients: {ttft_coefs}") - print(f" TPOT intercept: {tpot_intercept:.4f}") - print(f" TPOT coefficients: {tpot_coefs}") - - # Validate prefix_cache_score coefficient is reasonable - if "prefix_cache_score" in ttft_coefs: - prefix_coef = ttft_coefs["prefix_cache_score"] - print(f" Prefix cache coefficient: {prefix_coef:.4f}") - # Should be positive and reasonably close to our training value of 30 - assert 10 < prefix_coef < 50, f"Prefix cache coefficient should be reasonable: {prefix_coef}" - - print("✓ Training server metrics endpoint working correctly with prefix cache support") - - -def test_xgboost_tree_endpoints(): - """Test XGBoost tree endpoints if XGBoost is being used.""" - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_type = model_info_r.json().get("model_type") - - if model_type != "xgboost": - print("Skipping XGBoost tree tests - not using XGBoost model") - return - - print("Testing XGBoost tree endpoints...") - - # Test TTFT trees - ttft_response = requests.get(f"{BASE_URL}/model/ttft/xgb/json") - assert ttft_response.status_code == 200, "TTFT XGBoost trees should be available" - ttft_trees = ttft_response.json() - assert isinstance(ttft_trees, list), "TTFT trees should be a list" - assert len(ttft_trees) > 0, "Should have TTFT trees" - assert isinstance(ttft_trees[0], dict), "Each tree should be a dict" - - # Test TPOT trees - tpot_response = requests.get(f"{BASE_URL}/model/tpot/xgb/json") - assert tpot_response.status_code == 200, "TPOT XGBoost trees should be available" - tpot_trees = tpot_response.json() - assert isinstance(tpot_trees, list), "TPOT trees should be a list" - assert len(tpot_trees) > 0, "Should have TPOT trees" - assert isinstance(tpot_trees[0], dict), "Each tree should be a dict" - - print(f"✓ XGBoost trees available: {len(ttft_trees)} TTFT trees, {len(tpot_trees)} TPOT trees") - - -def test_bayesian_ridge_coefficients(): - """Test that Bayesian Ridge coefficients are properly descaled and stored.""" - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_type = model_info_r.json().get("model_type") - - if model_type != "bayesian_ridge": - print("Skipping Bayesian Ridge coefficient tests - not using Bayesian Ridge model") - return - - print("Testing Bayesian Ridge coefficient storage and retrieval...") - - # Get coefficients from metrics - r = requests.get(f"{BASE_URL}/metrics") - assert r.status_code == 200 - content = r.text - - # Parse coefficients from metrics - lines = content.split('\n') - ttft_coefs = {} - tpot_coefs = {} - - for line in lines: - if line.startswith('ttft_coef{'): - feature = line.split('feature="')[1].split('"')[0] - value = float(line.split('}')[1].strip()) - ttft_coefs[feature] = value - elif line.startswith('tpot_coef{'): - feature = line.split('feature="')[1].split('"')[0] - value = float(line.split('}')[1].strip()) - tpot_coefs[feature] = value - - # Test a prediction to see if coefficients make sense - test_features = { - "kv_cache_percentage": 0.5, - "input_token_length": 100, - "num_request_waiting": 2, - "num_request_running": 1, - "num_tokens_generated": 5, - "prefix_cache_score": 0.8, # Added prefix cache score - } - - # Make prediction via API - pred_response = requests.post(f"{BASE_URL}/predict", json=test_features) - assert pred_response.status_code == 200 - api_prediction = pred_response.json() - - print(f"✓ Coefficients extracted from metrics:") - print(f" TTFT coefficients: {ttft_coefs}") - print(f" TPOT coefficients: {tpot_coefs}") - print(f" API TTFT prediction: {api_prediction['ttft_ms']:.2f}") - print(f" API TPOT prediction: {api_prediction['tpot_ms']:.2f}") - - # Verify prefix_cache_score coefficient exists for TTFT - assert "prefix_cache_score" in ttft_coefs, "prefix_cache_score should be in TTFT coefficients" - assert "prefix_cache_score" not in tpot_coefs, "prefix_cache_score should NOT be in TPOT coefficients" - - -def test_model_endpoints_by_type(): - """Test the appropriate endpoints based on model type.""" - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_info = model_info_r.json() - model_type = model_info["model_type"] - - print(f"Testing endpoints for model type: {model_type}") - - if model_type == "bayesian_ridge": - # For Bayesian Ridge, we should have coefficients in metrics - test_bayesian_ridge_coefficients() - - # XGBoost endpoints should return 404 - ttft_xgb_response = requests.get(f"{BASE_URL}/model/ttft/xgb/json") - assert ttft_xgb_response.status_code == 404, "XGBoost endpoints should not be available for Bayesian Ridge" - - print("✓ Bayesian Ridge: coefficients available in metrics, XGBoost endpoints properly blocked") - - else: # XGBoost - # For XGBoost, we should have tree endpoints - test_xgboost_tree_endpoints() - - print("✓ XGBoost: tree endpoints available") - - -def generate_random_prediction_payload(): - """Generate a random prediction payload for stress testing including prefix_cache_score.""" - return { - "kv_cache_percentage": random.uniform(0.1, 0.9), - "input_token_length": random.randint(10, 1000), - "num_request_waiting": random.randint(1, 20), - "num_request_running": random.randint(1, 10), - "num_tokens_generated": random.randint(1, 20), - "prefix_cache_score": random.uniform(0.0, 1.0), # Added prefix cache score - } - - -def generate_random_training_payload(): - """Generate a random training data payload for stress testing with updated TTFT formula.""" - input_tokens = random.randint(10, 1000) - waiting_requests = random.randint(1, 20) - running_requests = random.randint(1, 10) - kv = random.uniform(0.01, 0.99) - tokens_generated = random.randint(1, 20) - prefix_cache = random.uniform(0.0, 1.0) # Added prefix cache score - - return { - "kv_cache_percentage": kv, - "input_token_length": input_tokens, - "num_request_waiting": waiting_requests, - "num_request_running": running_requests, - # Updated linear TTFT with noise - now includes prefix_cache_score - "actual_ttft_ms": ( - input_tokens * 2.0 - + waiting_requests * 3.0 - + running_requests * 4.0 - + kv * 50.0 - + prefix_cache * 30.0 # New term for prefix cache - + 95 + random.uniform(-10, 10) - ), - # TPOT formula remains unchanged - "actual_tpot_ms": ( - kv * 100.0 - + input_tokens * 0.5 - + tokens_generated * 1.0 - + running_requests * 5.0 - + 9 + random.uniform(-5, 5) - ), - "num_tokens_generated": tokens_generated, - "prefix_cache_score": prefix_cache, # Added prefix cache score - } - - -def generate_bulk_training_payload(size=1000): - """Generate a bulk training payload with specified number of entries.""" - entries = [] - for _ in range(size): - entries.append(generate_random_training_payload()) - return {"entries": entries} - - -async def async_post_request(session, url, payload, request_id): - """Make an async POST request and return result with metadata.""" - start_time = time.time() - try: - async with session.post(url, json=payload, timeout=aiohttp.ClientTimeout(total=5)) as response: - end_time = time.time() - response_data = await response.json() - return { - 'request_id': request_id, - 'status_code': response.status, - 'response_time': end_time - start_time, - 'success': response.status in [200, 202], - 'response_data': response_data, - 'request_type': 'predict' if '/predict' in url else 'training', - 'model_type': response_data.get('model_type') if response.status == 200 else None - } - except Exception as e: - end_time = time.time() - return { - 'request_id': request_id, - 'status_code': 0, - 'response_time': end_time - start_time, - 'success': False, - 'error': str(e), - 'request_type': 'predict' if '/predict' in url else 'training', - 'model_type': None - } - -async def run_stress_test_async(duration_seconds=10, target_qps=300): - interval = 1.0/target_qps - start = time.time() - connector = aiohttp.TCPConnector(limit=10000, limit_per_host=10000, ttl_dns_cache=300, use_dns_cache=True) - async with aiohttp.ClientSession(connector=connector, timeout=aiohttp.ClientTimeout(total=2)) as sess: - tasks = [] - req_id = 0 - next_time = start - while time.time() - start < duration_seconds: - now = time.time() - while next_time <= now: - req_id += 1 - if random.random()<0.5: - url = f"{BASE_URL}/predict" - payload = generate_random_prediction_payload() - else: - url = f"{BASE_URL}/add_training_data_bulk" - payload = {"entries":[ generate_random_training_payload() ]} - tasks.append(asyncio.create_task(async_post_request(sess, url, payload, req_id))) - next_time += interval - await asyncio.sleep(0.0001) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - valid_results = [r for r in results if isinstance(r, dict)] - - # Calculate actual QPS achieved - if valid_results: - actual_duration = duration_seconds - actual_qps = len(valid_results) / actual_duration - print(f"Target QPS: {target_qps}, Actual QPS: {actual_qps:.0f}") - - return valid_results - - -def fetch_and_parse_xgb_json(path_suffix): - """ - Download the XGBoost JSON dump for `path_suffix` (ttft or tpot), - parse into a Python list of dicts, and return it. - """ - url = f"{BASE_URL}/model/{path_suffix}/xgb/json" - r = requests.get(url, timeout=10) - assert r.status_code == 200, f"Failed to fetch JSON for {path_suffix}" - trees = r.json() - assert isinstance(trees, list), "Expected a JSON array of trees" - assert len(trees) > 0, "Tree list should not be empty" - assert isinstance(trees[0], dict), "Each tree must be a JSON object" - return trees - - -async def async_fetch_and_parse_xgb_json(session, suffix, request_id): - """ - Async GET /model//xgb/json and return timing + status. - """ - url = f"{BASE_URL}/model/{suffix}/xgb/json" - start = time.time() - try: - async with session.get(url, timeout=aiohttp.ClientTimeout(total=10)) as resp: - data = await resp.json() - elapsed = time.time() - start - return { - 'request_id': request_id, - 'request_type': f'download_{suffix}', - 'status_code': resp.status, - 'response_time': elapsed, - 'success': resp.status == 200, - 'tree_count': len(data) if isinstance(data, list) else None - } - except Exception as e: - elapsed = time.time() - start - return { - 'request_id': request_id, - 'request_type': f'download_{suffix}', - 'status_code': 0, - 'response_time': elapsed, - 'success': False, - 'error': str(e) - } - - -async def run_simplified_stress_test(duration_seconds=10, target_qps=2): - """ - Simplified stress test: bulk training vs predictions and tree downloads (XGBoost only). - """ - info_r = requests.get(f"{BASE_URL}/model/download/info", timeout=5.0) - model_type = info_r.json().get("model_type", "bayesian_ridge") - - interval = 1.0 / target_qps - start = time.time() - connector = aiohttp.TCPConnector(limit=1000, limit_per_host=1000) - async with aiohttp.ClientSession(connector=connector) as sess: - tasks = [] - req_id = 0 - next_time = start - - while time.time() - start < duration_seconds: - now = time.time() - while next_time <= now: - req_id += 1 - - if random.random() < 0.5: - # Either predictions or tree downloads (XGBoost only) - if random.random() < 0.7: # 70% predictions - url = f"{BASE_URL}/predict" - payload = generate_random_prediction_payload() - task = asyncio.create_task( - async_post_request_with_timeout( - sess, url, payload, req_id, - aiohttp.ClientTimeout(total=5), "predict" - ) - ) - else: # 30% tree downloads (only for XGBoost) - if model_type == "xgboost": - suffix = random.choice(["ttft", "tpot"]) - task = asyncio.create_task( - async_fetch_and_parse_xgb_json(sess, suffix, req_id) - ) - else: - # For Bayesian Ridge, just do another prediction - url = f"{BASE_URL}/predict" - payload = generate_random_prediction_payload() - task = asyncio.create_task( - async_post_request_with_timeout( - sess, url, payload, req_id, - aiohttp.ClientTimeout(total=5), "predict" - ) - ) - else: - # bulk training - url = f"{BASE_URL}/add_training_data_bulk" - payload = generate_bulk_training_payload(1000) - task = asyncio.create_task( - async_post_request_with_timeout( - sess, url, payload, req_id, - aiohttp.ClientTimeout(total=30), "bulk_training" - ) - ) - - tasks.append(task) - next_time += interval - - await asyncio.sleep(0.001) - - print(f"Waiting for {len(tasks)} requests to complete…") - results = await asyncio.gather(*tasks, return_exceptions=True) - valid = [r for r in results if isinstance(r, dict)] - - if valid: - actual_qps = len(valid) / duration_seconds - print(f"Target QPS: {target_qps}, Actual QPS: {actual_qps:.2f}") - - return valid - - -async def async_post_request_with_timeout(session, url, payload, request_id, timeout, request_type): - """Make an async POST request with custom timeout and return result with metadata.""" - start_time = time.time() - try: - async with session.post(url, json=payload, timeout=timeout) as response: - end_time = time.time() - response_data = await response.json() - - # Count training entries for bulk requests - training_entries = len(payload.get("entries", [])) if request_type == "bulk_training" else 1 - - return { - 'request_id': request_id, - 'status_code': response.status, - 'response_time': end_time - start_time, - 'success': response.status in [200, 202], - 'response_data': response_data, - 'request_type': request_type, - 'training_entries': training_entries if request_type == "bulk_training" else 0, - 'model_type': response_data.get('model_type') if response.status == 200 and request_type == 'predict' else None - } - except Exception as e: - end_time = time.time() - training_entries = len(payload.get("entries", [])) if request_type == "bulk_training" else 1 - return { - 'request_id': request_id, - 'status_code': 0, - 'response_time': end_time - start_time, - 'success': False, - 'error': str(e), - 'request_type': request_type, - 'training_entries': training_entries if request_type == "bulk_training" else 0, - 'model_type': None - } - - -def analyze_stress_test_results(results): - """Analyze and print stress test results with model type information.""" - if not results: - print("No results to analyze") - return - - total_requests = len(results) - successful_requests = sum(1 for r in results if r.get('success', False)) - failed_requests = total_requests - successful_requests - - response_times = [r['response_time'] for r in results if r.get('response_time')] - avg_response_time = sum(response_times) / len(response_times) if response_times else 0 - - status_codes = defaultdict(int) - for r in results: - status_codes[r.get('status_code', 0)] += 1 - - request_types = defaultdict(int) - for r in results: - request_types[r.get('request_type', 'unknown')] += 1 - - # Analyze model types in prediction responses - model_types = defaultdict(int) - for r in results: - if r.get('model_type'): - model_types[r['model_type']] += 1 - - test_duration = max(response_times) if response_times else 0 - actual_qps = total_requests / test_duration if test_duration > 0 else 0 - - print(f"\n{'='*50}") - print("STRESS TEST RESULTS") - print(f"{'='*50}") - print(f"Total Requests: {total_requests}") - print(f"Successful: {successful_requests} ({successful_requests/total_requests*100:.1f}%)") - print(f"Failed: {failed_requests} ({failed_requests/total_requests*100:.1f}%)") - print(f"Average Response Time: {avg_response_time*1000:.2f}ms") - print(f"Actual QPS: {actual_qps:.0f}") - print(f"\nRequest Types:") - for req_type, count in request_types.items(): - print(f" {req_type}: {count}") - print(f"\nStatus Code Distribution:") - for status, count in status_codes.items(): - print(f" {status}: {count}") - - if model_types: - print(f"\nModel Types in Predictions:") - for model_type, count in model_types.items(): - print(f" {model_type}: {count}") - - if response_times: - sorted_times = sorted(response_times) - p50 = sorted_times[int(len(sorted_times) * 0.5)] * 1000 - p95 = sorted_times[int(len(sorted_times) * 0.95)] * 1000 - p99 = sorted_times[int(len(sorted_times) * 0.99)] * 1000 - print(f"\nResponse Time Percentiles:") - print(f" P50: {p50:.2f}ms") - print(f" P95: {p95:.2f}ms") - print(f" P99: {p99:.2f}ms") - - -def analyze_bulk_training_results(results): - """Analyze and print bulk training stress test results with additional metrics.""" - if not results: - print("No results to analyze") - return - - total_requests = len(results) - successful_requests = sum(1 for r in results if r.get('success', False)) - failed_requests = total_requests - successful_requests - - # Separate analysis by request type - prediction_results = [r for r in results if r.get('request_type') == 'predict'] - bulk_training_results = [r for r in results if r.get('request_type') == 'bulk_training'] - download_results = [r for r in results if r.get('request_type', '').startswith('download_')] - - # Calculate total training entries processed - total_training_entries = sum(r.get('training_entries', 0) for r in bulk_training_results) - - # Analyze model types in prediction responses - model_types = defaultdict(int) - for r in prediction_results: - if r.get('model_type'): - model_types[r['model_type']] += 1 - - response_times = [r['response_time'] for r in results if r.get('response_time')] - avg_response_time = sum(response_times) / len(response_times) if response_times else 0 - - status_codes = defaultdict(int) - for r in results: - status_codes[r.get('status_code', 0)] += 1 - - request_types = defaultdict(int) - for r in results: - request_types[r.get('request_type', 'unknown')] += 1 - - print(f"\n{'='*60}") - print("BULK TRAINING STRESS TEST RESULTS") - print(f"{'='*60}") - print(f"Total Requests: {total_requests}") - print(f"Successful: {successful_requests} ({successful_requests/total_requests*100:.1f}%)") - print(f"Failed: {failed_requests} ({failed_requests/total_requests*100:.1f}%)") - print(f"Average Response Time: {avg_response_time*1000:.2f}ms") - - print(f"\nRequest Type Breakdown:") - print(f" Prediction requests: {len(prediction_results)}") - print(f" Bulk training requests: {len(bulk_training_results)}") - print(f" Model download requests: {len(download_results)}") - print(f" Total training entries processed: {total_training_entries}") - - if model_types: - print(f"\nModel Types in Predictions:") - for model_type, count in model_types.items(): - print(f" {model_type}: {count}") - - print(f"\nStatus Code Distribution:") - for status, count in status_codes.items(): - print(f" {status}: {count}") - - # Response time analysis by request type - if prediction_results: - pred_times = [r['response_time'] for r in prediction_results if r.get('response_time')] - if pred_times: - avg_pred_time = sum(pred_times) / len(pred_times) - print(f"\nPrediction Request Response Times:") - print(f" Average: {avg_pred_time*1000:.2f}ms") - print(f" Min: {min(pred_times)*1000:.2f}ms") - print(f" Max: {max(pred_times)*1000:.2f}ms") - - if bulk_training_results: - bulk_times = [r['response_time'] for r in bulk_training_results if r.get('response_time')] - if bulk_times: - avg_bulk_time = sum(bulk_times) / len(bulk_times) - print(f"\nBulk Training Request Response Times:") - print(f" Average: {avg_bulk_time*1000:.2f}ms") - print(f" Min: {min(bulk_times)*1000:.2f}ms") - print(f" Max: {max(bulk_times)*1000:.2f}ms") - - if download_results: - download_times = [r['response_time'] for r in download_results if r.get('response_time')] - if download_times: - avg_download_time = sum(download_times) / len(download_times) - print(f"\nModel Download Request Response Times:") - print(f" Average: {avg_download_time*1000:.2f}ms") - print(f" Min: {min(download_times)*1000:.2f}ms") - print(f" Max: {max(download_times)*1000:.2f}ms") - - if response_times: - sorted_times = sorted(response_times) - p50 = sorted_times[int(len(sorted_times) * 0.5)] * 1000 - p95 = sorted_times[int(len(sorted_times) * 0.95)] * 1000 - p99 = sorted_times[int(len(sorted_times) * 0.99)] * 1000 - print(f"\nOverall Response Time Percentiles:") - print(f" P50: {p50:.2f}ms") - print(f" P95: {p95:.2f}ms") - print(f" P99: {p99:.2f}ms") - - -def test_stress_test_high_qps(): - """ - Stress test with 300 QPS for 10 seconds. - Sends predictions and training data in parallel. - """ - results = asyncio.run(run_stress_test_async(duration_seconds=10, target_qps=300)) - - analyze_stress_test_results(results) - - assert len(results) > 0, "No requests were made" - - successful_requests = sum(1 for r in results if r.get('success', False)) - success_rate = successful_requests / len(results) - - assert success_rate > 0.8, f"Success rate too low: {success_rate*100:.1f}%" - - print(f"Stress test completed successfully with {success_rate*100:.1f}% success rate") - - -def test_stress_test_mixed_load(): - """ - Alternative stress test with mixed load patterns. - Tests server stability under varying load conditions. - """ - print("Running mixed load stress test...") - - print("Phase 1: Ramping up load...") - results_phase1 = asyncio.run(run_stress_test_async(duration_seconds=5, target_qps=100)) - - print("Phase 2: High sustained load...") - results_phase2 = asyncio.run(run_stress_test_async(duration_seconds=10, target_qps=300)) - - print("Phase 3: Cooling down...") - results_phase3 = asyncio.run(run_stress_test_async(duration_seconds=5, target_qps=50)) - - all_results = results_phase1 + results_phase2 + results_phase3 - - print("\nCOMBINED RESULTS FOR ALL PHASES:") - analyze_stress_test_results(all_results) - - assert len(all_results) > 0, "No requests were made" - - successful_requests = sum(1 for r in all_results if r.get('success', False)) - success_rate = successful_requests / len(all_results) - - assert success_rate > 0.75, f"Overall success rate too low: {success_rate*100:.1f}%" - - print(f"Mixed load stress test completed with {success_rate*100:.1f}% success rate") - - -def test_simplified_stress_test(): - """Simplified stress test focusing on predictions, training, and tree downloads with prefix cache.""" - print("Running simplified stress test with prefix cache score support...") - print("Configuration: 2 QPS, 50% bulk training, 35% predictions, 15% tree downloads (XGBoost only)") - - results = asyncio.run(run_simplified_stress_test(duration_seconds=60, target_qps=2)) - - analyze_bulk_training_results(results) - - assert len(results) > 0, "No requests were made" - - successful_requests = sum(1 for r in results if r.get('success', False)) - success_rate = successful_requests / len(results) - - # Count request types - prediction_count = sum(1 for r in results if r.get('request_type') == 'predict') - bulk_training_count = sum(1 for r in results if r.get('request_type') == 'bulk_training') - download_count = sum(1 for r in results if r.get('request_type', '').startswith('download_')) - - assert success_rate > 0.8, f"Success rate too low: {success_rate*100:.1f}%" - assert prediction_count > 0, "No prediction requests were made" - assert bulk_training_count > 0, "No bulk training requests were made" - - print(f"✓ Simplified stress test with prefix cache completed:") - print(f" Success rate: {success_rate*100:.1f}%") - print(f" Prediction requests: {prediction_count}") - print(f" Tree download requests: {download_count}") - print(f" Bulk training requests: {bulk_training_count}") - - -def test_model_type_consistency(): - """ - Test that the model type is consistent across all API endpoints. - """ - print("Testing model type consistency across endpoints...") - - # Get model type from different endpoints - root_response = requests.get(f"{BASE_URL}/") - model_info_response = requests.get(f"{BASE_URL}/model/download/info") - - # Make a prediction to get model type from prediction response - prediction_request = generate_random_prediction_payload() - prediction_response = requests.post(f"{BASE_URL}/predict", json=prediction_request) - - # Extract model types - root_model_type = root_response.json().get("model_type") - model_info_model_type = model_info_response.json().get("model_type") - prediction_model_type = prediction_response.json().get("model_type") - - # Check consistency - assert root_model_type == model_info_model_type == prediction_model_type, ( - f"Model type inconsistency: root={root_model_type}, " - f"model_info={model_info_model_type}, prediction={prediction_model_type}" - ) - - print(f"Model type consistent across all endpoints: {root_model_type}") - - -def test_xgboost_vs_bayesian_ridge_performance(): - """ - Performance comparison test (if both models are available). - This test will check model performance differences. - """ - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_info = model_info_r.json() - - print(f"Current model: {model_info['model_type']}") - - # Generate test predictions with prefix cache scores - test_cases = [generate_random_prediction_payload() for _ in range(10)] - - predictions = [] - response_times = [] - - for test_case in test_cases: - start_time = time.time() - response = requests.post(f"{BASE_URL}/predict", json=test_case) - end_time = time.time() - - assert response.status_code == 200 - predictions.append(response.json()) - response_times.append((end_time - start_time) * 1000) # Convert to ms - - avg_response_time = sum(response_times) / len(response_times) - avg_prefix_cache = sum(tc['prefix_cache_score'] for tc in test_cases) / len(test_cases) - - print(f"Model: {predictions[0]['model_type']}") - print(f"Average response time: {avg_response_time:.2f}ms") - print(f"Average prefix cache score: {avg_prefix_cache:.2f}") - print(f"Average TTFT prediction: {sum(p['ttft_ms'] for p in predictions)/len(predictions):.2f}ms") - print(f"Average TPOT prediction: {sum(p['tpot_ms'] for p in predictions)/len(predictions):.2f}ms") - print(f"Average TTFT uncertainty: {sum(p['ttft_uncertainty'] for p in predictions)/len(predictions):.2f}") - print(f"Average TPOT uncertainty: {sum(p['tpot_uncertainty'] for p in predictions)/len(predictions):.2f}") - - # Basic sanity checks - assert avg_response_time < 1000, f"Response time too slow: {avg_response_time:.2f}ms" - assert all(p['ttft_ms'] > 0 for p in predictions), "All TTFT predictions should be positive" - assert all(p['tpot_ms'] > 0 for p in predictions), "All TPOT predictions should be positive" - - -def test_uncertainty_estimation_quality(): - """ - Test the quality of uncertainty estimation for both model types. - """ - model_info_r = requests.get(f"{BASE_URL}/model/download/info") - model_type = model_info_r.json().get("model_type") - - # Generate multiple predictions for the same input - test_payload = { - "kv_cache_percentage": 0.5, - "input_token_length": 100, - "num_request_waiting": 2, - "num_request_running": 1, - "num_tokens_generated": 5, - "prefix_cache_score": 0.8, # Added prefix cache score - } - - predictions = [] - for _ in range(5): # Make multiple identical requests - response = requests.post(f"{BASE_URL}/predict", json=test_payload) - assert response.status_code == 200 - predictions.append(response.json()) - - # Check that predictions are consistent (should be identical for same input) - ttft_values = [p['ttft_ms'] for p in predictions] - tpot_values = [p['tpot_ms'] for p in predictions] - - ttft_std = sum((x - ttft_values[0])**2 for x in ttft_values)**0.5 / len(ttft_values) - tpot_std = sum((x - tpot_values[0])**2 for x in tpot_values)**0.5 / len(tpot_values) - - # For deterministic models, predictions should be identical - if model_type == "bayesian_ridge": - assert ttft_std < 0.01, f"TTFT predictions should be consistent, got std: {ttft_std}" - assert tpot_std < 0.01, f"TPOT predictions should be consistent, got std: {tpot_std}" - - # Check uncertainty values are reasonable - pred = predictions[0] - ttft_uncertainty_ratio = pred['ttft_uncertainty'] / pred['ttft_ms'] - tpot_uncertainty_ratio = pred['tpot_uncertainty'] / pred['tpot_ms'] - - print(f"Model: {model_type}") - print(f"Prefix cache score: {test_payload['prefix_cache_score']}") - print(f"TTFT: {pred['ttft_ms']:.2f} ± {pred['ttft_uncertainty']:.2f} ({ttft_uncertainty_ratio*100:.1f}%)") - print(f"TPOT: {pred['tpot_ms']:.2f} ± {pred['tpot_uncertainty']:.2f} ({tpot_uncertainty_ratio*100:.1f}%)") - - # Uncertainty should be reasonable (not too high or too low) - assert 0.01 < ttft_uncertainty_ratio < 0.5, f"TTFT uncertainty ratio should be reasonable: {ttft_uncertainty_ratio}" - assert 0.01 < tpot_uncertainty_ratio < 0.5, f"TPOT uncertainty ratio should be reasonable: {tpot_uncertainty_ratio}" - - # Check prediction bounds contain the prediction - ttft_bounds = pred['ttft_prediction_bounds'] - tpot_bounds = pred['tpot_prediction_bounds'] - - assert ttft_bounds[0] <= pred['ttft_ms'] <= ttft_bounds[1], "TTFT should be within prediction bounds" - assert tpot_bounds[0] <= pred['tpot_ms'] <= tpot_bounds[1], "TPOT should be within prediction bounds" - - -def test_edge_cases(): - """ - Test edge cases and boundary conditions with prefix cache score. - """ - # Test minimum values - min_payload = { - "kv_cache_percentage": 0.0, - "input_token_length": 1, - "num_request_waiting": 0, - "num_request_running": 0, - "num_tokens_generated": 1, - "prefix_cache_score": 0.0, # Added prefix cache score - } - - response = requests.post(f"{BASE_URL}/predict", json=min_payload) - assert response.status_code == 200 - data = response.json() - assert data['ttft_ms'] > 0 - assert data['tpot_ms'] > 0 - - # Test maximum reasonable values - max_payload = { - "kv_cache_percentage": 1.0, - "input_token_length": 10000, - "num_request_waiting": 100, - "num_request_running": 50, - "num_tokens_generated": 1000, - "prefix_cache_score": 1.0, # Added prefix cache score - } - - response = requests.post(f"{BASE_URL}/predict", json=max_payload) - assert response.status_code == 200 - data = response.json() - assert data['ttft_ms'] > 0 - assert data['tpot_ms'] > 0 - - # Test invalid values (should fail validation) - invalid_payloads = [ - {"kv_cache_percentage": -0.1, "input_token_length": 100, "num_request_waiting": 1, "num_request_running": 1, "num_tokens_generated": 10, "prefix_cache_score": 0.5}, - {"kv_cache_percentage": 1.1, "input_token_length": 100, "num_request_waiting": 1, "num_request_running": 1, "num_tokens_generated": 10, "prefix_cache_score": 0.5}, - {"kv_cache_percentage": 0.5, "input_token_length": -1, "num_request_waiting": 1, "num_request_running": 1, "num_tokens_generated": 10, "prefix_cache_score": 0.5}, - {"kv_cache_percentage": 0.5, "input_token_length": 100, "num_request_waiting": -1, "num_request_running": 1, "num_tokens_generated": 10, "prefix_cache_score": 0.5}, - {"kv_cache_percentage": 0.5, "input_token_length": 100, "num_request_waiting": 1, "num_request_running": -1, "num_tokens_generated": 10, "prefix_cache_score": 0.5}, - {"kv_cache_percentage": 0.5, "input_token_length": 100, "num_request_waiting": 1, "num_request_running": 1, "num_tokens_generated": -1, "prefix_cache_score": 0.5}, - {"kv_cache_percentage": 0.5, "input_token_length": 100, "num_request_waiting": 1, "num_request_running": 1, "num_tokens_generated": 10, "prefix_cache_score": -0.1}, # Invalid prefix cache - {"kv_cache_percentage": 0.5, "input_token_length": 100, "num_request_waiting": 1, "num_request_running": 1, "num_tokens_generated": 10, "prefix_cache_score": 1.1}, # Invalid prefix cache - ] - - for invalid_payload in invalid_payloads: - response = requests.post(f"{BASE_URL}/predict", json=invalid_payload) - assert response.status_code == 422, f"Should reject invalid payload: {invalid_payload}" - - -def test_concurrent_training_and_prediction(): - """ - Test that training and prediction can happen concurrently without issues. - """ - print("Testing concurrent training and prediction with prefix cache...") - - def make_predictions(): - results = [] - for _ in range(20): - payload = generate_random_prediction_payload() - try: - response = requests.post(f"{BASE_URL}/predict", json=payload, timeout=5) - results.append(response.status_code == 200) - except: - results.append(False) - time.sleep(0.1) - return results - - def send_training_data(): - results = [] - for _ in range(5): - payload = generate_bulk_training_payload(100) # Smaller batches for faster processing - try: - response = requests.post(f"{BASE_URL}/add_training_data_bulk", json=payload, timeout=10) - results.append(response.status_code == 202) - except: - results.append(False) - time.sleep(0.5) - return results - - # Run both functions concurrently - with ThreadPoolExecutor(max_workers=2) as executor: - prediction_future = executor.submit(make_predictions) - training_future = executor.submit(send_training_data) - - prediction_results = prediction_future.result() - training_results = training_future.result() - - prediction_success_rate = sum(prediction_results) / len(prediction_results) - training_success_rate = sum(training_results) / len(training_results) - - print(f"Prediction success rate: {prediction_success_rate*100:.1f}%") \ No newline at end of file diff --git a/latencypredictor-v1/training_server.py b/latencypredictor-v1/training_server.py index d8d504e04..559917a6d 100644 --- a/latencypredictor-v1/training_server.py +++ b/latencypredictor-v1/training_server.py @@ -32,10 +32,18 @@ XGBOOST_AVAILABLE = False logging.warning("XGBoost not available. Please install with: pip install xgboost") +try: + import lightgbm as lgb + LIGHTGBM_AVAILABLE = True +except ImportError: + LIGHTGBM_AVAILABLE = False + logging.warning("LightGBM not available. Please install with: pip install lightgbm") + class ModelType(str, Enum): BAYESIAN_RIDGE = "bayesian_ridge" XGBOOST = "xgboost" + LIGHTGBM = "lightgbm" class RandomDropDeque(deque): @@ -92,6 +100,8 @@ class Settings: class ModelInfoResponse(BaseModel): model_type: str xgboost_available: bool + lightgbm_available: bool = Field(default=False, description="Whether LightGBM is available") # FIXED: Added this field + is_ready: bool ttft_training_samples: int = Field(default=0, description="Number of TTFT training samples") tpot_training_samples: int = Field(default=0, description="Number of TPOT training samples") @@ -170,13 +180,17 @@ def __init__(self, model_type: str = None): if model_type is None: model_type = settings.MODEL_TYPE - if model_type not in [ModelType.BAYESIAN_RIDGE, ModelType.XGBOOST]: + if model_type not in [ModelType.BAYESIAN_RIDGE, ModelType.XGBOOST, ModelType.LIGHTGBM]: raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(ModelType)}") if model_type == ModelType.XGBOOST and not XGBOOST_AVAILABLE: logging.warning("XGBoost requested but not available. Falling back to Bayesian Ridge.") model_type = ModelType.BAYESIAN_RIDGE - + + if model_type == ModelType.LIGHTGBM and not LIGHTGBM_AVAILABLE: + logging.warning("LightGBM requested but not available. Falling back to Bayesian Ridge.") + model_type = ModelType.BAYESIAN_RIDGE + self.model_type = ModelType(model_type) self.quantile = settings.QUANTILE_ALPHA logging.info(f"Initialized LatencyPredictor with model type: {self.model_type}, quantile: {self.quantile}") @@ -290,7 +304,7 @@ def is_ready(self) -> bool: """Checks if all models and scalers are loaded/trained.""" if self.model_type == ModelType.BAYESIAN_RIDGE: return all([self.ttft_model, self.tpot_model, self.ttft_scaler, self.tpot_scaler]) - else: # XGBoost + else: # XGBoost or LightGBM return all([self.ttft_model, self.tpot_model]) @is_ready.setter @@ -305,7 +319,8 @@ def _all_samples(self, buckets: dict) -> list: samples.extend(bucket_deque) return samples - def _train_model_with_scaling(self, features: pd.DataFrame, target: pd.Series) -> Union[Tuple[BayesianRidge, StandardScaler], xgb.XGBRegressor]: + def _train_model_with_scaling(self, features: pd.DataFrame, target: pd.Series) -> Union[Tuple[BayesianRidge, StandardScaler], xgb.XGBRegressor, lgb.LGBMRegressor]: + try: if len(features) == 0 or len(target) == 0: raise ValueError("Empty training data") @@ -326,8 +341,8 @@ def _train_model_with_scaling(self, features: pd.DataFrame, target: pd.Series) - model = BayesianRidge(compute_score=True) model.fit(features_scaled, target) return model, scaler - - else: # XGBoost with quantile regression + + elif self.model_type == ModelType.XGBOOST: # XGBoost with quantile regression model = xgb.XGBRegressor( n_estimators=200, # Number of trees to build (moderate value for balanced accuracy and speed) max_depth=6, # Depth of trees; 6 is typically a sweet spot balancing bias/variance @@ -345,6 +360,25 @@ def _train_model_with_scaling(self, features: pd.DataFrame, target: pd.Series) - ) model.fit(features, target) return model + elif self.model_type == ModelType.LIGHTGBM: # LightGBM with quantile regression + model = lgb.LGBMRegressor( + n_estimators=200, # Number of trees + max_depth=6, # Maximum tree depth + learning_rate=0.05, # Learning rate + subsample=0.8, # Row sampling ratio + colsample_bytree=0.8, # Column sampling ratio + min_child_samples=20, # Minimum samples in leaf + reg_alpha=0.1, # L1 regularization + reg_lambda=0.1, # L2 regularization + objective="quantile", # Quantile regression objective + alpha=self.quantile, # Quantile level (e.g., 0.9 for p90) + n_jobs=-1, # Use all cores + random_state=42, # Reproducibility + verbosity=-1, # Suppress warnings + force_col_wise=True # Better for small datasets + ) + model.fit(features, target) + return model except Exception as e: logging.error(f"Error in _train_model_with_scaling: {e}", exc_info=True) @@ -386,7 +420,8 @@ def _calculate_quantile_metrics_on_test(self, model, scaler, test_data, feature_ logging.error(f"Error calculating quantile metrics: {e}", exc_info=True) return None, None, None - def _create_default_model(self, model_type: str) -> Union[Tuple[BayesianRidge, StandardScaler], xgb.XGBRegressor]: + def _create_default_model(self, model_type: str) -> Union[Tuple[BayesianRidge, StandardScaler], xgb.XGBRegressor, lgb.LGBMRegressor]: + """Creates and trains a simple default model with initial priors.""" try: logging.info(f"Creating default '{model_type}' model with priors.") @@ -577,16 +612,26 @@ def predict(self, features: dict) -> Tuple[float, float, float, float]: return ttft_pred, tpot_pred, ttft_std[0], tpot_std[0] - else: # XGBoost with true quantile regression + elif self.model_type == ModelType.XGBOOST: # XGBoost quantile regression directly predicts the quantile ttft_pred = self.ttft_model.predict(df_ttft) tpot_pred = self.tpot_model.predict(df_tpot) - + # For XGBoost quantile regression, uncertainty estimation is more complex - # We'll use a simple heuristic based on the quantile value ttft_std = ttft_pred[0] * 0.1 # 10% of prediction as uncertainty estimate tpot_std = tpot_pred[0] * 0.1 - + + return ttft_pred[0], tpot_pred[0], ttft_std, tpot_std + + else: # LightGBM with quantile regression + # LightGBM quantile regression directly predicts the quantile + ttft_pred = self.ttft_model.predict(df_ttft) + tpot_pred = self.tpot_model.predict(df_tpot) + + # For LightGBM quantile regression, use a similar uncertainty estimate as XGBoost + ttft_std = ttft_pred[0] * 0.1 # 10% of prediction as uncertainty estimate + tpot_std = tpot_pred[0] * 0.1 + return ttft_pred[0], tpot_pred[0], ttft_std, tpot_std except ValueError as ve: @@ -646,21 +691,21 @@ def add_training_samples(self, samples: list): logging.exception("Failed to add one sample in bulk ingestion") + # Update the _save_models_unlocked method to handle LightGBM model exports def _save_models_unlocked(self): try: if self.ttft_model: os.makedirs(os.path.dirname(settings.TTFT_MODEL_PATH), exist_ok=True) joblib.dump(self.ttft_model, settings.TTFT_MODEL_PATH) logging.info("TTFT model saved.") - - # Save XGBoost booster trees as JSON + + # Save model-specific exports if self.model_type == ModelType.XGBOOST: try: booster = self.ttft_model.get_booster() raw_trees = booster.get_dump(dump_format="json") trees = [json.loads(t) for t in raw_trees] - - # Save to JSON file alongside the model + ttft_json_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_trees.json') with open(ttft_json_path, 'w') as f: json.dump(trees, f, indent=2) @@ -668,24 +713,43 @@ def _save_models_unlocked(self): except Exception as e: logging.error(f"Error saving TTFT XGBoost trees: {e}", exc_info=True) + elif self.model_type == ModelType.LIGHTGBM: + try: + # Save LightGBM model as text format + ttft_txt_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_lgb.txt') + self.ttft_model.booster_.save_model(ttft_txt_path) + + # Save feature importances as JSON + feature_names = ['kv_cache_percentage', 'input_token_length', + 'num_request_waiting', 'num_request_running', 'prefix_cache_score'] + importances = dict(zip(feature_names, self.ttft_model.feature_importances_)) + + ttft_imp_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_importances.json') + with open(ttft_imp_path, 'w') as f: + json.dump(importances, f, indent=2) + + logging.info(f"TTFT LightGBM model saved to {ttft_txt_path}") + logging.info(f"TTFT LightGBM importances saved to {ttft_imp_path}") + except Exception as e: + logging.error(f"Error saving TTFT LightGBM exports: {e}", exc_info=True) + if self.ttft_scaler and self.model_type == ModelType.BAYESIAN_RIDGE: os.makedirs(os.path.dirname(settings.TTFT_SCALER_PATH), exist_ok=True) joblib.dump(self.ttft_scaler, settings.TTFT_SCALER_PATH) logging.info("TTFT scaler saved.") - + if self.tpot_model: os.makedirs(os.path.dirname(settings.TPOT_MODEL_PATH), exist_ok=True) joblib.dump(self.tpot_model, settings.TPOT_MODEL_PATH) logging.info("TPOT model saved.") - - # Save XGBoost booster trees as JSON + + # Save model-specific exports if self.model_type == ModelType.XGBOOST: try: booster = self.tpot_model.get_booster() raw_trees = booster.get_dump(dump_format="json") trees = [json.loads(t) for t in raw_trees] - - # Save to JSON file alongside the model + tpot_json_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_trees.json') with open(tpot_json_path, 'w') as f: json.dump(trees, f, indent=2) @@ -693,11 +757,31 @@ def _save_models_unlocked(self): except Exception as e: logging.error(f"Error saving TPOT XGBoost trees: {e}", exc_info=True) + elif self.model_type == ModelType.LIGHTGBM: + try: + # Save LightGBM model as text format + tpot_txt_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_lgb.txt') + self.tpot_model.booster_.save_model(tpot_txt_path) + + # Save feature importances as JSON + feature_names = ['kv_cache_percentage', 'input_token_length', + 'num_request_waiting', 'num_request_running', 'num_tokens_generated'] + importances = dict(zip(feature_names, self.tpot_model.feature_importances_)) + + tpot_imp_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_importances.json') + with open(tpot_imp_path, 'w') as f: + json.dump(importances, f, indent=2) + + logging.info(f"TPOT LightGBM model saved to {tpot_txt_path}") + logging.info(f"TPOT LightGBM importances saved to {tpot_imp_path}") + except Exception as e: + logging.error(f"Error saving TPOT LightGBM exports: {e}", exc_info=True) + if self.tpot_scaler and self.model_type == ModelType.BAYESIAN_RIDGE: os.makedirs(os.path.dirname(settings.TPOT_SCALER_PATH), exist_ok=True) joblib.dump(self.tpot_scaler, settings.TPOT_SCALER_PATH) logging.info("TPOT scaler saved.") - + except Exception as e: logging.error(f"Error saving models: {e}", exc_info=True) @@ -1000,11 +1084,18 @@ async def model_download_info(): "tpot_coefficients_available": predictor.tpot_coefficients is not None, "description": "Descaled coefficients available in Prometheus metrics endpoint" } - else: # XGBoost + elif predictor.model_type == ModelType.XGBOOST: info["available_endpoints"]["trees"] = { "ttft_trees": "/model/ttft/xgb/json", "tpot_trees": "/model/tpot/xgb/json" } + else: # LightGBM - FIXED: Added LightGBM endpoints + info["available_endpoints"]["lightgbm"] = { + "ttft_model_txt": "/model/ttft/lgb/txt", + "tpot_model_txt": "/model/tpot/lgb/txt", + "ttft_importances": "/model/ttft/lgb/importances", + "tpot_importances": "/model/tpot/lgb/importances" + } info["model_status"] = { "ttft_model_ready": predictor.ttft_model is not None, @@ -1167,6 +1258,88 @@ async def list_models(): } } +# Add new API endpoints for LightGBM model exports +@app.get("/model/ttft/lgb/txt") +async def ttft_lgb_txt(): + """ + Download the TTFT LightGBM model as text format. + """ + if predictor.model_type != ModelType.LIGHTGBM: + raise HTTPException(status_code=404, detail="TTFT model is not LightGBM") + + if not predictor.ttft_model: + raise HTTPException(status_code=404, detail="TTFT model not available") + + txt_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_lgb.txt') + if not os.path.exists(txt_path): + raise HTTPException(status_code=404, detail="TTFT LightGBM text model not found") + + return FileResponse( + txt_path, + media_type='text/plain', + filename='ttft_lgb_model.txt' + ) + +@app.get("/model/tpot/lgb/txt") +async def tpot_lgb_txt(): + """ + Download the TPOT LightGBM model as text format. + """ + if predictor.model_type != ModelType.LIGHTGBM: + raise HTTPException(status_code=404, detail="TPOT model is not LightGBM") + + if not predictor.tpot_model: + raise HTTPException(status_code=404, detail="TPOT model not available") + + txt_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_lgb.txt') + if not os.path.exists(txt_path): + raise HTTPException(status_code=404, detail="TPOT LightGBM text model not found") + + return FileResponse( + txt_path, + media_type='text/plain', + filename='tpot_lgb_model.txt' + ) + +@app.get("/model/ttft/lgb/importances") +async def ttft_lgb_importances(): + """ + Get TTFT LightGBM feature importances as JSON. + """ + if predictor.model_type != ModelType.LIGHTGBM: + raise HTTPException(status_code=404, detail="TTFT model is not LightGBM") + + if not predictor.ttft_model: + raise HTTPException(status_code=404, detail="TTFT model not available") + + imp_path = settings.TTFT_MODEL_PATH.replace('.joblib', '_importances.json') + if not os.path.exists(imp_path): + raise HTTPException(status_code=404, detail="TTFT LightGBM importances not found") + + with open(imp_path, 'r') as f: + importances = json.load(f) + + return JSONResponse(content=importances) + +@app.get("/model/tpot/lgb/importances") +async def tpot_lgb_importances(): + """ + Get TPOT LightGBM feature importances as JSON. + """ + if predictor.model_type != ModelType.LIGHTGBM: + raise HTTPException(status_code=404, detail="TPOT model is not LightGBM") + + if not predictor.tpot_model: + raise HTTPException(status_code=404, detail="TPOT model not available") + + imp_path = settings.TPOT_MODEL_PATH.replace('.joblib', '_importances.json') + if not os.path.exists(imp_path): + raise HTTPException(status_code=404, detail="TPOT LightGBM importances not found") + + with open(imp_path, 'r') as f: + importances = json.load(f) + + return JSONResponse(content=importances) if __name__ == "__main__": uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file diff --git a/pkg/epp/latencypredictorasync/latencypredictor_async.go b/pkg/epp/latencypredictorasync/latencypredictor_async.go index 870d5e9ad..7524f3d86 100644 --- a/pkg/epp/latencypredictorasync/latencypredictor_async.go +++ b/pkg/epp/latencypredictorasync/latencypredictor_async.go @@ -36,8 +36,10 @@ type Config struct { UseNativeXGBoost bool // HTTPTimeout is the timeout for HTTP requests to the Python server. HTTPTimeout time.Duration - + // MetricsRefreshInterval determines how often to refresh cached metrics. MetricsRefreshInterval time.Duration + // MaxBulkSize is the maximum number of predictions to send in a single bulk request. + MaxBulkSize int } func DefaultConfig() *Config { @@ -49,6 +51,7 @@ func DefaultConfig() *Config { MetricsRefreshInterval: 60 * time.Second, UseNativeXGBoost: true, HTTPTimeout: 10 * time.Second, + MaxBulkSize: 100, } } @@ -87,19 +90,26 @@ func ConfigFromEnv() *Config { cfg.HTTPTimeout = time.Duration(sec) * time.Second } } - if s := os.Getenv("LATENCY_METRICS_INTERVAL_SEC"); s != "" { if sec, err := strconv.Atoi(s); err == nil && sec > 0 { cfg.MetricsRefreshInterval = time.Duration(sec) * time.Second } } + if bulkStr := os.Getenv("LATENCY_MAX_BULK_SIZE"); bulkStr != "" { + if size, err := strconv.Atoi(bulkStr); err == nil && size > 0 && size <= 100 { + cfg.MaxBulkSize = size + } + } return cfg } // Predictor defines the interface for latency prediction and training. type PredictorInterface interface { Predict(ctx context.Context, req PredictionRequest) (*PredictionResponse, error) + PredictBulk(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) + PredictBulkStrict(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) AddTrainingDataBulk(entry []TrainingEntry) error + GetServerStatus(ctx context.Context) (*ServerStatusResponse, error) } // --- Data Models --- @@ -112,7 +122,7 @@ type TrainingEntry struct { NumTokensGenerated int `json:"num_tokens_generated"` ActualTTFT float64 `json:"actual_ttft_ms"` ActualTPOT float64 `json:"actual_tpot_ms"` - PrefixCacheScore float64 `json:"prefix_cache_score"` // Added prefix cache score + PrefixCacheScore float64 `json:"prefix_cache_score"` Timestamp time.Time `json:"timestamp"` } @@ -126,20 +136,58 @@ type PredictionRequest struct { NumRequestWaiting int `json:"num_request_waiting"` NumRequestRunning int `json:"num_request_running"` NumTokensGenerated int `json:"num_tokens_generated"` - PrefixCacheScore float64 `json:"prefix_cache_score"` // Added prefix cache score + PrefixCacheScore float64 `json:"prefix_cache_score"` } type PredictionResponse struct { TTFT float64 `json:"ttft_ms"` TPOT float64 `json:"tpot_ms"` - TTFTUncertainty float64 `json:"ttft_uncertainty"` - TPOTUncertainty float64 `json:"tpot_uncertainty"` - TTFTPredictionBounds [2]float64 `json:"ttft_prediction_bounds"` - TPOTPredictionBounds [2]float64 `json:"tpot_prediction_bounds"` + TTFTUncertainty float64 `json:"ttft_uncertainty,omitempty"` + TPOTUncertainty float64 `json:"tpot_uncertainty,omitempty"` + TTFTPredictionBounds [2]float64 `json:"ttft_prediction_bounds,omitempty"` + TPOTPredictionBounds [2]float64 `json:"tpot_prediction_bounds,omitempty"` PredictedAt time.Time `json:"predicted_at"` ModelType string `json:"model_type"` - Quantile float64 `json:"quantile"` // Add this field - LastModelLoad *time.Time `json:"last_model_load"` // Add this field + Quantile float64 `json:"quantile"` + LastModelLoad *time.Time `json:"last_model_load"` +} + +// New data models for bulk predictions +type BulkPredictionRequest struct { + Requests []PredictionRequest `json:"requests"` +} + +type BulkPredictionResponse struct { + Predictions []PredictionResponse `json:"predictions"` + TotalRequests int `json:"total_requests"` + SuccessfulPredictions int `json:"successful_predictions"` + FailedPredictions int `json:"failed_predictions"` + ProcessingTimeMs float64 `json:"processing_time_ms"` +} + +type BulkPredictionError struct { + Index int `json:"index"` + Error string `json:"error"` + Request PredictionRequest `json:"request"` +} + +type BulkPredictionResponseWithErrors struct { + Predictions []*PredictionResponse `json:"predictions"` + Errors []BulkPredictionError `json:"errors"` + TotalRequests int `json:"total_requests"` + SuccessfulPredictions int `json:"successful_predictions"` + FailedPredictions int `json:"failed_predictions"` + ProcessingTimeMs float64 `json:"processing_time_ms"` +} + +// Server status response +type ServerStatusResponse struct { + IsReady bool `json:"is_ready"` + ModelType string `json:"model_type"` + Quantile float64 `json:"quantile"` + LastModelLoad *time.Time `json:"last_model_load"` + TrainingServerURL string `json:"training_server_url"` + ModelsExist map[string]bool `json:"models_exist"` } type ModelCoefficients struct { @@ -162,6 +210,7 @@ type BucketCounts struct { type ModelInfo struct { ModelType string `json:"model_type"` ModelStatus map[string]bool `json:"model_status"` + Quantile float64 `json:"quantile"` } type MetricsResponse struct { @@ -183,6 +232,7 @@ type Predictor struct { metricsMu sync.RWMutex cachedMetrics *MetricsResponse modelInfo *ModelInfo + serverStatus *ServerStatusResponse xgboostMu sync.RWMutex @@ -221,9 +271,14 @@ func (p *Predictor) getRandomPredictionURL() string { return p.config.PredictionURLs[index] } -// Start is a no-op for API compatibility. +// Start initializes the predictor by fetching server status and model info. func (p *Predictor) Start(ctx context.Context) error { - // Get initial model info + // Get initial server status + if err := p.refreshServerStatus(ctx); err != nil { + p.logger.Error(err, "Failed to get initial server status") + } + + // Get initial model info if training server is available if err := p.refreshModelInfo(ctx); err != nil { p.logger.Error(err, "Failed to get initial model info") } @@ -233,7 +288,8 @@ func (p *Predictor) Start(ctx context.Context) error { "prediction_urls", p.config.PredictionURLs, "max_sample_size", p.config.MaxSampleSize, "flush_interval", p.config.FlushInterval, - "use_native_xgboost", p.config.UseNativeXGBoost) + "use_native_xgboost", p.config.UseNativeXGBoost, + "max_bulk_size", p.config.MaxBulkSize) return nil } @@ -261,12 +317,70 @@ func (p *Predictor) backgroundLoop() { p.flushTraining() case <-metricsTicker.C: p.refreshMetrics() + // Also refresh server status periodically + ctx, cancel := context.WithTimeout(context.Background(), p.config.HTTPTimeout) + p.refreshServerStatus(ctx) + cancel() case <-p.done: return } } } +// GetServerStatus fetches the current status from a prediction server +func (p *Predictor) GetServerStatus(ctx context.Context) (*ServerStatusResponse, error) { + if err := p.refreshServerStatus(ctx); err != nil { + return nil, err + } + + p.metricsMu.RLock() + defer p.metricsMu.RUnlock() + + if p.serverStatus == nil { + return nil, fmt.Errorf("server status not available") + } + + return p.serverStatus, nil +} + +// refreshServerStatus gets current server status from a prediction server +func (p *Predictor) refreshServerStatus(ctx context.Context) error { + predictionURL := p.getRandomPredictionURL() + url := predictionURL + "/status" + + p.logger.V(1).Info("Fetching server status", "url", url) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create server status request: %w", err) + } + + resp, err := p.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to call /status endpoint: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("server %s returned non-200 status: %d %s, body: %s", url, resp.StatusCode, resp.Status, string(body)) + } + + var status ServerStatusResponse + if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { + return fmt.Errorf("failed to decode server status response: %w", err) + } + + p.metricsMu.Lock() + p.serverStatus = &status + p.metricsMu.Unlock() + + p.logger.V(1).Info("Retrieved server status", + "model_type", status.ModelType, + "quantile", status.Quantile, + "is_ready", status.IsReady) + return nil +} + // refreshModelInfo gets current model type and readiness info from training server func (p *Predictor) refreshModelInfo(ctx context.Context) error { url := p.config.TrainingURL + "/model/download/info" @@ -546,53 +660,201 @@ func (p *Predictor) refreshMetrics() { p.logger.Error(err, "Failed to refresh Bayesian Ridge metrics") } case "xgboost": - trees, err := p.getXGBoostTrees(ctx) - if err != nil { - p.logger.Error(err, "Failed to fetch XGBoost trees") - return - } + if p.config.UseNativeXGBoost { + // Fetch XGBoost trees for native predictions + trees, err := p.getXGBoostTrees(ctx) + if err != nil { + p.logger.Error(err, "Failed to fetch XGBoost trees") + return + } + p.metricsMu.Lock() + if p.cachedMetrics == nil { + p.cachedMetrics = &MetricsResponse{} + } + p.cachedMetrics.ModelType = modelType + p.cachedMetrics.XGBoostTrees = trees + p.metricsMu.Unlock() + + p.logger.V(1).Info("Updated XGBoost trees for native predictions") + } else { + // Just update model type for HTTP-based predictions + p.metricsMu.Lock() + if p.cachedMetrics == nil { + p.cachedMetrics = &MetricsResponse{} + } + p.cachedMetrics.ModelType = modelType + p.metricsMu.Unlock() + + p.logger.V(1).Info("Updated model type for HTTP-based predictions", "model_type", modelType) + } + case "lightgbm": + // LightGBM only supports HTTP calls, no native tree caching needed p.metricsMu.Lock() if p.cachedMetrics == nil { p.cachedMetrics = &MetricsResponse{} } p.cachedMetrics.ModelType = modelType - p.cachedMetrics.XGBoostTrees = trees p.metricsMu.Unlock() - if p.IsXGBoostReady() { - p.logger.V(1).Info("Successfully refreshed XGBoost models") - } else { - p.logger.V(1).Info("XGBoost models not ready, will use HTTP fallback") - } + p.logger.V(1).Info("Updated model type for HTTP-based predictions", "model_type", modelType) default: p.logger.Info("Unknown model type, cannot refresh metrics", "model_type", modelType) } } -// Predict uses cached coefficients (Bayesian Ridge) or XGBoost models for local prediction. +// Predict uses cached coefficients (Bayesian Ridge) or HTTP calls (XGBoost/LightGBM) for prediction. func (p *Predictor) Predict(ctx context.Context, req PredictionRequest) (*PredictionResponse, error) { + // Get current model type from server status first, fall back to model info p.metricsMu.RLock() + modelType := "" + quantile := 0.9 // default + + if p.serverStatus != nil { + modelType = p.serverStatus.ModelType + quantile = p.serverStatus.Quantile + } else if p.modelInfo != nil { + modelType = p.modelInfo.ModelType + if p.modelInfo.Quantile > 0 { + quantile = p.modelInfo.Quantile + } + } + mr := p.cachedMetrics - modelInfo := p.modelInfo p.metricsMu.RUnlock() - if modelInfo == nil { - return nil, fmt.Errorf("model info not yet available") + if modelType == "" { + return nil, fmt.Errorf("model type not yet available from server") } - switch modelInfo.ModelType { + switch modelType { case "bayesian_ridge": - return p.predictBayesianRidge(req, mr) - case "xgboost": - return p.predictXGBoostHTTP(ctx, req) + return p.predictBayesianRidge(req, mr, quantile) + case "xgboost", "lightgbm": + return p.predictHTTP(ctx, req) default: - return nil, fmt.Errorf("unsupported or unknown model type: %s", modelInfo.ModelType) + return nil, fmt.Errorf("unsupported or unknown model type: %s", modelType) + } +} + +// PredictBulk makes bulk predictions with error handling (allows partial failures) +func (p *Predictor) PredictBulk(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no prediction requests provided") + } + + if len(requests) > p.config.MaxBulkSize { + return nil, fmt.Errorf("too many requests: %d (max: %d)", len(requests), p.config.MaxBulkSize) } + + // Validate all requests first + for i, req := range requests { + if err := p.ValidatePredictionRequest(req); err != nil { + return nil, fmt.Errorf("validation failed for request %d: %w", i, err) + } + } + + payload := BulkPredictionRequest{Requests: requests} + data, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal bulk prediction request: %w", err) + } + + predictionURL := p.getRandomPredictionURL() + url := predictionURL + "/predict/bulk" + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("failed to create bulk prediction request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call bulk prediction endpoint %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("bulk prediction server returned non-200 status: %d %s, body: %s", resp.StatusCode, resp.Status, string(body)) + } + + var bulkResp BulkPredictionResponseWithErrors + if err := json.NewDecoder(resp.Body).Decode(&bulkResp); err != nil { + return nil, fmt.Errorf("failed to decode bulk prediction response: %w", err) + } + + // Convert to standard bulk response format + var predictions []PredictionResponse + for _, pred := range bulkResp.Predictions { + if pred != nil { + predictions = append(predictions, *pred) + } + } + + return &BulkPredictionResponse{ + Predictions: predictions, + TotalRequests: bulkResp.TotalRequests, + SuccessfulPredictions: bulkResp.SuccessfulPredictions, + FailedPredictions: bulkResp.FailedPredictions, + ProcessingTimeMs: bulkResp.ProcessingTimeMs, + }, nil +} + +// PredictBulkStrict makes bulk predictions that fail if any single prediction fails +func (p *Predictor) PredictBulkStrict(ctx context.Context, requests []PredictionRequest) (*BulkPredictionResponse, error) { + if len(requests) == 0 { + return nil, fmt.Errorf("no prediction requests provided") + } + + if len(requests) > p.config.MaxBulkSize { + return nil, fmt.Errorf("too many requests: %d (max: %d)", len(requests), p.config.MaxBulkSize) + } + + // Validate all requests first + for i, req := range requests { + if err := p.ValidatePredictionRequest(req); err != nil { + return nil, fmt.Errorf("validation failed for request %d: %w", i, err) + } + } + + payload := BulkPredictionRequest{Requests: requests} + data, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal bulk prediction request: %w", err) + } + + predictionURL := p.getRandomPredictionURL() + url := predictionURL + "/predict/bulk/strict" + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("failed to create bulk prediction request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call bulk prediction endpoint %s: %w", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("bulk prediction server returned non-200 status: %d %s, body: %s", resp.StatusCode, resp.Status, string(body)) + } + + var bulkResp BulkPredictionResponse + if err := json.NewDecoder(resp.Body).Decode(&bulkResp); err != nil { + return nil, fmt.Errorf("failed to decode bulk prediction response: %w", err) + } + + return &bulkResp, nil } // predictBayesianRidge uses cached coefficients for linear prediction -func (p *Predictor) predictBayesianRidge(req PredictionRequest, mr *MetricsResponse) (*PredictionResponse, error) { +func (p *Predictor) predictBayesianRidge(req PredictionRequest, mr *MetricsResponse, quantile float64) (*PredictionResponse, error) { if mr == nil || mr.Coefficients == nil { return nil, fmt.Errorf("no cached Bayesian Ridge coefficients available for prediction") } @@ -604,7 +866,7 @@ func (p *Predictor) predictBayesianRidge(req PredictionRequest, mr *MetricsRespo c.TTFTCoeffs["input_token_length"]*float64(req.InputTokenLength) + c.TTFTCoeffs["num_request_waiting"]*float64(req.NumRequestWaiting) + c.TTFTCoeffs["num_request_running"]*float64(req.NumRequestRunning) + - c.TTFTCoeffs["prefix_cache_score"]*req.PrefixCacheScore // Added prefix cache score + c.TTFTCoeffs["prefix_cache_score"]*req.PrefixCacheScore // Linear combination for TPOT (remains unchanged - no prefix cache effect) tpot := c.TPOTIntercept + @@ -619,12 +881,12 @@ func (p *Predictor) predictBayesianRidge(req PredictionRequest, mr *MetricsRespo TPOT: tpot, PredictedAt: time.Now(), ModelType: "bayesian_ridge", - Quantile: 0.9, + Quantile: quantile, }, nil } -// predictXGBoostHTTP makes an HTTP call to a randomly selected prediction server for XGBoost predictions -func (p *Predictor) predictXGBoostHTTP(ctx context.Context, req PredictionRequest) (*PredictionResponse, error) { +// predictHTTP makes an HTTP call to a randomly selected prediction server for XGBoost/LightGBM predictions +func (p *Predictor) predictHTTP(ctx context.Context, req PredictionRequest) (*PredictionResponse, error) { data, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("failed to marshal prediction request: %w", err) @@ -863,6 +1125,13 @@ func (p *Predictor) IsXGBoostReady() bool { return p.modelInfo != nil && p.modelInfo.ModelType == "xgboost" } +// IsLightGBMReady returns true if LightGBM models are available via HTTP. +func (p *Predictor) IsLightGBMReady() bool { + p.metricsMu.RLock() + defer p.metricsMu.RUnlock() + return p.modelInfo != nil && p.modelInfo.ModelType == "lightgbm" && len(p.config.PredictionURLs) > 0 +} + // IsBayesianRidgeReady returns true if Bayesian Ridge coefficients are cached. func (p *Predictor) IsBayesianRidgeReady() bool { p.metricsMu.RLock() @@ -870,24 +1139,50 @@ func (p *Predictor) IsBayesianRidgeReady() bool { return p.cachedMetrics != nil && p.cachedMetrics.Coefficients != nil } -// GetCurrentModelType returns the current model type from cached model info. +// GetCurrentModelType returns the current model type from cached server status or model info. func (p *Predictor) GetCurrentModelType() string { p.metricsMu.RLock() defer p.metricsMu.RUnlock() + + // Prefer server status if available + if p.serverStatus != nil { + return p.serverStatus.ModelType + } + if p.modelInfo == nil { return "" } return p.modelInfo.ModelType } +// GetCurrentQuantile returns the current quantile from server status or defaults to 0.9 +func (p *Predictor) GetCurrentQuantile() float64 { + p.metricsMu.RLock() + defer p.metricsMu.RUnlock() + + // Prefer server status if available + if p.serverStatus != nil && p.serverStatus.Quantile > 0 { + return p.serverStatus.Quantile + } + + if p.modelInfo != nil && p.modelInfo.Quantile > 0 { + return p.modelInfo.Quantile + } + + return 0.9 // Default quantile +} + // IsReady returns true if a prediction method is ready based on the current model type. func (p *Predictor) IsReady() bool { switch p.GetCurrentModelType() { case "bayesian_ridge": return p.IsBayesianRidgeReady() case "xgboost": - // Ready if native models are loaded OR we have prediction URLs for HTTP fallback. - return p.IsXGBoostReady() || len(p.config.PredictionURLs) > 0 + // Ready if we have prediction URLs for HTTP calls + return len(p.config.PredictionURLs) > 0 + case "lightgbm": + // Ready if we have prediction URLs for HTTP calls + return p.IsLightGBMReady() default: return false } diff --git a/pkg/epp/latencypredictorasync/latencypredictor_async_test.go b/pkg/epp/latencypredictorasync/latencypredictor_async_test.go index 1fe1dfcc6..389405c13 100644 --- a/pkg/epp/latencypredictorasync/latencypredictor_async_test.go +++ b/pkg/epp/latencypredictorasync/latencypredictor_async_test.go @@ -53,6 +53,7 @@ func TestLatencyPredictorIntegration(t *testing.T) { MetricsRefreshInterval: 1 * time.Second, // Longer for metrics UseNativeXGBoost: true, HTTPTimeout: 30 * time.Second, // Longer timeout for tests + MaxBulkSize: 50, // Test bulk size } // Create predictor @@ -68,6 +69,10 @@ func TestLatencyPredictorIntegration(t *testing.T) { t.Fatalf("Failed to start predictor: %v", err) } + t.Run("TestServerStatus", func(t *testing.T) { + testServerStatus(t, ctx, predictor) + }) + t.Run("TestModelInfo", func(t *testing.T) { testModelInfo(t, ctx, predictor) }) @@ -80,6 +85,14 @@ func TestLatencyPredictorIntegration(t *testing.T) { testPrediction(t, ctx, predictor) }) + t.Run("TestBulkPredictions", func(t *testing.T) { + testBulkPredictions(t, ctx, predictor) + }) + + t.Run("TestBulkPredictionsStrict", func(t *testing.T) { + testBulkPredictionsStrict(t, ctx, predictor) + }) + t.Run("TestPredictionWithPrefixCache", func(t *testing.T) { testPredictionWithPrefixCache(t, ctx, predictor) }) @@ -88,10 +101,18 @@ func TestLatencyPredictorIntegration(t *testing.T) { testHTTPFallbackPrediction(t, ctx, predictor) }) + t.Run("TestLightGBMSupport", func(t *testing.T) { + testLightGBMSupport(t, ctx, predictor) + }) + t.Run("TestPredictionPerformance", func(t *testing.T) { testPredictionPerformance(t, ctx, predictor) }) + t.Run("TestBulkPredictionPerformance", func(t *testing.T) { + testBulkPredictionPerformance(t, ctx, predictor) + }) + t.Run("TestHTTPOnlyPerformance", func(t *testing.T) { testHTTPOnlyPerformance(t, ctx) }) @@ -119,6 +140,42 @@ func TestLatencyPredictorIntegration(t *testing.T) { t.Run("TestPredictionConstructors", func(t *testing.T) { testPredictionConstructors(t) }) + + t.Run("TestQuantileConfiguration", func(t *testing.T) { + testQuantileConfiguration(t, ctx, predictor) + }) +} + +func testServerStatus(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing server status retrieval...") + + status, err := predictor.GetServerStatus(ctx) + if err != nil { + t.Fatalf("Failed to get server status: %v", err) + } + + t.Logf("Server Status:") + t.Logf(" Is Ready: %t", status.IsReady) + t.Logf(" Model Type: %s", status.ModelType) + t.Logf(" Quantile: %.2f", status.Quantile) + t.Logf(" Training Server URL: %s", status.TrainingServerURL) + t.Logf(" Models Exist: %v", status.ModelsExist) + + if status.ModelType == "" { + t.Error("Model type should not be empty") + } + + if status.Quantile <= 0 || status.Quantile >= 1 { + t.Errorf("Quantile should be between 0 and 1, got: %.2f", status.Quantile) + } + + // Test quantile retrieval + currentQuantile := predictor.GetCurrentQuantile() + t.Logf("Current quantile from predictor: %.2f", currentQuantile) + + if currentQuantile != status.Quantile { + t.Logf("Note: Cached quantile (%.2f) differs from server status (%.2f)", currentQuantile, status.Quantile) + } } func testModelInfo(t *testing.T, ctx context.Context, predictor *Predictor) { @@ -129,8 +186,8 @@ func testModelInfo(t *testing.T, ctx context.Context, predictor *Predictor) { t.Fatalf("Failed to get model info: %v", err) } - t.Logf("Model Info - Type: %s, Model Status: %v", - modelInfo.ModelType, modelInfo.ModelStatus) + t.Logf("Model Info - Type: %s, Model Status: %v, Quantile: %.2f", + modelInfo.ModelType, modelInfo.ModelStatus, modelInfo.Quantile) if modelInfo.ModelType == "" { t.Error("Model type should not be empty") @@ -170,8 +227,10 @@ func testPrediction(t *testing.T, ctx context.Context, predictor *Predictor) { // Log current predictor state t.Logf("Predictor state:") t.Logf(" Current model type: %s", predictor.GetCurrentModelType()) + t.Logf(" Current quantile: %.2f", predictor.GetCurrentQuantile()) t.Logf(" Overall ready: %t", predictor.IsReady()) t.Logf(" XGBoost ready: %t", predictor.IsXGBoostReady()) + t.Logf(" LightGBM ready: %t", predictor.IsLightGBMReady()) t.Logf(" Bayesian Ridge ready: %t", predictor.IsBayesianRidgeReady()) // Wait for models to be ready @@ -214,6 +273,7 @@ func testPrediction(t *testing.T, ctx context.Context, predictor *Predictor) { t.Logf(" TTFT Bounds: [%.2f, %.2f]", response.TTFTPredictionBounds[0], response.TTFTPredictionBounds[1]) t.Logf(" TPOT Bounds: [%.2f, %.2f]", response.TPOTPredictionBounds[0], response.TPOTPredictionBounds[1]) t.Logf(" Model Type: %s", response.ModelType) + t.Logf(" Quantile: %.2f", response.Quantile) t.Logf(" Predicted At: %s", response.PredictedAt.Format(time.RFC3339)) // Validate response @@ -226,6 +286,9 @@ func testPrediction(t *testing.T, ctx context.Context, predictor *Predictor) { if response.ModelType == "" { t.Error("Model type should not be empty") } + if response.Quantile <= 0 || response.Quantile >= 1 { + t.Errorf("Quantile should be between 0 and 1, got: %.2f", response.Quantile) + } // Test multiple predictions to ensure consistency t.Log("Testing multiple predictions with varying prefix cache scores...") @@ -245,8 +308,244 @@ func testPrediction(t *testing.T, ctx context.Context, predictor *Predictor) { continue } - t.Logf("Prediction %d: TTFT=%.2f, TPOT=%.2f (prefix_cache=%.1f%%)", - i+1, resp.TTFT, resp.TPOT, testReq.PrefixCacheScore*100) + t.Logf("Prediction %d: TTFT=%.2f, TPOT=%.2f (prefix_cache=%.1f%%, quantile=%.2f)", + i+1, resp.TTFT, resp.TPOT, testReq.PrefixCacheScore*100, resp.Quantile) + } +} + +func testBulkPredictions(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing bulk predictions with error tolerance...") + + if !predictor.IsReady() { + t.Skip("Predictor not ready for bulk prediction testing") + } + + // Create multiple prediction requests + requests := make([]PredictionRequest, 10) + for i := 0; i < 10; i++ { + requests[i] = PredictionRequest{ + KVCachePercentage: float64(40+i*5) / 100.0, // 40% to 85% + InputTokenLength: 200 + i*50, // 200 to 650 + NumRequestWaiting: i % 5, // 0 to 4 + NumRequestRunning: (i % 3) + 1, // 1 to 3 + NumTokensGenerated: 25 + i*10, // 25 to 115 + PrefixCacheScore: float64(i) / 9.0, // 0.0 to 1.0 + } + } + + t.Logf("Making bulk prediction request with %d requests", len(requests)) + + bulkResponse, err := predictor.PredictBulk(ctx, requests) + if err != nil { + t.Fatalf("Bulk prediction failed: %v", err) + } + + t.Logf("Bulk Prediction Response:") + t.Logf(" Total Requests: %d", bulkResponse.TotalRequests) + t.Logf(" Successful: %d", bulkResponse.SuccessfulPredictions) + t.Logf(" Failed: %d", bulkResponse.FailedPredictions) + t.Logf(" Processing Time: %.2f ms", bulkResponse.ProcessingTimeMs) + + if bulkResponse.TotalRequests != len(requests) { + t.Errorf("Expected %d total requests, got %d", len(requests), bulkResponse.TotalRequests) + } + + if bulkResponse.SuccessfulPredictions != len(bulkResponse.Predictions) { + t.Errorf("Successful count (%d) doesn't match predictions length (%d)", + bulkResponse.SuccessfulPredictions, len(bulkResponse.Predictions)) + } + + // Validate each prediction in the response + for i, prediction := range bulkResponse.Predictions { + if prediction.TTFT <= 0 { + t.Errorf("Prediction %d: TTFT should be positive, got %.2f", i, prediction.TTFT) + } + if prediction.TPOT <= 0 { + t.Errorf("Prediction %d: TPOT should be positive, got %.2f", i, prediction.TPOT) + } + if prediction.ModelType == "" { + t.Errorf("Prediction %d: Model type should not be empty", i) + } + + t.Logf(" Prediction %d: TTFT=%.2f, TPOT=%.2f, quantile=%.2f", + i+1, prediction.TTFT, prediction.TPOT, prediction.Quantile) + } + + // Test performance expectation + avgTimePerPrediction := bulkResponse.ProcessingTimeMs / float64(bulkResponse.SuccessfulPredictions) + t.Logf("Average time per prediction: %.2f ms", avgTimePerPrediction) + + if avgTimePerPrediction > 100 { // Bulk should be more efficient + t.Logf("Note: Bulk prediction averaging %.2f ms per request (may be acceptable)", avgTimePerPrediction) + } else { + t.Logf("✓ Good bulk prediction performance: %.2f ms per request", avgTimePerPrediction) + } +} + +func testBulkPredictionsStrict(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing strict bulk predictions...") + + if !predictor.IsReady() { + t.Skip("Predictor not ready for strict bulk prediction testing") + } + + // Create valid prediction requests + requests := make([]PredictionRequest, 5) + for i := 0; i < 5; i++ { + requests[i] = PredictionRequest{ + KVCachePercentage: 0.6, + InputTokenLength: 300 + i*100, + NumRequestWaiting: i, + NumRequestRunning: 1, + NumTokensGenerated: 50, + PrefixCacheScore: float64(i) / 4.0, // 0.0 to 1.0 + } + } + + t.Logf("Making strict bulk prediction request with %d requests", len(requests)) + + bulkResponse, err := predictor.PredictBulkStrict(ctx, requests) + if err != nil { + t.Fatalf("Strict bulk prediction failed: %v", err) + } + + t.Logf("Strict Bulk Prediction Response:") + t.Logf(" Total Requests: %d", bulkResponse.TotalRequests) + t.Logf(" Successful: %d", bulkResponse.SuccessfulPredictions) + t.Logf(" Failed: %d", bulkResponse.FailedPredictions) + t.Logf(" Processing Time: %.2f ms", bulkResponse.ProcessingTimeMs) + + // In strict mode, we expect all requests to succeed or the entire batch to fail + if bulkResponse.FailedPredictions > 0 { + t.Errorf("Strict bulk prediction should not have partial failures, got %d failed", + bulkResponse.FailedPredictions) + } + + if bulkResponse.SuccessfulPredictions != len(requests) { + t.Errorf("Expected all %d requests to succeed, got %d", + len(requests), bulkResponse.SuccessfulPredictions) + } + + // Test bulk size limits + t.Log("Testing bulk size limits...") + largeRequests := make([]PredictionRequest, 150) // Over the limit + for i := range largeRequests { + largeRequests[i] = requests[0] // Use a valid request template + } + + _, err = predictor.PredictBulkStrict(ctx, largeRequests) + if err == nil { + t.Error("Expected error for oversized bulk request, but got none") + } else { + t.Logf("✓ Correctly rejected oversized bulk request: %v", err) + } +} + +func testLightGBMSupport(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing LightGBM support...") + + currentModelType := predictor.GetCurrentModelType() + t.Logf("Current model type: %s", currentModelType) + + if currentModelType == "lightgbm" { + t.Log("Testing LightGBM-specific functionality...") + + // Test LightGBM readiness + isReady := predictor.IsLightGBMReady() + t.Logf("LightGBM ready: %t", isReady) + + if isReady { + // Test LightGBM prediction + req := PredictionRequest{ + KVCachePercentage: 0.7, + InputTokenLength: 400, + NumRequestWaiting: 2, + NumRequestRunning: 1, + NumTokensGenerated: 60, + PrefixCacheScore: 0.8, + } + + response, err := predictor.Predict(ctx, req) + if err != nil { + t.Errorf("LightGBM prediction failed: %v", err) + } else { + t.Logf("LightGBM prediction successful: TTFT=%.2f, TPOT=%.2f", + response.TTFT, response.TPOT) + + if response.ModelType != "lightgbm" { + t.Errorf("Expected model type 'lightgbm', got '%s'", response.ModelType) + } + } + } else { + t.Log("LightGBM not ready, skipping LightGBM-specific tests") + } + } else { + t.Logf("Current model type is %s, not LightGBM. LightGBM-specific tests skipped.", currentModelType) + } + + // Test that the client handles all model types properly + t.Log("Verifying model type handling...") + switch currentModelType { + case "bayesian_ridge": + t.Log("✓ Bayesian Ridge model type recognized") + case "xgboost": + t.Log("✓ XGBoost model type recognized") + case "lightgbm": + t.Log("✓ LightGBM model type recognized") + default: + t.Logf("⚠ Unknown model type: %s", currentModelType) + } +} + +func testQuantileConfiguration(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing quantile configuration detection...") + + // Get server status to check quantile + status, err := predictor.GetServerStatus(ctx) + if err != nil { + t.Errorf("Failed to get server status: %v", err) + return + } + + expectedQuantile := status.Quantile + currentQuantile := predictor.GetCurrentQuantile() + + t.Logf("Server quantile: %.2f", expectedQuantile) + t.Logf("Cached quantile: %.2f", currentQuantile) + + // Test that predictions use the correct quantile + req := PredictionRequest{ + KVCachePercentage: 0.6, + InputTokenLength: 300, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 50, + PrefixCacheScore: 0.7, + } + + response, err := predictor.Predict(ctx, req) + if err != nil { + t.Errorf("Prediction failed: %v", err) + return + } + + t.Logf("Prediction quantile: %.2f", response.Quantile) + + // The response quantile should match the server's quantile configuration + if abs(response.Quantile-expectedQuantile) > 0.01 { + t.Errorf("Response quantile (%.2f) doesn't match expected (%.2f)", + response.Quantile, expectedQuantile) + } else { + t.Log("✓ Prediction correctly uses server's quantile configuration") + } + + // Test common quantile values + commonQuantiles := []float64{0.5, 0.8, 0.9, 0.95} + for _, q := range commonQuantiles { + if abs(expectedQuantile-q) < 0.01 { + t.Logf("✓ Using common quantile value: %.0f%%", q*100) + break + } } } @@ -268,6 +567,7 @@ func testPredictionWithPrefixCache(t *testing.T, ctx context.Context, predictor prefixCacheScores := []float64{0.0, 0.2, 0.4, 0.6, 0.8, 1.0} var ttftResults []float64 + var quantileResults []float64 for _, prefixScore := range prefixCacheScores { req := baseRequest @@ -280,8 +580,21 @@ func testPredictionWithPrefixCache(t *testing.T, ctx context.Context, predictor } ttftResults = append(ttftResults, response.TTFT) - t.Logf("Prefix cache %.0f%%: TTFT=%.2f ms, TPOT=%.2f ms", - prefixScore*100, response.TTFT, response.TPOT) + quantileResults = append(quantileResults, response.Quantile) + t.Logf("Prefix cache %.0f%%: TTFT=%.2f ms, TPOT=%.2f ms, quantile=%.2f", + prefixScore*100, response.TTFT, response.TPOT, response.Quantile) + } + + // Verify quantile consistency + if len(quantileResults) > 1 { + firstQuantile := quantileResults[0] + for i, q := range quantileResults { + if abs(q-firstQuantile) > 0.01 { + t.Errorf("Quantile inconsistency: prediction %d has quantile %.2f, expected %.2f", + i, q, firstQuantile) + } + } + t.Log("✓ Quantile values consistent across predictions") } // Analyze the relationship between prefix cache and TTFT @@ -307,12 +620,11 @@ func testPredictionWithPrefixCache(t *testing.T, ctx context.Context, predictor } func testHTTPFallbackPrediction(t *testing.T, ctx context.Context, predictor *Predictor) { - t.Log("Testing HTTP fallback prediction when native XGBoost fails...") + t.Log("Testing HTTP fallback prediction...") - // Since we know XGBoost native parsing failed from the logs, - // the predictor should fall back to HTTP predictions - if predictor.GetCurrentModelType() != "xgboost" { - t.Skip("This test is specific to XGBoost model type") + modelType := predictor.GetCurrentModelType() + if modelType == "bayesian_ridge" { + t.Skip("HTTP fallback test not applicable for Bayesian Ridge") } // Test prediction with HTTP fallback including prefix cache score @@ -325,17 +637,18 @@ func testHTTPFallbackPrediction(t *testing.T, ctx context.Context, predictor *Pr PrefixCacheScore: 0.9, // 90% prefix cache hit rate } - t.Logf("Making HTTP fallback prediction request: %+v", req) + t.Logf("Making HTTP prediction request: %+v", req) response, err := predictor.Predict(ctx, req) if err != nil { - t.Fatalf("HTTP fallback prediction failed: %v", err) + t.Fatalf("HTTP prediction failed: %v", err) } - t.Logf("HTTP Fallback Prediction Response:") + t.Logf("HTTP Prediction Response:") t.Logf(" TTFT: %.2f ms", response.TTFT) t.Logf(" TPOT: %.2f ms", response.TPOT) t.Logf(" Model Type: %s", response.ModelType) + t.Logf(" Quantile: %.2f", response.Quantile) t.Logf(" Prefix Cache Score Used: %.1f%%", req.PrefixCacheScore*100) // Validate that we got a reasonable response @@ -346,12 +659,12 @@ func testHTTPFallbackPrediction(t *testing.T, ctx context.Context, predictor *Pr t.Error("TPOT should be positive") } - // The model type should indicate it's using XGBoost (likely "xgboost" from HTTP) + // The model type should indicate the correct type if response.ModelType == "" { t.Error("Model type should not be empty") } - t.Logf("Successfully tested HTTP fallback prediction with prefix cache") + t.Logf("Successfully tested HTTP prediction with prefix cache") } func testPredictionPerformance(t *testing.T, ctx context.Context, predictor *Predictor) { @@ -415,8 +728,8 @@ func testPredictionPerformance(t *testing.T, ctx context.Context, predictor *Pre } durationMs := float64(duration.Nanoseconds()) / 1e6 - t.Logf("Prediction %d: %.2fms - TTFT: %.1fms, TPOT: %.1fms (prefix: %.0f%%)", - i+1, durationMs, response.TTFT, response.TPOT, testReq.PrefixCacheScore*100) + t.Logf("Prediction %d: %.2fms - TTFT: %.1fms, TPOT: %.1fms (prefix: %.0f%%, quantile: %.2f)", + i+1, durationMs, response.TTFT, response.TPOT, testReq.PrefixCacheScore*100, response.Quantile) } // Calculate statistics @@ -446,6 +759,91 @@ func testPredictionPerformance(t *testing.T, ctx context.Context, predictor *Pre } } +func testBulkPredictionPerformance(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing bulk prediction performance...") + + if !predictor.IsReady() { + t.Skip("Predictor not ready for bulk performance test") + } + + // Create batch of prediction requests + const batchSize = 20 + requests := make([]PredictionRequest, batchSize) + for i := 0; i < batchSize; i++ { + requests[i] = PredictionRequest{ + KVCachePercentage: 0.6 + float64(i%5)*0.05, // Vary between 0.6 and 0.8 + InputTokenLength: 300 + i*10, // Vary input length + NumRequestWaiting: i % 4, // 0 to 3 + NumRequestRunning: (i % 2) + 1, // 1 to 2 + NumTokensGenerated: 50 + i*2, // Vary generated tokens + PrefixCacheScore: float64(i) / float64(batchSize-1), // 0.0 to 1.0 + } + } + + // Warm up + warmupRequests := requests[:3] + _, err := predictor.PredictBulk(ctx, warmupRequests) + if err != nil { + t.Fatalf("Warmup bulk prediction failed: %v", err) + } + + // Performance test + const numTests = 5 + var totalDuration time.Duration + var totalRequests int + var totalSuccessful int + + t.Logf("Running %d bulk prediction performance tests with %d requests each...", numTests, batchSize) + + for i := 0; i < numTests; i++ { + start := time.Now() + + response, err := predictor.PredictBulk(ctx, requests) + + duration := time.Since(start) + totalDuration += duration + + if err != nil { + t.Errorf("Bulk prediction %d failed: %v", i+1, err) + continue + } + + totalRequests += response.TotalRequests + totalSuccessful += response.SuccessfulPredictions + + durationMs := float64(duration.Nanoseconds()) / 1e6 + avgPerRequest := durationMs / float64(response.SuccessfulPredictions) + + t.Logf("Bulk test %d: %.2fms total, %.2fms per request (%d/%d successful)", + i+1, durationMs, avgPerRequest, response.SuccessfulPredictions, response.TotalRequests) + } + + // Calculate bulk performance statistics + avgTotalDuration := totalDuration / numTests + avgTotalMs := float64(avgTotalDuration.Nanoseconds()) / 1e6 + avgPerRequest := avgTotalMs / float64(batchSize) + + t.Logf("Bulk Performance Results:") + t.Logf(" Average total time: %.2fms", avgTotalMs) + t.Logf(" Average per request: %.2fms", avgPerRequest) + t.Logf(" Success rate: %.1f%%", float64(totalSuccessful)/float64(totalRequests)*100) + + // Compare with single prediction performance target + singlePredictionTarget := 250.0 // ms + bulkEfficiencyThreshold := singlePredictionTarget * 0.7 // Bulk should be more efficient + + if avgPerRequest <= bulkEfficiencyThreshold { + t.Logf("✅ Bulk predictions are efficient: %.2fms per request < %.2fms threshold", + avgPerRequest, bulkEfficiencyThreshold) + } else if avgPerRequest <= singlePredictionTarget { + t.Logf("✓ Bulk predictions acceptable: %.2fms per request < %.2fms single target", + avgPerRequest, singlePredictionTarget) + } else { + t.Errorf("❌ Bulk predictions slow: %.2fms per request > %.2fms target", + avgPerRequest, singlePredictionTarget) + } +} + func testHTTPOnlyPerformance(t *testing.T, ctx context.Context) { t.Log("Testing HTTP-only prediction performance (no native XGBoost interference) with prefix cache...") @@ -485,6 +883,7 @@ func testHTTPOnlyPerformance(t *testing.T, ctx context.Context) { MetricsRefreshInterval: 1 * time.Second, // Longer for metrics UseNativeXGBoost: false, // Force HTTP-only HTTPTimeout: 5 * time.Second, // Reasonable timeout + MaxBulkSize: 50, } httpPredictor := New(httpOnlyConfig, logger) @@ -563,8 +962,8 @@ func testHTTPOnlyPerformance(t *testing.T, ctx context.Context) { status := "✅" - t.Logf("%s Test %d: %.1fms (TTFT: %.0fms, TPOT: %.0fms, prefix: %.0f%%)", - status, i+1, durationMs, response.TTFT, response.TPOT, testReq.PrefixCacheScore*100) + t.Logf("%s Test %d: %.1fms (TTFT: %.0fms, TPOT: %.0fms, prefix: %.0f%%, quantile: %.2f)", + status, i+1, durationMs, response.TTFT, response.TPOT, testReq.PrefixCacheScore*100, response.Quantile) } // Calculate statistics @@ -665,6 +1064,7 @@ func testHTTPOnlyPrediction(t *testing.T, ctx context.Context) { MetricsRefreshInterval: 1 * time.Second, // Longer for metrics UseNativeXGBoost: false, // Force HTTP fallback HTTPTimeout: 30 * time.Second, + MaxBulkSize: 25, } httpPredictor := New(httpOnlyConfig, logger) @@ -714,6 +1114,7 @@ func testHTTPOnlyPrediction(t *testing.T, ctx context.Context) { t.Logf(" TTFT: %.2f ms", response.TTFT) t.Logf(" TPOT: %.2f ms", response.TPOT) t.Logf(" Model Type: %s", response.ModelType) + t.Logf(" Quantile: %.2f", response.Quantile) t.Logf(" TTFT Uncertainty: %.2f", response.TTFTUncertainty) t.Logf(" TPOT Uncertainty: %.2f", response.TPOTUncertainty) t.Logf(" Prefix Cache Score Used: %.1f%%", req.PrefixCacheScore*100) @@ -744,8 +1145,8 @@ func testHTTPOnlyPrediction(t *testing.T, ctx context.Context) { continue } - t.Logf("HTTP-only prediction %d: TTFT=%.2f, TPOT=%.2f (prefix: %.0f%%)", - i+1, resp.TTFT, resp.TPOT, testReq.PrefixCacheScore*100) + t.Logf("HTTP-only prediction %d: TTFT=%.2f, TPOT=%.2f (prefix: %.0f%%, quantile: %.2f)", + i+1, resp.TTFT, resp.TPOT, testReq.PrefixCacheScore*100, resp.Quantile) } t.Log("Successfully tested HTTP-only predictions with prefix cache") @@ -785,8 +1186,8 @@ func testLoadBalancing(t *testing.T, ctx context.Context, predictor *Predictor) } successfulPredictions++ - t.Logf("Prediction %d: TTFT=%.2f, TPOT=%.2f (prefix: %.0f%%)", - i+1, response.TTFT, response.TPOT, testReq.PrefixCacheScore*100) + t.Logf("Prediction %d: TTFT=%.2f, TPOT=%.2f (prefix: %.0f%%, quantile: %.2f)", + i+1, response.TTFT, response.TPOT, testReq.PrefixCacheScore*100, response.Quantile) } successRate := float64(successfulPredictions) / float64(numPredictions) * 100 @@ -1049,13 +1450,16 @@ func testMetricsRetrieval(t *testing.T, ctx context.Context, predictor *Predicto t.Log("Testing metrics retrieval...") modelType := predictor.GetCurrentModelType() - t.Logf("Testing metrics for model type: %s", modelType) + quantile := predictor.GetCurrentQuantile() + t.Logf("Testing metrics for model type: %s, quantile: %.2f", modelType, quantile) switch modelType { case "bayesian_ridge": testBayesianRidgeMetrics(t, ctx, predictor) case "xgboost": testXGBoostMetrics(t, ctx, predictor) + case "lightgbm": + testLightGBMMetrics(t, ctx, predictor) default: t.Logf("Unknown model type %s, testing cached metrics only", modelType) } @@ -1075,6 +1479,7 @@ func testMetricsRetrieval(t *testing.T, ctx context.Context, predictor *Predicto t.Logf("Predictor readiness status:") t.Logf(" Overall Ready: %t", predictor.IsReady()) t.Logf(" XGBoost Ready: %t", predictor.IsXGBoostReady()) + t.Logf(" LightGBM Ready: %t", predictor.IsLightGBMReady()) t.Logf(" Bayesian Ridge Ready: %t", predictor.IsBayesianRidgeReady()) } @@ -1166,6 +1571,34 @@ func testXGBoostMetrics(t *testing.T, ctx context.Context, predictor *Predictor) } } +func testLightGBMMetrics(t *testing.T, ctx context.Context, predictor *Predictor) { + t.Log("Testing LightGBM specific metrics...") + + // For LightGBM, we primarily use HTTP calls, so test the HTTP connectivity + if predictor.IsLightGBMReady() { + t.Log("LightGBM models are ready via HTTP") + + // Test a simple prediction to ensure the HTTP endpoint works + req := PredictionRequest{ + KVCachePercentage: 0.6, + InputTokenLength: 300, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 50, + PrefixCacheScore: 0.7, + } + + _, err := predictor.Predict(ctx, req) + if err != nil { + t.Errorf("LightGBM test prediction failed: %v", err) + } else { + t.Log("✓ LightGBM HTTP prediction working") + } + } else { + t.Log("LightGBM models not ready") + } +} + // generateTrainingEntries creates random training data for testing with prefix cache scores func generateTrainingEntries(count int) []TrainingEntry { entries := make([]TrainingEntry, count) @@ -1233,6 +1666,7 @@ func BenchmarkPrediction(b *testing.B) { MetricsRefreshInterval: 1 * time.Second, UseNativeXGBoost: true, HTTPTimeout: 10 * time.Second, + MaxBulkSize: 100, } predictor := New(config, logger) @@ -1269,6 +1703,77 @@ func BenchmarkPrediction(b *testing.B) { }) } +// Benchmark test for bulk prediction performance +func BenchmarkBulkPrediction(b *testing.B) { + predictionURLs := os.Getenv("PREDICTION_SERVER_URL") + trainingURL := os.Getenv("TRAINING_SERVER_URL") + if predictionURLs == "" { + b.Skip("PREDICTION_SERVER_URL not set, skipping benchmark") + } + if trainingURL == "" { + urls := strings.Split(predictionURLs, ",") + if len(urls) > 0 { + trainingURL = strings.TrimSpace(urls[0]) + } else { + b.Skip("No valid URLs available for benchmarking") + } + } + + var parsedPredictionURLs []string + for _, url := range strings.Split(predictionURLs, ",") { + parsedPredictionURLs = append(parsedPredictionURLs, strings.TrimSpace(url)) + } + + logger := logr.Discard() + config := &Config{ + TrainingURL: trainingURL, + PredictionURLs: parsedPredictionURLs, + MaxSampleSize: 1000, + FlushInterval: 1 * time.Second, + MetricsRefreshInterval: 1 * time.Second, + UseNativeXGBoost: true, + HTTPTimeout: 10 * time.Second, + MaxBulkSize: 50, + } + + predictor := New(config, logger) + defer predictor.Stop() + + ctx := context.Background() + predictor.Start(ctx) + + for i := 0; i < 100; i++ { + if predictor.IsReady() { + break + } + time.Sleep(100 * time.Millisecond) + } + + // Create batch of requests + const batchSize = 20 + requests := make([]PredictionRequest, batchSize) + for i := 0; i < batchSize; i++ { + requests[i] = PredictionRequest{ + KVCachePercentage: 0.6 + float64(i%5)*0.05, + InputTokenLength: 300 + i*10, + NumRequestWaiting: i % 4, + NumRequestRunning: (i % 2) + 1, + NumTokensGenerated: 50 + i*2, + PrefixCacheScore: float64(i) / float64(batchSize-1), + } + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := predictor.PredictBulk(ctx, requests) + if err != nil { + b.Errorf("Bulk prediction failed: %v", err) + } + } + }) +} + // Test to verify config loading from environment func TestConfigFromEnv(t *testing.T) { // Save original env vars @@ -1278,6 +1783,7 @@ func TestConfigFromEnv(t *testing.T) { originalInterval := os.Getenv("LATENCY_FLUSH_INTERVAL_SEC") originalNative := os.Getenv("LATENCY_USE_NATIVE_XGBOOST") originalTimeout := os.Getenv("LATENCY_HTTP_TIMEOUT_SEC") + originalBulkSize := os.Getenv("LATENCY_MAX_BULK_SIZE") // Set test env vars os.Setenv("PREDICTION_SERVER_URL", "http://pred1.example.com,http://pred2.example.com,http://pred3.example.com") @@ -1286,6 +1792,7 @@ func TestConfigFromEnv(t *testing.T) { os.Setenv("LATENCY_FLUSH_INTERVAL_SEC", "5") os.Setenv("LATENCY_USE_NATIVE_XGBOOST", "false") os.Setenv("LATENCY_HTTP_TIMEOUT_SEC", "20") + os.Setenv("LATENCY_MAX_BULK_SIZE", "75") defer func() { // Restore original env vars (handle empty strings properly) @@ -1319,6 +1826,11 @@ func TestConfigFromEnv(t *testing.T) { } else { os.Unsetenv("LATENCY_HTTP_TIMEOUT_SEC") } + if originalBulkSize != "" { + os.Setenv("LATENCY_MAX_BULK_SIZE", originalBulkSize) + } else { + os.Unsetenv("LATENCY_MAX_BULK_SIZE") + } }() config := ConfigFromEnv() @@ -1359,441 +1871,130 @@ func TestConfigFromEnv(t *testing.T) { if config.HTTPTimeout != 20*time.Second { t.Errorf("Expected HTTPTimeout to be 20s, got %v", config.HTTPTimeout) } + if config.MaxBulkSize != 75 { + t.Errorf("Expected MaxBulkSize to be 75, got %d", config.MaxBulkSize) + } } -// Test URL parsing edge cases -func TestConfigURLParsing(t *testing.T) { - tests := []struct { - name string - latencyServerURL string - trainingServerURL string - expectedPredictionURLs []string - expectedTrainingURL string - }{ - { - name: "Single prediction URL", - latencyServerURL: "http://localhost:8001", - trainingServerURL: "http://localhost:8000", - expectedPredictionURLs: []string{"http://localhost:8001"}, - expectedTrainingURL: "http://localhost:8000", - }, - { - name: "Multiple prediction URLs with spaces", - latencyServerURL: "http://localhost:8001, http://localhost:8002 ,http://localhost:8003", - trainingServerURL: "http://localhost:8000", - expectedPredictionURLs: []string{"http://localhost:8001", "http://localhost:8002", "http://localhost:8003"}, - expectedTrainingURL: "http://localhost:8000", - }, - { - name: "Empty training URL with prediction URLs", - latencyServerURL: "http://localhost:8001,http://localhost:8002", - trainingServerURL: "", - expectedPredictionURLs: []string{"http://localhost:8001", "http://localhost:8002"}, - expectedTrainingURL: "http://localhost:8000", // Should use default - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Save original env vars - originalLatencyURL := os.Getenv("PREDICTION_SERVER_URL") - originalTrainingURL := os.Getenv("TRAINING_SERVER_URL") - - // Set test env vars - os.Setenv("PREDICTION_SERVER_URL", tt.latencyServerURL) - if tt.trainingServerURL != "" { - os.Setenv("TRAINING_SERVER_URL", tt.trainingServerURL) - } else { - os.Unsetenv("TRAINING_SERVER_URL") - } - - defer func() { - // Restore original env vars - if originalLatencyURL != "" { - os.Setenv("PREDICTION_SERVER_URL", originalLatencyURL) - } else { - os.Unsetenv("PREDICTION_SERVER_URL") - } - if originalTrainingURL != "" { - os.Setenv("TRAINING_SERVER_URL", originalTrainingURL) - } else { - os.Unsetenv("TRAINING_SERVER_URL") - } - }() - - config := ConfigFromEnv() - - // Check prediction URLs - if len(config.PredictionURLs) != len(tt.expectedPredictionURLs) { - t.Errorf("Expected %d prediction URLs, got %d", len(tt.expectedPredictionURLs), len(config.PredictionURLs)) - } - for i, expected := range tt.expectedPredictionURLs { - if i >= len(config.PredictionURLs) || config.PredictionURLs[i] != expected { - t.Errorf("Expected PredictionURLs[%d] to be '%s', got '%s'", i, expected, config.PredictionURLs[i]) - } - } - - // Check training URL - if config.TrainingURL != tt.expectedTrainingURL { - t.Errorf("Expected TrainingURL to be '%s', got '%s'", tt.expectedTrainingURL, config.TrainingURL) - } - }) +// Helper function for absolute value +func abs(x float64) float64 { + if x < 0 { + return -x } + return x } -// Test prefix cache score impact on training data generation -func TestTrainingDataWithPrefixCache(t *testing.T) { - t.Log("Testing training data generation with prefix cache scores...") - - entries := generateTrainingEntries(100) +// Test comprehensive bulk prediction functionality +func TestBulkPredictionValidation(t *testing.T) { + t.Log("Testing bulk prediction validation...") - // Validate all entries have prefix cache scores - for i, entry := range entries { - if entry.PrefixCacheScore < 0.0 || entry.PrefixCacheScore > 1.0 { - t.Errorf("Entry %d has invalid prefix cache score: %.3f", i, entry.PrefixCacheScore) - } + predictor := &Predictor{ + config: &Config{MaxBulkSize: 5}, } - // Check that prefix cache scores vary - var prefixScores []float64 - for _, entry := range entries { - prefixScores = append(prefixScores, entry.PrefixCacheScore) - } - - // Calculate variance to ensure we have variety - var sum, mean, variance float64 - for _, score := range prefixScores { - sum += score - } - mean = sum / float64(len(prefixScores)) - - for _, score := range prefixScores { - variance += (score - mean) * (score - mean) - } - variance /= float64(len(prefixScores)) - - t.Logf("Prefix cache score statistics:") - t.Logf(" Mean: %.3f", mean) - t.Logf(" Variance: %.3f", variance) - t.Logf(" Range: [%.3f, %.3f]", 0.0, 1.0) - - if variance < 0.05 { - t.Error("Prefix cache scores should have more variance for good training data") + // Test empty request list + _, err := predictor.PredictBulk(context.Background(), []PredictionRequest{}) + if err == nil { + t.Error("Expected error for empty request list") } else { - t.Log("✓ Good variance in prefix cache scores") + t.Logf("✓ Correctly rejected empty request list: %v", err) } - // Verify the training equation includes prefix cache impact - // Check that entries with higher prefix cache tend to have higher TTFT - // (based on our training equation: ttft includes +30*prefixCache) - - // Sort by prefix cache score - type entryWithIndex struct { - entry TrainingEntry - index int - } - - var sortedEntries []entryWithIndex - for i, entry := range entries { - sortedEntries = append(sortedEntries, entryWithIndex{entry, i}) - } - - // Simple sort by prefix cache score - for i := 0; i < len(sortedEntries)-1; i++ { - for j := i + 1; j < len(sortedEntries); j++ { - if sortedEntries[i].entry.PrefixCacheScore > sortedEntries[j].entry.PrefixCacheScore { - sortedEntries[i], sortedEntries[j] = sortedEntries[j], sortedEntries[i] - } + // Test oversized request list + oversizedRequests := make([]PredictionRequest, 10) // Over the limit of 5 + for i := range oversizedRequests { + oversizedRequests[i] = PredictionRequest{ + KVCachePercentage: 0.5, + InputTokenLength: 100, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 10, + PrefixCacheScore: 0.5, } } - // Compare low vs high prefix cache entries - lowPrefixCount := len(sortedEntries) / 4 - highPrefixStart := len(sortedEntries) * 3 / 4 - - var lowPrefixTTFT, highPrefixTTFT float64 - for i := 0; i < lowPrefixCount; i++ { - lowPrefixTTFT += sortedEntries[i].entry.ActualTTFT - } - lowPrefixTTFT /= float64(lowPrefixCount) - - highPrefixCount := len(sortedEntries) - highPrefixStart - for i := highPrefixStart; i < len(sortedEntries); i++ { - highPrefixTTFT += sortedEntries[i].entry.ActualTTFT - } - highPrefixTTFT /= float64(highPrefixCount) - - ttftDifference := highPrefixTTFT - lowPrefixTTFT - - t.Logf("TTFT impact analysis:") - t.Logf(" Low prefix cache TTFT avg: %.2f ms", lowPrefixTTFT) - t.Logf(" High prefix cache TTFT avg: %.2f ms", highPrefixTTFT) - t.Logf(" Difference: %.2f ms", ttftDifference) - - if ttftDifference > 10 { - t.Log("✓ Prefix cache score appears to positively impact TTFT in training data") + _, err = predictor.PredictBulk(context.Background(), oversizedRequests) + if err == nil { + t.Error("Expected error for oversized request list") } else { - t.Log("ℹ Small or no prefix cache impact detected (may be due to noise)") + t.Logf("✓ Correctly rejected oversized request list: %v", err) } - t.Log("✅ Training data with prefix cache validation completed") -} - -// Test prediction request validation edge cases -func TestPredictionValidationEdgeCases(t *testing.T) { - t.Log("Testing prediction validation edge cases with prefix cache...") - - predictor := &Predictor{} // Temporary predictor for validation - - testCases := []struct { - name string - req PredictionRequest - shouldErr bool - errorMsg string - }{ + // Test invalid request in the list + invalidRequests := []PredictionRequest{ { - name: "Valid minimum values", - req: PredictionRequest{ - KVCachePercentage: 0.0, - InputTokenLength: 0, - NumRequestWaiting: 0, - NumRequestRunning: 0, - NumTokensGenerated: 0, - PrefixCacheScore: 0.0, - }, - shouldErr: false, - }, - { - name: "Valid maximum values", - req: PredictionRequest{ - KVCachePercentage: 1.0, - InputTokenLength: 10000, - NumRequestWaiting: 100, - NumRequestRunning: 50, - NumTokensGenerated: 1000, - PrefixCacheScore: 1.0, - }, - shouldErr: false, - }, - { - name: "Invalid negative prefix cache", - req: PredictionRequest{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - PrefixCacheScore: -0.001, - }, - shouldErr: true, - errorMsg: "prefix_cache_score must be between 0.0 and 1.0", - }, - { - name: "Invalid high prefix cache", - req: PredictionRequest{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - PrefixCacheScore: 1.001, - }, - shouldErr: true, - errorMsg: "prefix_cache_score must be between 0.0 and 1.0", - }, - { - name: "Invalid negative KV cache with valid prefix cache", - req: PredictionRequest{ - KVCachePercentage: -0.1, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - PrefixCacheScore: 0.8, - }, - shouldErr: true, - errorMsg: "kv_cache_percentage must be between 0.0 and 1.0", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := predictor.ValidatePredictionRequest(tc.req) - - if tc.shouldErr { - if err == nil { - t.Errorf("Expected validation error for %s, but got none", tc.name) - } else if !strings.Contains(err.Error(), tc.errorMsg) { - t.Errorf("Expected error message to contain '%s', got: %v", tc.errorMsg, err) - } else { - t.Logf("✓ Correctly rejected %s: %v", tc.name, err) - } - } else { - if err != nil { - t.Errorf("Expected no validation error for %s, but got: %v", tc.name, err) - } else { - t.Logf("✓ Correctly accepted %s", tc.name) - } - } - }) - } - - t.Log("✅ Prediction validation edge cases completed") -} - -// Test training entry validation edge cases -func TestTrainingValidationEdgeCases(t *testing.T) { - t.Log("Testing training entry validation edge cases with prefix cache...") - - predictor := &Predictor{} // Temporary predictor for validation - - testCases := []struct { - name string - entry TrainingEntry - shouldErr bool - errorMsg string - }{ - { - name: "Valid entry with prefix cache", - entry: TrainingEntry{ - KVCachePercentage: 0.6, - InputTokenLength: 200, - NumRequestWaiting: 2, - NumRequestRunning: 1, - NumTokensGenerated: 20, - ActualTTFT: 45.5, - ActualTPOT: 12.3, - PrefixCacheScore: 0.8, - Timestamp: time.Now(), - }, - shouldErr: false, - }, - { - name: "Zero prefix cache score", - entry: TrainingEntry{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - ActualTTFT: 30.0, - ActualTPOT: 8.0, - PrefixCacheScore: 0.0, // Valid minimum - Timestamp: time.Now(), - }, - shouldErr: false, - }, - { - name: "Maximum prefix cache score", - entry: TrainingEntry{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - ActualTTFT: 30.0, - ActualTPOT: 8.0, - PrefixCacheScore: 1.0, // Valid maximum - Timestamp: time.Now(), - }, - shouldErr: false, - }, - { - name: "Invalid negative prefix cache", - entry: TrainingEntry{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - ActualTTFT: 30.0, - ActualTPOT: 8.0, - PrefixCacheScore: -0.1, - Timestamp: time.Now(), - }, - shouldErr: true, - errorMsg: "prefix_cache_score must be between 0.0 and 1.0", + KVCachePercentage: 0.5, + InputTokenLength: 100, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 10, + PrefixCacheScore: 0.5, }, { - name: "Invalid high prefix cache", - entry: TrainingEntry{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - ActualTTFT: 30.0, - ActualTPOT: 8.0, - PrefixCacheScore: 1.5, - Timestamp: time.Now(), - }, - shouldErr: true, - errorMsg: "prefix_cache_score must be between 0.0 and 1.0", + KVCachePercentage: 1.5, // Invalid + InputTokenLength: 100, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 10, + PrefixCacheScore: 0.5, }, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := predictor.ValidateTrainingEntry(tc.entry) - - if tc.shouldErr { - if err == nil { - t.Errorf("Expected validation error for %s, but got none", tc.name) - } else if !strings.Contains(err.Error(), tc.errorMsg) { - t.Errorf("Expected error message to contain '%s', got: %v", tc.errorMsg, err) - } else { - t.Logf("✓ Correctly rejected %s: %v", tc.name, err) - } - } else { - if err != nil { - t.Errorf("Expected no validation error for %s, but got: %v", tc.name, err) - } else { - t.Logf("✓ Correctly accepted %s", tc.name) - } - } - }) + _, err = predictor.PredictBulk(context.Background(), invalidRequests) + if err == nil { + t.Error("Expected error for invalid request in list") + } else { + t.Logf("✓ Correctly rejected list with invalid request: %v", err) } - t.Log("✅ Training validation edge cases completed") + t.Log("✅ Bulk prediction validation tests completed") } -// Test comprehensive prefix cache feature integration -func TestPrefixCacheFeatureIntegration(t *testing.T) { - t.Log("Testing comprehensive prefix cache feature integration...") +// Test comprehensive prefix cache integration +func TestPrefixCacheIntegration(t *testing.T) { + t.Log("Testing comprehensive prefix cache integration...") - // Test that all components work together with prefix cache + // Test all components work together with prefix cache zapLog, err := zap.NewDevelopment() if err != nil { t.Fatalf("Failed to create logger: %v", err) } logger := zapr.NewLogger(zapLog) - // Create a minimal config for testing + // Use a minimal config that doesn't require network calls config := &Config{ TrainingURL: "http://mock-training.local", PredictionURLs: []string{"http://mock-prediction.local"}, MaxSampleSize: 100, - FlushInterval: 10 * time.Second, // Long interval for testing - MetricsRefreshInterval: 10 * time.Second, + FlushInterval: 1 * time.Hour, // Very long interval to avoid network calls + MetricsRefreshInterval: 1 * time.Hour, // Very long interval to avoid network calls UseNativeXGBoost: false, HTTPTimeout: 5 * time.Second, + MaxBulkSize: 10, } predictor := New(config, logger) - defer predictor.Stop() - // Test that training entries with prefix cache can be created - entries := make([]TrainingEntry, 10) - for i := 0; i < 10; i++ { + // Manually stop background processes without triggering final flush/refresh + defer func() { + // Stop background loop without calling Stop() which does final flush + close(predictor.done) + predictor.wg.Wait() + t.Log("Background processes stopped without network calls") + }() + + // Test training entries with prefix cache can be created and validated + entries := make([]TrainingEntry, 5) + for i := 0; i < 5; i++ { entry, err := NewTrainingEntry( float64(i)/10.0, // kv_cache_percentage 100+i*50, // input_token_length - i%5, // num_request_waiting - (i%3)+1, // num_request_running + i%3, // num_request_waiting + 1, // num_request_running (always > 0) 10+i*5, // num_tokens_generated 50.0+float64(i)*5, // actual_ttft_ms 10.0+float64(i)*2, // actual_tpot_ms - float64(i)/9.0, // prefix_cache_score (0.0 to 1.0) + float64(i)/4.0, // prefix_cache_score (0.0 to 1.0) ) if err != nil { t.Fatalf("Failed to create training entry %d: %v", i, err) @@ -1804,284 +2005,295 @@ func TestPrefixCacheFeatureIntegration(t *testing.T) { i, entry.PrefixCacheScore*100, entry.ActualTTFT, entry.ActualTPOT) } - // Test that training entries can be added to predictor + // Add training data to buffer (won't flush due to long interval) err = predictor.AddTrainingDataBulk(entries) if err != nil { - t.Fatalf("Failed to add training entries with prefix cache: %v", err) + t.Fatalf("Failed to add training entries: %v", err) } - t.Log("✓ Successfully added training entries with prefix cache scores") + t.Log("✓ Successfully added training entries with prefix cache scores to buffer") - // Test that prediction requests with prefix cache can be created - for i := 0; i < 5; i++ { + // Test prediction requests with prefix cache can be created and validated + requests := make([]PredictionRequest, 3) + for i := 0; i < 3; i++ { req, err := NewPredictionRequest( - float64(i*20)/100.0, // kv_cache_percentage: 0%, 20%, 40%, 60%, 80% + float64(i*20)/100.0, // kv_cache_percentage 200+i*100, // input_token_length - i%4, // num_request_waiting - (i%2)+1, // num_request_running + i%2, // num_request_waiting + 1, // num_request_running (always > 0) 20+i*10, // num_tokens_generated - float64(i)/4.0, // prefix_cache_score: 0.0, 0.25, 0.5, 0.75, 1.0 + float64(i)/2.0, // prefix_cache_score ) if err != nil { t.Fatalf("Failed to create prediction request %d: %v", i, err) } + requests[i] = req - t.Logf("Request %d: prefix_cache=%.1f%%, kv_cache=%.1f%%, input_len=%d", - i, req.PrefixCacheScore*100, req.KVCachePercentage*100, req.InputTokenLength) - - // Validate the request err = predictor.ValidatePredictionRequest(req) if err != nil { t.Errorf("Valid prediction request %d failed validation: %v", i, err) } + + t.Logf("Request %d: prefix_cache=%.1f%%, kv_cache=%.1f%%, input_len=%d", + i, req.PrefixCacheScore*100, req.KVCachePercentage*100, req.InputTokenLength) } - t.Log("✓ Successfully created and validated prediction requests with prefix cache scores") + t.Log("✓ Successfully created and validated prediction requests with prefix cache") - // Test validation edge cases work correctly - testCases := []struct { + // Test validation edge cases + edgeCases := []struct { name string prefixCache float64 shouldPass bool }{ {"Zero prefix cache", 0.0, true}, - {"Half prefix cache", 0.5, true}, - {"Full prefix cache", 1.0, true}, + {"Max prefix cache", 1.0, true}, {"Negative prefix cache", -0.1, false}, - {"Over-full prefix cache", 1.1, false}, + {"Over-max prefix cache", 1.1, false}, } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := PredictionRequest{ - KVCachePercentage: 0.5, - InputTokenLength: 100, - NumRequestWaiting: 1, - NumRequestRunning: 1, - NumTokensGenerated: 10, - PrefixCacheScore: tc.prefixCache, - } + for _, tc := range edgeCases { + req := PredictionRequest{ + KVCachePercentage: 0.5, + InputTokenLength: 100, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 10, + PrefixCacheScore: tc.prefixCache, + } - err := predictor.ValidatePredictionRequest(req) - if tc.shouldPass && err != nil { - t.Errorf("Expected %s to pass validation, got error: %v", tc.name, err) - } else if !tc.shouldPass && err == nil { - t.Errorf("Expected %s to fail validation, but it passed", tc.name) - } - }) + err := predictor.ValidatePredictionRequest(req) + if tc.shouldPass && err != nil { + t.Errorf("Edge case '%s' should pass but failed: %v", tc.name, err) + } else if !tc.shouldPass && err == nil { + t.Errorf("Edge case '%s' should fail but passed", tc.name) + } else { + t.Logf("✓ Edge case '%s' handled correctly", tc.name) + } } - t.Log("✅ Comprehensive prefix cache feature integration test completed") -} + // Test that we can access the configuration + t.Logf("Configuration validation:") + t.Logf(" Training URL: %s", predictor.GetTrainingURL()) + t.Logf(" Prediction URLs: %v", predictor.GetPredictionURLs()) + t.Logf(" Max Bulk Size: %d", config.MaxBulkSize) -// Test that demonstrates the prefix cache feature end-to-end -func TestPrefixCacheEndToEnd(t *testing.T) { - t.Log("Testing prefix cache feature end-to-end workflow...") - - // This test demonstrates a complete workflow with prefix cache scores - - // 1. Create training data that shows prefix cache impact - t.Log("Step 1: Creating training data with prefix cache impact...") - - var trainingEntries []TrainingEntry - rng := rand.New(rand.NewSource(42)) // Fixed seed for reproducible test - - for i := 0; i < 50; i++ { - kv := 0.5 + rng.Float64()*0.3 // 0.5 to 0.8 - inputLen := 200 + rng.Intn(300) // 200 to 500 - waiting := rng.Intn(5) // 0 to 4 - running := 1 + rng.Intn(3) // 1 to 3 - generated := 20 + rng.Intn(80) // 20 to 100 - prefixCache := rng.Float64() // 0.0 to 1.0 - - // Simulate the actual equation with prefix cache impact on TTFT - // TTFT = base + 2*input + 3*waiting + 4*running + 50*kv + 30*prefix_cache + noise - ttft := 95.0 + - 2.0*float64(inputLen) + - 3.0*float64(waiting) + - 4.0*float64(running) + - 50.0*kv + - 30.0*prefixCache + // Prefix cache impact - rng.NormFloat64()*5 // Small noise - - // TPOT = base + 0.5*input + 1*generated + 5*running + 100*kv + noise - // (No prefix cache impact on TPOT) - tpot := 9.0 + - 0.5*float64(inputLen) + - 1.0*float64(generated) + - 5.0*float64(running) + - 100.0*kv + - rng.NormFloat64()*3 // Small noise - - entry := TrainingEntry{ - KVCachePercentage: kv, - InputTokenLength: inputLen, - NumRequestWaiting: waiting, - NumRequestRunning: running, - NumTokensGenerated: generated, - ActualTTFT: ttft, - ActualTPOT: tpot, - PrefixCacheScore: prefixCache, - Timestamp: time.Now().Add(-time.Duration(i) * time.Minute), - } - - trainingEntries = append(trainingEntries, entry) + // Validate configuration consistency + if len(predictor.GetPredictionURLs()) != len(config.PredictionURLs) { + t.Errorf("Prediction URLs mismatch: expected %d, got %d", + len(config.PredictionURLs), len(predictor.GetPredictionURLs())) } - t.Logf("Created %d training entries with prefix cache scores", len(trainingEntries)) - - // 2. Analyze the training data to show prefix cache correlation - t.Log("Step 2: Analyzing prefix cache correlation in training data...") + if predictor.GetTrainingURL() != config.TrainingURL { + t.Errorf("Training URL mismatch: expected %s, got %s", + config.TrainingURL, predictor.GetTrainingURL()) + } - // Sort by prefix cache score - sortedEntries := make([]TrainingEntry, len(trainingEntries)) - copy(sortedEntries, trainingEntries) + // Test data structure integrity + t.Log("Validating data structure integrity...") - // Simple bubble sort by prefix cache score - for i := 0; i < len(sortedEntries)-1; i++ { - for j := i + 1; j < len(sortedEntries); j++ { - if sortedEntries[i].PrefixCacheScore > sortedEntries[j].PrefixCacheScore { - sortedEntries[i], sortedEntries[j] = sortedEntries[j], sortedEntries[i] - } + // Check that training entries maintain their prefix cache scores + for i, entry := range entries { + expectedPrefixCache := float64(i) / 4.0 + if abs(entry.PrefixCacheScore-expectedPrefixCache) > 0.001 { + t.Errorf("Training entry %d prefix cache score mismatch: expected %.3f, got %.3f", + i, expectedPrefixCache, entry.PrefixCacheScore) } } - // Compare bottom 25% vs top 25% - quarterSize := len(sortedEntries) / 4 - - var lowPrefixTTFT, highPrefixTTFT float64 - var lowPrefixTPOT, highPrefixTPOT float64 - var lowPrefixCacheAvg, highPrefixCacheAvg float64 - - // Calculate averages for low prefix cache group (bottom 25%) - for i := 0; i < quarterSize; i++ { - lowPrefixTTFT += sortedEntries[i].ActualTTFT - lowPrefixTPOT += sortedEntries[i].ActualTPOT - lowPrefixCacheAvg += sortedEntries[i].PrefixCacheScore + // Check that prediction requests maintain their prefix cache scores + for i, req := range requests { + expectedPrefixCache := float64(i) / 2.0 + if abs(req.PrefixCacheScore-expectedPrefixCache) > 0.001 { + t.Errorf("Prediction request %d prefix cache score mismatch: expected %.3f, got %.3f", + i, expectedPrefixCache, req.PrefixCacheScore) + } } - lowPrefixTTFT /= float64(quarterSize) - lowPrefixTPOT /= float64(quarterSize) - lowPrefixCacheAvg /= float64(quarterSize) - // Calculate averages for high prefix cache group (top 25%) - startIdx := len(sortedEntries) - quarterSize - for i := startIdx; i < len(sortedEntries); i++ { - highPrefixTTFT += sortedEntries[i].ActualTTFT - highPrefixTPOT += sortedEntries[i].ActualTPOT - highPrefixCacheAvg += sortedEntries[i].PrefixCacheScore + // Verify that training data is properly buffered (not flushed due to long interval) + predictor.bufferMu.Lock() + bufferedCount := len(predictor.pending) + predictor.bufferMu.Unlock() + + if bufferedCount != len(entries) { + t.Errorf("Expected %d buffered entries, got %d", len(entries), bufferedCount) + } else { + t.Logf("✓ Training data properly buffered: %d entries", bufferedCount) } - highPrefixTTFT /= float64(quarterSize) - highPrefixTPOT /= float64(quarterSize) - highPrefixCacheAvg /= float64(quarterSize) - ttftDiff := highPrefixTTFT - lowPrefixTTFT - tpotDiff := highPrefixTPOT - lowPrefixTPOT + t.Log("✅ Comprehensive prefix cache integration test completed (offline mode)") +} - t.Logf("Training data analysis results:") - t.Logf(" Low prefix cache group (avg=%.2f): TTFT=%.1f ms, TPOT=%.1f ms", - lowPrefixCacheAvg, lowPrefixTTFT, lowPrefixTPOT) - t.Logf(" High prefix cache group (avg=%.2f): TTFT=%.1f ms, TPOT=%.1f ms", - highPrefixCacheAvg, highPrefixTTFT, highPrefixTPOT) - t.Logf(" TTFT difference: %.1f ms (expect ~%.1f ms)", - ttftDiff, (highPrefixCacheAvg-lowPrefixCacheAvg)*30.0) - t.Logf(" TPOT difference: %.1f ms (expect ~0 ms)", tpotDiff) +// Test offline validation functionality without network dependencies +func TestOfflineValidation(t *testing.T) { + t.Log("Testing offline validation functionality...") - // Validate that we see the expected prefix cache impact - expectedTTFTDiff := (highPrefixCacheAvg - lowPrefixCacheAvg) * 30.0 // Our training coefficient - if ttftDiff > expectedTTFTDiff*0.5 && ttftDiff < expectedTTFTDiff*1.5 { - t.Log("✓ TTFT shows expected prefix cache correlation") - } else { - t.Logf("ℹ TTFT correlation weaker than expected (noise effects)") - } + // Create a minimal predictor for validation testing + predictor := &Predictor{} - if abs(tpotDiff) < 10 { // TPOT should not be significantly affected - t.Log("✓ TPOT correctly shows minimal prefix cache correlation") - } else { - t.Logf("⚠ TPOT unexpectedly affected by prefix cache: %.1f ms difference", tpotDiff) - } + // Test prediction request validation + t.Run("PredictionRequestValidation", func(t *testing.T) { + validReq := PredictionRequest{ + KVCachePercentage: 0.7, + InputTokenLength: 500, + NumRequestWaiting: 2, + NumRequestRunning: 1, + NumTokensGenerated: 80, + PrefixCacheScore: 0.85, + } - // 3. Create prediction scenarios to demonstrate usage - t.Log("Step 3: Creating prediction scenarios...") + err := predictor.ValidatePredictionRequest(validReq) + if err != nil { + t.Errorf("Valid prediction request failed validation: %v", err) + } - scenarios := []struct { - name string - description string - req PredictionRequest - }{ - { - name: "Cold Cache", - description: "No prefix cache hits, high latency expected", - req: PredictionRequest{ - KVCachePercentage: 0.7, - InputTokenLength: 400, - NumRequestWaiting: 2, - NumRequestRunning: 1, - NumTokensGenerated: 50, - PrefixCacheScore: 0.0, // No cache hits + // Test invalid requests + invalidTests := []struct { + name string + req PredictionRequest + }{ + { + name: "Negative KV cache", + req: PredictionRequest{ + KVCachePercentage: -0.1, + InputTokenLength: 100, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 10, + PrefixCacheScore: 0.5, + }, }, - }, - { - name: "Warm Cache", - description: "Moderate prefix cache hits", - req: PredictionRequest{ - KVCachePercentage: 0.7, - InputTokenLength: 400, - NumRequestWaiting: 2, - NumRequestRunning: 1, - NumTokensGenerated: 50, - PrefixCacheScore: 0.5, // 50% cache hits + { + name: "Invalid prefix cache", + req: PredictionRequest{ + KVCachePercentage: 0.5, + InputTokenLength: 100, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 10, + PrefixCacheScore: 1.5, + }, }, - }, - { - name: "Hot Cache", - description: "High prefix cache hits, low latency expected", - req: PredictionRequest{ - KVCachePercentage: 0.7, - InputTokenLength: 400, - NumRequestWaiting: 2, - NumRequestRunning: 1, - NumTokensGenerated: 50, - PrefixCacheScore: 0.9, // 90% cache hits - }, - }, - } + } - for _, scenario := range scenarios { - // Validate each scenario - predictor := &Predictor{} // Temporary for validation - err := predictor.ValidatePredictionRequest(scenario.req) + for _, test := range invalidTests { + err := predictor.ValidatePredictionRequest(test.req) + if err == nil { + t.Errorf("Invalid request '%s' should have failed validation", test.name) + } else { + t.Logf("✓ '%s' correctly rejected: %v", test.name, err) + } + } + }) + + // Test training entry validation + t.Run("TrainingEntryValidation", func(t *testing.T) { + validEntry := TrainingEntry{ + KVCachePercentage: 0.6, + InputTokenLength: 300, + NumRequestWaiting: 1, + NumRequestRunning: 1, + NumTokensGenerated: 50, + ActualTTFT: 45.5, + ActualTPOT: 12.3, + PrefixCacheScore: 0.75, + Timestamp: time.Now(), + } + + err := predictor.ValidateTrainingEntry(validEntry) if err != nil { - t.Errorf("Scenario '%s' failed validation: %v", scenario.name, err) - continue + t.Errorf("Valid training entry failed validation: %v", err) } - // Calculate expected TTFT using our training equation - expectedTTFT := 95.0 + - 2.0*float64(scenario.req.InputTokenLength) + - 3.0*float64(scenario.req.NumRequestWaiting) + - 4.0*float64(scenario.req.NumRequestRunning) + - 50.0*scenario.req.KVCachePercentage + - 30.0*scenario.req.PrefixCacheScore + // Test invalid entry + invalidEntry := validEntry + invalidEntry.PrefixCacheScore = -0.5 - expectedTPOT := 9.0 + - 0.5*float64(scenario.req.InputTokenLength) + - 1.0*float64(scenario.req.NumTokensGenerated) + - 5.0*float64(scenario.req.NumRequestRunning) + - 100.0*scenario.req.KVCachePercentage + err = predictor.ValidateTrainingEntry(invalidEntry) + if err == nil { + t.Error("Invalid training entry should have failed validation") + } else { + t.Logf("✓ Invalid training entry correctly rejected: %v", err) + } + }) - t.Logf("Scenario: %s", scenario.name) - t.Logf(" Description: %s", scenario.description) - t.Logf(" Prefix cache: %.0f%%", scenario.req.PrefixCacheScore*100) - t.Logf(" Expected TTFT: %.1f ms", expectedTTFT) - t.Logf(" Expected TPOT: %.1f ms", expectedTPOT) - t.Log("") - } + // Test constructor functions + t.Run("ConstructorFunctions", func(t *testing.T) { + // Test valid constructors + _, err := NewPredictionRequest(0.5, 100, 1, 1, 10, 0.8) + if err != nil { + t.Errorf("Valid prediction request constructor failed: %v", err) + } + + _, err = NewTrainingEntry(0.5, 100, 1, 1, 10, 30.0, 8.0, 0.8) + if err != nil { + t.Errorf("Valid training entry constructor failed: %v", err) + } + + // Test invalid constructors + _, err = NewPredictionRequest(0.5, 100, 1, 1, 10, 1.5) // Invalid prefix cache + if err == nil { + t.Error("Invalid prediction request constructor should have failed") + } - t.Log("✅ End-to-end prefix cache workflow demonstration completed") + _, err = NewTrainingEntry(0.5, 100, 1, 1, 10, 30.0, 8.0, -0.1) // Invalid prefix cache + if err == nil { + t.Error("Invalid training entry constructor should have failed") + } + }) + + t.Log("✅ Offline validation tests completed") } -// Helper function for absolute value -func abs(x float64) float64 { - if x < 0 { - return -x +// Test configuration handling without network calls +func TestConfigurationHandling(t *testing.T) { + t.Log("Testing configuration handling...") + + // Test default configuration + defaultConfig := DefaultConfig() + if defaultConfig.MaxBulkSize != 100 { + t.Errorf("Expected default MaxBulkSize to be 100, got %d", defaultConfig.MaxBulkSize) } - return x + + if defaultConfig.UseNativeXGBoost != true { + t.Errorf("Expected default UseNativeXGBoost to be true, got %t", defaultConfig.UseNativeXGBoost) + } + + // Test configuration with mock URLs (no network calls) + config := &Config{ + TrainingURL: "http://mock-training.local", + PredictionURLs: []string{"http://mock1.local", "http://mock2.local"}, + MaxSampleSize: 500, + FlushInterval: 2 * time.Second, + MetricsRefreshInterval: 5 * time.Second, + UseNativeXGBoost: false, + HTTPTimeout: 10 * time.Second, + MaxBulkSize: 50, + } + + // Create logger + zapLog, err := zap.NewDevelopment() + if err != nil { + t.Fatalf("Failed to create logger: %v", err) + } + logger := zapr.NewLogger(zapLog) + + // Create predictor but don't start it (to avoid network calls) + predictor := New(config, logger) + + // Test configuration access + if predictor.GetTrainingURL() != config.TrainingURL { + t.Errorf("Training URL mismatch: expected %s, got %s", + config.TrainingURL, predictor.GetTrainingURL()) + } + + predictionURLs := predictor.GetPredictionURLs() + if len(predictionURLs) != len(config.PredictionURLs) { + t.Errorf("Prediction URLs length mismatch: expected %d, got %d", + len(config.PredictionURLs), len(predictionURLs)) + } + + // Cleanup without starting background processes + close(predictor.done) + predictor.wg.Wait() + + t.Log("✅ Configuration handling tests completed") } diff --git a/pkg/epp/requestcontrol/director.go b/pkg/epp/requestcontrol/director.go index 5a25d5b28..662491bdb 100644 --- a/pkg/epp/requestcontrol/director.go +++ b/pkg/epp/requestcontrol/director.go @@ -213,9 +213,13 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo infObjective := d.datastore.ObjectiveGet(reqCtx.ObjectiveKey) if infObjective == nil { logger.V(logutil.VERBOSE).Info("No associated InferenceObjective found, using default", "objectiveKey", reqCtx.ObjectiveKey) + priority := d.defaultPriority + if strings.Contains(reqCtx.ObjectiveKey, "sheddable") { + priority = -1 + } infObjective = &v1alpha2.InferenceObjective{ Spec: v1alpha2.InferenceObjectiveSpec{ - Priority: &d.defaultPriority, + Priority: &priority, }, } } else if infObjective.Spec.Priority == nil { @@ -225,13 +229,14 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo // get request slos // Get Request SLOs from request header - ttftSLO, _, err := parseFloatHeader(reqCtx, "x-SLO-TTFT-ms") + ttftSLO, _, err := parseFloatHeader(reqCtx, "x-slo-ttft-ms") if err != nil { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-SLO-TTFT-ms must be a float: %v", err)} + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-slo-ttft-ms must be a float: %v", err)} } - avgTPOTSLO, _, err := parseFloatHeader(reqCtx, "x-SLO-TPOT-ms") + + avgTPOTSLO, _, err := parseFloatHeader(reqCtx, "x-slo-tpot-ms") if err != nil { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-SLO-TPOT-ms must be a float: %v", err)} + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("x-slo-tpot-ms must be a float: %v", err)} } predictionBasedScheduling, err := parseBoolHeader(reqCtx, "x-prediction-based-scheduling") if err != nil { @@ -295,16 +300,16 @@ func (d *Director) HandleRequest(ctx context.Context, reqCtx *handlers.RequestCo func (d *Director) admitRequest(ctx context.Context, candidatePods []backendmetrics.PodMetrics, request *schedulingtypes.LLMRequest, requestPriority int, fairnessID string) error { logger := log.FromContext(ctx) - logger.V(logutil.TRACE).Info("Entering Flow Control", "priority", requestPriority, "fairnessID", fairnessID) + logger.V(logutil.DEBUG).Info("Entering Flow Control", "priority", requestPriority, "fairnessID", fairnessID) // This will be removed in favor of a more robust implementation (Flow Control) in the very near future. // TODO: Make this a configurable value. // Tracking issue https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1347 if requestPriority >= 0 { - logger.V(logutil.TRACE).Info("Non-sheddable request bypassing saturation check.") + logger.V(logutil.DEBUG).Info("Non-sheddable request bypassing saturation check.") return nil } else { - logger.V(logutil.TRACE).Info("Sheddable request subject to saturation check.") + logger.V(logutil.DEBUG).Info("Sheddable request subject to saturation check.") } if d.saturationDetector.IsSaturated(ctx, candidatePods) || !request.HasValidPod { // Assuming non-nil Saturation Detector diff --git a/pkg/epp/requestcontrol/director_test.go b/pkg/epp/requestcontrol/director_test.go index 4ff4a1775..61a8b31be 100644 --- a/pkg/epp/requestcontrol/director_test.go +++ b/pkg/epp/requestcontrol/director_test.go @@ -170,6 +170,21 @@ type mockPredictor struct { addSampleShouldFail bool } +// GetServerStatus implements latencypredictorasync.PredictorInterface. +func (m *mockPredictor) GetServerStatus(ctx context.Context) (*latencypredictor.ServerStatusResponse, error) { + panic("unimplemented") +} + +// PredictBulk implements latencypredictorasync.PredictorInterface. +func (m *mockPredictor) PredictBulk(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + panic("unimplemented") +} + +// PredictBulkStrict implements latencypredictorasync.PredictorInterface. +func (m *mockPredictor) PredictBulkStrict(ctx context.Context, requests []latencypredictor.PredictionRequest) (*latencypredictor.BulkPredictionResponse, error) { + panic("unimplemented") +} + var _ latencypredictor.PredictorInterface = &mockPredictor{} func (m *mockPredictor) Predict(ctx context.Context, req latencypredictor.PredictionRequest) (*latencypredictor.PredictionResponse, error) { diff --git a/pkg/epp/requestcontrol/latencypredictor_helper.go b/pkg/epp/requestcontrol/latencypredictor_helper.go index 2dab443e8..8cd840391 100644 --- a/pkg/epp/requestcontrol/latencypredictor_helper.go +++ b/pkg/epp/requestcontrol/latencypredictor_helper.go @@ -395,6 +395,98 @@ func PredictWithMetrics( return result, nil } +// BulkPredictWithMetrics performs bulk predictions for multiple pods using their metrics states. +// Returns predictions in the same order as the input slices. +func BulkPredictWithMetrics( + ctx context.Context, + predictor latencypredictor.PredictorInterface, + metricsStates []*backendmetrics.MetricsState, + prompts []string, + generatedTokenCounts []int, + prefixCacheScores []float64, +) ([]*latencypredictor.PredictionResponse, error) { + logger := log.FromContext(ctx) + + // Validate input lengths + if len(metricsStates) != len(prompts) || len(prompts) != len(generatedTokenCounts) || len(generatedTokenCounts) != len(prefixCacheScores) { + return nil, fmt.Errorf("input slice lengths must match: metrics=%d, prompts=%d, tokenCounts=%d, prefixScores=%d", + len(metricsStates), len(prompts), len(generatedTokenCounts), len(prefixCacheScores)) + } + + if len(metricsStates) == 0 { + return []*latencypredictor.PredictionResponse{}, nil + } + + // Validate that no metrics state is nil + for i, metricsState := range metricsStates { + if metricsState == nil { + return nil, fmt.Errorf("metrics state at index %d cannot be nil", i) + } + } + + // Build bulk prediction requests + bulkRequests := make([]latencypredictor.PredictionRequest, len(metricsStates)) + for i := range metricsStates { + bulkRequests[i] = latencypredictor.PredictionRequest{ + KVCachePercentage: metricsStates[i].KVCacheUsagePercent, + InputTokenLength: len(strings.Fields(prompts[i])), + NumRequestWaiting: metricsStates[i].WaitingQueueSize, + NumRequestRunning: metricsStates[i].RunningQueueSize, + NumTokensGenerated: generatedTokenCounts[i], + PrefixCacheScore: prefixCacheScores[i], + } + } + + // Perform bulk prediction + start := time.Now() + bulkResponse, err := predictor.PredictBulkStrict(ctx, bulkRequests) + duration := time.Since(start) + + if err != nil { + logger.V(logutil.DEBUG).Error(err, "bulk prediction failed", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests)) + return nil, err + } + + if bulkResponse == nil { + logger.V(logutil.DEBUG).Info("bulk prediction returned nil", + "duration_ms", duration.Milliseconds()) + return nil, fmt.Errorf("bulk prediction returned nil result") + } + + // Convert to pointer slice for consistency with single prediction + results := make([]*latencypredictor.PredictionResponse, len(bulkResponse.Predictions)) + for i := range bulkResponse.Predictions { + results[i] = &bulkResponse.Predictions[i] + } + + logger.V(logutil.DEBUG).Info("bulk prediction succeeded", + "duration_ms", duration.Milliseconds(), + "request_count", len(bulkRequests), + "successful_predictions", bulkResponse.SuccessfulPredictions, + "failed_predictions", bulkResponse.FailedPredictions, + "processing_time_ms", bulkResponse.ProcessingTimeMs) + + // Log detailed results if at trace level + if logger.V(logutil.TRACE).Enabled() { + for i, result := range results { + logger.V(logutil.TRACE).Info("bulk prediction result", + "index", i, + "ttft_ms", result.TTFT, + "tpot_ms", result.TPOT, + "input_tokens", bulkRequests[i].InputTokenLength, + "generated_tokens", bulkRequests[i].NumTokensGenerated, + "kv_cache_percent", bulkRequests[i].KVCachePercentage, + "waiting_queue", bulkRequests[i].NumRequestWaiting, + "running_queue", bulkRequests[i].NumRequestRunning, + "prefix_cache_score", bulkRequests[i].PrefixCacheScore) + } + } + + return results, nil +} + // Fixed DebugPrintRawScores for map[string]map[Pod]float64 structure func DebugPrintRawScores(ctx context.Context, reqCtx *handlers.RequestContext) { logger := log.FromContext(ctx) diff --git a/pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go b/pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go index ca10b44d6..c2f2daa3b 100644 --- a/pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go +++ b/pkg/epp/scheduling/framework/plugins/scorer/slo_scorer.go @@ -721,8 +721,12 @@ func (s *SLOScorer) validatePrediction( func (s *SLOScorer) getPrefixCacheScoreForPod(ctx context.Context, cycleState *schedulingtypes.CycleState, pod schedulingtypes.Pod) float64 { log.FromContext(ctx).V(logutil.DEBUG).Info("Running getPrefixCacheScoreForPod, getting prefix cache score for pod", "pod", pod.GetPod().String()) + plugintype := prefix.PrefixCachePluginType + pluginname := prefix.PrefixCachePluginType + cycleStateKey := (plugins.TypedName{Type: plugintype, Name: pluginname}).String() + stateData, err := cycleState.Read(plugins.StateKey(cycleStateKey)) - stateData, err := cycleState.Read(plugins.StateKey(prefix.PrefixCachePluginType)) + log.FromContext(ctx).V(logutil.DEBUG).Info("Reading prefix cache state from cycle state", "stateKey", cycleStateKey) if err != nil { // The prefix cache plugin might not be enabled, which is a valid scenario. diff --git a/pkg/epp/scheduling/framework/scheduler_profile.go b/pkg/epp/scheduling/framework/scheduler_profile.go index 08b3b8f18..7bb6fc653 100644 --- a/pkg/epp/scheduling/framework/scheduler_profile.go +++ b/pkg/epp/scheduling/framework/scheduler_profile.go @@ -188,7 +188,9 @@ func (p *SchedulerProfile) runScorerPlugins(ctx context.Context, request *types. for pod, score := range scores { // weight is relative to the sum of weights logger.V(logutil.DEBUG).Info("Calculated score", "plugin", scorer.TypedName(), "endpoint", pod.GetPod().NamespacedName, "score", score, "weight", scorer.Weight()) weightedScorePerPod[pod] += enforceScoreRange(score) * float64(scorer.Weight()) + } + for pod, score := range scores { logger.V(logutil.DEBUG).Info("Pod score", "scorer_type", scorer.TypedName().Type,