diff --git a/latencypredictor-v1/training_server.py b/latencypredictor-v1/training_server.py index fa8a81118..51425bfe0 100644 --- a/latencypredictor-v1/training_server.py +++ b/latencypredictor-v1/training_server.py @@ -628,22 +628,22 @@ def train(self): if len(df_ttft) >= settings.MIN_SAMPLES_FOR_RETRAIN: # Updated TTFT features to include prefix_cache_score ttft_feature_cols_tree = [ - 'kv_cache_percentage','input_token_length','num_request_waiting', - 'num_request_running','prefix_cache_score','effective_input_tokens','prefill_score_bucket' - ] - ttft_feature_cols_br = [ - 'kv_cache_percentage','input_token_length','num_request_waiting', - 'num_request_running','prefix_cache_score','effective_input_tokens' - ] - - # Build X_ttft for all model types, then trim for BR - X_ttft = df_ttft[ttft_feature_cols_tree] - if self.model_type == ModelType.BAYESIAN_RIDGE: - X_ttft = X_ttft[ttft_feature_cols_br] + 'kv_cache_percentage','input_token_length','num_request_waiting', + 'num_request_running','prefix_cache_score','effective_input_tokens','prefill_score_bucket' + ] + ttft_feature_cols_br = [ + 'kv_cache_percentage','input_token_length','num_request_waiting', + 'num_request_running','prefix_cache_score','effective_input_tokens' + ] + + # Build X_ttft for all model types, then trim for BR + X_ttft = df_ttft[ttft_feature_cols_tree] + if self.model_type == ModelType.BAYESIAN_RIDGE: + X_ttft = X_ttft[ttft_feature_cols_br] - y_ttft = raw_ttft['actual_ttft_ms'] + y_ttft = raw_ttft['actual_ttft_ms'] - try: + try: # raw_ttft still has the original columns including 'prefix_cache_score' raw_ttft['_prefix_bucket'] = raw_ttft['prefix_cache_score'].clip(0, 1).apply( lambda s: min(int(s * self.prefix_buckets), self.prefix_buckets - 1) @@ -677,8 +677,6 @@ def train(self): new_ttft_model, new_ttft_scaler, test_records, cols, 'actual_ttft_ms' ) - - if ql is not None: self.ttft_quantile_loss_scores.append(ql) self.ttft_coverage_scores.append(coverage) @@ -690,7 +688,7 @@ def train(self): else: logging.info(f"TTFT model trained on {len(df_ttft)} samples. Quantile metrics = N/A (insufficient test data)") - except Exception: + except Exception: logging.error("Error training TTFT model", exc_info=True)