|
13 | 13 | import tempfile
|
14 | 14 |
|
15 | 15 | # Base URLs for the dual-server architecture
|
16 |
| -PREDICTION_URL = os.getenv("PREDICTION_SERVER_URL", "http://34.158.41.245:80") # Update this |
17 |
| -TRAINING_URL = os.getenv("TRAINING_SERVER_URL", "http://34.143.208.0:8080") # Update this |
| 16 | + |
| 17 | +PREDICTION_URL = os.getenv("PREDICTION_SERVER_URL", "http://<PREDICTION_IP>:80") # Update this |
| 18 | +TRAINING_URL = os.getenv("TRAINING_SERVER_URL", "http://<TRAINING_IP>:8080") # Update this |
18 | 19 |
|
19 | 20 | TARGET_QPS = float(os.getenv("TARGET_QPS", 1000)) # Update this
|
20 | 21 | TARGET_QPS_LARGE_BATCH = float(os.getenv("TARGET_QPS_LARGE_BATCH", 100)) # Update this
|
@@ -1133,6 +1134,204 @@ def test_server_configuration():
|
1133 | 1134 | print(f"Training server: {train_root_data.get('message')}")
|
1134 | 1135 | print(f" Model type: {train_root_data.get('model_type')}")
|
1135 | 1136 |
|
| 1137 | +def test_training_server_flush_api(): |
| 1138 | + """Test the training server flush API and data status endpoint.""" |
| 1139 | + print("Testing training server flush API...") |
| 1140 | + |
| 1141 | + # 1. Check initial data status |
| 1142 | + print("Step 1: Checking initial data status...") |
| 1143 | + initial_status_r = requests.get(f"{TRAINING_URL}/data/status") |
| 1144 | + assert initial_status_r.status_code == 200 |
| 1145 | + initial_status = initial_status_r.json() |
| 1146 | + |
| 1147 | + print(f" Initial training samples: TTFT={initial_status['training_data']['ttft_samples']}, " |
| 1148 | + f"TPOT={initial_status['training_data']['tpot_samples']}") |
| 1149 | + print(f" Initial test samples: TTFT={initial_status['test_data']['ttft_samples']}, " |
| 1150 | + f"TPOT={initial_status['test_data']['tpot_samples']}") |
| 1151 | + |
| 1152 | + # 2. Add training data |
| 1153 | + print("Step 2: Adding training data...") |
| 1154 | + training_entries = [generate_random_training_payload() for _ in range(100)] |
| 1155 | + training_payload = {"entries": training_entries} |
| 1156 | + |
| 1157 | + add_r = requests.post(f"{TRAINING_URL}/add_training_data_bulk", json=training_payload) |
| 1158 | + assert add_r.status_code == 202 |
| 1159 | + print(f" Added 100 training samples") |
| 1160 | + |
| 1161 | + # Wait a bit for data to be processed |
| 1162 | + time.sleep(2) |
| 1163 | + |
| 1164 | + # 3. Verify data was added |
| 1165 | + print("Step 3: Verifying data was added...") |
| 1166 | + after_add_status_r = requests.get(f"{TRAINING_URL}/data/status") |
| 1167 | + assert after_add_status_r.status_code == 200 |
| 1168 | + after_add_status = after_add_status_r.json() |
| 1169 | + |
| 1170 | + total_samples_after = after_add_status['training_data']['total_samples'] + after_add_status['test_data']['total_samples'] |
| 1171 | + print(f" After adding - Training: {after_add_status['training_data']['total_samples']}, " |
| 1172 | + f"Test: {after_add_status['test_data']['total_samples']}, Total: {total_samples_after}") |
| 1173 | + |
| 1174 | + # Should have more data now (some goes to training, some to test based on TEST_TRAIN_RATIO) |
| 1175 | + assert total_samples_after > 0, "No samples were added" |
| 1176 | + |
| 1177 | + # 4. Test flush with only training data |
| 1178 | + print("Step 4: Testing flush with only training data...") |
| 1179 | + flush_training_only = { |
| 1180 | + "flush_training_data": True, |
| 1181 | + "flush_test_data": False, |
| 1182 | + "flush_metrics": False, |
| 1183 | + "reason": "Test flush training data only" |
| 1184 | + } |
| 1185 | + |
| 1186 | + flush_r = requests.post(f"{TRAINING_URL}/flush", json=flush_training_only) |
| 1187 | + assert flush_r.status_code == 200 |
| 1188 | + flush_response = flush_r.json() |
| 1189 | + |
| 1190 | + assert flush_response["success"] == True |
| 1191 | + assert flush_response["metrics_cleared"] == False |
| 1192 | + assert flush_response["reason"] == "Test flush training data only" |
| 1193 | + |
| 1194 | + print(f" Flushed {flush_response['ttft_training_samples_flushed']} TTFT training samples") |
| 1195 | + print(f" Flushed {flush_response['tpot_training_samples_flushed']} TPOT training samples") |
| 1196 | + print(f" Test samples flushed: {flush_response['ttft_test_samples_flushed']} TTFT, " |
| 1197 | + f"{flush_response['tpot_test_samples_flushed']} TPOT (should be 0)") |
| 1198 | + |
| 1199 | + # Verify training data was flushed but test data remains |
| 1200 | + after_flush_training_r = requests.get(f"{TRAINING_URL}/data/status") |
| 1201 | + after_flush_training = after_flush_training_r.json() |
| 1202 | + |
| 1203 | + assert after_flush_training['training_data']['total_samples'] == 0, "Training data should be empty" |
| 1204 | + # Test data should still exist if any was added |
| 1205 | + print(f" After training flush - Training: {after_flush_training['training_data']['total_samples']}, " |
| 1206 | + f"Test: {after_flush_training['test_data']['total_samples']}") |
| 1207 | + |
| 1208 | + # 5. Add more data |
| 1209 | + print("Step 5: Adding more training data...") |
| 1210 | + more_entries = [generate_random_training_payload() for _ in range(50)] |
| 1211 | + requests.post(f"{TRAINING_URL}/add_training_data_bulk", json={"entries": more_entries}) |
| 1212 | + time.sleep(2) |
| 1213 | + |
| 1214 | + # 6. Test flush everything |
| 1215 | + print("Step 6: Testing flush everything...") |
| 1216 | + flush_all = { |
| 1217 | + "flush_training_data": True, |
| 1218 | + "flush_test_data": True, |
| 1219 | + "flush_metrics": True, |
| 1220 | + "reason": "Complete flush test" |
| 1221 | + } |
| 1222 | + |
| 1223 | + flush_all_r = requests.post(f"{TRAINING_URL}/flush", json=flush_all) |
| 1224 | + assert flush_all_r.status_code == 200 |
| 1225 | + flush_all_response = flush_all_r.json() |
| 1226 | + |
| 1227 | + assert flush_all_response["success"] == True |
| 1228 | + assert flush_all_response["metrics_cleared"] == True |
| 1229 | + assert "Successfully flushed" in flush_all_response["message"] |
| 1230 | + |
| 1231 | + print(f" Complete flush message: {flush_all_response['message']}") |
| 1232 | + |
| 1233 | + # Verify everything was flushed |
| 1234 | + after_flush_all_r = requests.get(f"{TRAINING_URL}/data/status") |
| 1235 | + after_flush_all = after_flush_all_r.json() |
| 1236 | + |
| 1237 | + assert after_flush_all['training_data']['total_samples'] == 0, "Training data should be empty" |
| 1238 | + assert after_flush_all['test_data']['total_samples'] == 0, "Test data should be empty" |
| 1239 | + |
| 1240 | + print(f" After complete flush - Training: {after_flush_all['training_data']['total_samples']}, " |
| 1241 | + f"Test: {after_flush_all['test_data']['total_samples']}") |
| 1242 | + |
| 1243 | + # 7. Test flush with default parameters (should flush everything) |
| 1244 | + print("Step 7: Testing default flush (no body)...") |
| 1245 | + |
| 1246 | + # Add some data first |
| 1247 | + requests.post(f"{TRAINING_URL}/add_training_data_bulk", |
| 1248 | + json={"entries": [generate_random_training_payload() for _ in range(20)]}) |
| 1249 | + time.sleep(1) |
| 1250 | + |
| 1251 | + # Flush with empty body (uses defaults) |
| 1252 | + default_flush_r = requests.post(f"{TRAINING_URL}/flush") |
| 1253 | + assert default_flush_r.status_code == 200 |
| 1254 | + default_flush_response = default_flush_r.json() |
| 1255 | + |
| 1256 | + assert default_flush_response["success"] == True |
| 1257 | + print(f" Default flush result: {default_flush_response['message']}") |
| 1258 | + |
| 1259 | + # 8. Test flush with only test data |
| 1260 | + print("Step 8: Testing flush with only test data...") |
| 1261 | + |
| 1262 | + # Add data |
| 1263 | + requests.post(f"{TRAINING_URL}/add_training_data_bulk", |
| 1264 | + json={"entries": [generate_random_training_payload() for _ in range(50)]}) |
| 1265 | + time.sleep(2) |
| 1266 | + |
| 1267 | + # Get status before |
| 1268 | + before_test_flush_r = requests.get(f"{TRAINING_URL}/data/status") |
| 1269 | + before_test_flush = before_test_flush_r.json() |
| 1270 | + |
| 1271 | + # Flush only test data |
| 1272 | + flush_test_only = { |
| 1273 | + "flush_training_data": False, |
| 1274 | + "flush_test_data": True, |
| 1275 | + "flush_metrics": False, |
| 1276 | + "reason": "Test flush test data only" |
| 1277 | + } |
| 1278 | + |
| 1279 | + flush_test_r = requests.post(f"{TRAINING_URL}/flush", json=flush_test_only) |
| 1280 | + assert flush_test_r.status_code == 200 |
| 1281 | + flush_test_response = flush_test_r.json() |
| 1282 | + |
| 1283 | + print(f" Test data flush: {flush_test_response['ttft_test_samples_flushed']} TTFT, " |
| 1284 | + f"{flush_test_response['tpot_test_samples_flushed']} TPOT") |
| 1285 | + |
| 1286 | + # Verify only test data was flushed |
| 1287 | + after_test_flush_r = requests.get(f"{TRAINING_URL}/data/status") |
| 1288 | + after_test_flush = after_test_flush_r.json() |
| 1289 | + |
| 1290 | + assert after_test_flush['test_data']['total_samples'] == 0, "Test data should be empty" |
| 1291 | + # Training data should still exist |
| 1292 | + print(f" After test flush - Training: {after_test_flush['training_data']['total_samples']}, " |
| 1293 | + f"Test: {after_test_flush['test_data']['total_samples']}") |
| 1294 | + |
| 1295 | + # 9. Test bucket distribution in status |
| 1296 | + print("Step 9: Testing bucket distribution in status...") |
| 1297 | + if "bucket_distribution" in after_flush_all: |
| 1298 | + print(f" Bucket distribution available: {len(after_flush_all.get('bucket_distribution', {}))} buckets with data") |
| 1299 | + |
| 1300 | + print("✓ Flush API tests passed!") |
| 1301 | + |
| 1302 | + |
| 1303 | +def test_training_server_flush_error_handling(): |
| 1304 | + """Test error handling in flush API.""" |
| 1305 | + print("Testing flush API error handling...") |
| 1306 | + |
| 1307 | + # Test with invalid JSON |
| 1308 | + invalid_json = '{"flush_training_data": "not_a_boolean"}' |
| 1309 | + headers = {'Content-Type': 'application/json'} |
| 1310 | + |
| 1311 | + try: |
| 1312 | + r = requests.post(f"{TRAINING_URL}/flush", data=invalid_json, headers=headers) |
| 1313 | + # Should get validation error |
| 1314 | + assert r.status_code in [400, 422], f"Expected 400 or 422, got {r.status_code}" |
| 1315 | + print("✓ Invalid JSON handled correctly") |
| 1316 | + except Exception as e: |
| 1317 | + print(f"⚠️ Error handling test skipped: {e}") |
| 1318 | + |
| 1319 | + # Test with valid parameters |
| 1320 | + valid_flush = { |
| 1321 | + "flush_training_data": False, |
| 1322 | + "flush_test_data": False, |
| 1323 | + "flush_metrics": True, |
| 1324 | + "reason": "Metrics only flush" |
| 1325 | + } |
| 1326 | + |
| 1327 | + r = requests.post(f"{TRAINING_URL}/flush", json=valid_flush) |
| 1328 | + assert r.status_code == 200 |
| 1329 | + response = r.json() |
| 1330 | + assert response["metrics_cleared"] == True |
| 1331 | + assert response["ttft_training_samples_flushed"] == 0 |
| 1332 | + assert response["tpot_training_samples_flushed"] == 0 |
| 1333 | + |
| 1334 | + print("✓ Flush error handling tests passed!") |
1136 | 1335 |
|
1137 | 1336 | if __name__ == "__main__":
|
1138 | 1337 | print("Running dual-server architecture tests with prefix cache score support...")
|
@@ -1168,6 +1367,8 @@ def test_server_configuration():
|
1168 | 1367 | ("Training Metrics", test_training_server_metrics),
|
1169 | 1368 | ("Model Consistency", test_model_consistency_between_servers),
|
1170 | 1369 | ("XGBoost Trees", test_model_specific_endpoints_on_training_server),
|
| 1370 | + ("Flush API", test_training_server_flush_api), |
| 1371 | + ("Flush Error Handling", test_training_server_flush_error_handling), |
1171 | 1372 |
|
1172 | 1373 | ("Dual Server Model Learns Equation", test_dual_server_quantile_regression_learns_distribution),
|
1173 | 1374 | ("End-to-End Workflow", test_end_to_end_workflow),
|
|
0 commit comments