|
27 | 27 | generate_categorical_segments, |
28 | 28 | generate_numerical_segments, |
29 | 29 | validate_ready_for_upload, |
| 30 | + recalculate_numerical_segments, |
30 | 31 | ) |
31 | 32 | from utils.case_import_storage import save_import_file, load_import_file |
32 | 33 | from utils.case_import_process_confirmed_segmentation import ( |
|
37 | 38 | SegmentationPreviewRequest, |
38 | 39 | SegmentationPreviewResponse, |
39 | 40 | GenerateSegmentValuesRequest, |
| 41 | + SegmentationRecalculateRequest, |
40 | 42 | ) |
41 | 43 |
|
42 | 44 | security = HTTPBearer() |
@@ -226,3 +228,73 @@ def download_upload_template(): |
226 | 228 | filename=TEMPLATE_NAME, |
227 | 229 | media_type="application/vnd.ms-excel.sheet.macroEnabled.12", |
228 | 230 | ) |
| 231 | + |
| 232 | + |
| 233 | +@case_import_route.post( |
| 234 | + "/case-import/recalculate-segmentation", |
| 235 | + summary="Recalculate segmentation after user edits segment values", |
| 236 | + name="case_import:recalculate_segmentation", |
| 237 | + tags=ROUTE_TAG_NAME, |
| 238 | +) |
| 239 | +def recalculate_segmentation( |
| 240 | + req: Request, |
| 241 | + payload: SegmentationRecalculateRequest, |
| 242 | + session: Session = Depends(get_session), |
| 243 | + credentials: credentials = Depends(security), |
| 244 | +): |
| 245 | + verify_case_creator( |
| 246 | + session=session, |
| 247 | + authenticated=req.state.authenticated, |
| 248 | + ) |
| 249 | + |
| 250 | + case_import = get_case_import( |
| 251 | + session=session, |
| 252 | + import_id=payload.import_id, |
| 253 | + ) |
| 254 | + |
| 255 | + content = load_import_file(case_import.file_path) |
| 256 | + df = load_data_dataframe_from_bytes(content) |
| 257 | + |
| 258 | + variable = payload.segmentation_variable.lower() |
| 259 | + var_type = payload.variable_type.lower() |
| 260 | + |
| 261 | + if variable not in df.columns: |
| 262 | + raise HTTPException( |
| 263 | + status_code=400, |
| 264 | + detail="Segmentation variable not found in data sheet", |
| 265 | + ) |
| 266 | + |
| 267 | + # -------- CATEGORICAL -------- |
| 268 | + if var_type == "categorical": |
| 269 | + # No recalculation needed; categories are fixed |
| 270 | + segments = generate_categorical_segments( |
| 271 | + df=df, |
| 272 | + column=variable, |
| 273 | + ) |
| 274 | + |
| 275 | + # -------- NUMERICAL -------- |
| 276 | + elif var_type == "numerical": |
| 277 | + if not payload.segments: |
| 278 | + raise HTTPException( |
| 279 | + status_code=400, |
| 280 | + detail="Segments are required for numerical recalculation", |
| 281 | + ) |
| 282 | + |
| 283 | + segments = recalculate_numerical_segments( |
| 284 | + df=df, |
| 285 | + column=variable, |
| 286 | + segments=[seg.dict() for seg in payload.segments], |
| 287 | + ) |
| 288 | + |
| 289 | + else: |
| 290 | + raise HTTPException( |
| 291 | + status_code=400, |
| 292 | + detail="Invalid variable type", |
| 293 | + ) |
| 294 | + |
| 295 | + return { |
| 296 | + "import_id": payload.import_id, |
| 297 | + "segmentation_variable": variable, |
| 298 | + "variable_type": var_type, |
| 299 | + "segments": segments, |
| 300 | + } |
0 commit comments