Skip to content

Commit 86d225d

Browse files
[#700] Update case import process segmentation logic with correct thresholds for each segment
1 parent 95f6852 commit 86d225d

File tree

1 file changed

+93
-149
lines changed

1 file changed

+93
-149
lines changed

backend/utils/case_import_process_confirmed_segmentation.py

Lines changed: 93 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from io import BytesIO
66

77
from models.question import Question
8-
from models.case_commodity import CaseCommodity
8+
from models.case_commodity import CaseCommodity, CaseCommodityType
99
from models.segment import SegmentUpdateBase, SegmentAnswerBase
1010
from utils.case_import_storage import load_import_file
1111
from db.crud_case_import import get_case_import
@@ -25,10 +25,7 @@ def resolve_question(
2525
q = session.get(Question, int(question_id))
2626
if q:
2727
return q
28-
raise HTTPException(
29-
status_code=400,
30-
detail=f"Question not found for id {question_id}",
31-
)
28+
raise HTTPException(400, f"Question not found for id {question_id}")
3229

3330
if public_key:
3431
q = (
@@ -39,34 +36,29 @@ def resolve_question(
3936
if q:
4037
return q
4138
raise HTTPException(
42-
status_code=400,
43-
detail=f"Question not found for public_key '{public_key}'",
39+
400, f"Question not found for public_key '{public_key}'"
4440
)
4541

4642
raise HTTPException(
47-
status_code=400,
48-
detail="Mapping row must contain either id or public_key",
43+
400, "Mapping row must contain either id or public_key"
4944
)
5045

5146

5247
# --------------------------------------------------
53-
# Main processor
48+
# Main processor (FINAL FIX)
5449
# --------------------------------------------------
5550
def process_confirmed_segmentation(
5651
*,
5752
payload,
5853
session,
5954
) -> Dict[str, Any]:
6055

61-
# --------------------------------------------------
62-
# 1. Inputs
63-
# --------------------------------------------------
6456
case_id = payload.case_id
6557
segmentation_variable = payload.segmentation_variable.strip().lower()
6658
segments = sorted(payload.segments, key=lambda s: s.index)
6759

6860
# --------------------------------------------------
69-
# 2. Load import file (bytes)
61+
# Load import file
7062
# --------------------------------------------------
7163
case_import = get_case_import(session=session, import_id=payload.import_id)
7264
content = load_import_file(case_import.file_path)
@@ -76,195 +68,147 @@ def process_confirmed_segmentation(
7668
data_df = pd.read_excel(xls, sheet_name="data")
7769
mapping_df = pd.read_excel(xls, sheet_name="mapping")
7870

79-
# normalize column names
8071
data_df.columns = data_df.columns.str.strip().str.lower()
8172
mapping_df.columns = mapping_df.columns.str.strip().str.lower()
8273
except Exception:
83-
raise HTTPException(
84-
status_code=400,
85-
detail="Failed to read import workbook",
86-
)
74+
raise HTTPException(400, "Failed to read import workbook")
8775

8876
if segmentation_variable not in data_df.columns:
8977
raise HTTPException(
90-
status_code=400,
91-
detail=f"Segmentation variable {segmentation_variable} not found",
78+
400, f"Segmentation variable '{segmentation_variable}' not found"
9279
)
9380

9481
# --------------------------------------------------
95-
# 3. Validate mapping sheet
82+
# Prepare segmentation series
9683
# --------------------------------------------------
97-
if "variable_name" not in mapping_df.columns:
98-
raise HTTPException(
99-
status_code=400,
100-
detail="Mapping sheet must contain 'variable_name'",
101-
)
102-
103-
if not {"id", "public_key"} & set(mapping_df.columns):
104-
raise HTTPException(
105-
status_code=400,
106-
detail="Mapping sheet must contain at least 'id' or 'public_key'",
107-
)
108-
109-
# Only variables explicitly mapped are output drivers
110-
output_variables = (
111-
mapping_df["variable_name"].dropna().str.lower().tolist()
112-
)
113-
114-
# --------------------------------------------------
115-
# 4. Segment assignment (UI AUTHORITATIVE)
116-
# --------------------------------------------------
117-
series = data_df[segmentation_variable]
118-
is_numeric = pd.api.types.is_numeric_dtype(series)
84+
seg_series = data_df[segmentation_variable]
85+
is_numeric = pd.api.types.is_numeric_dtype(seg_series)
11986

12087
if not is_numeric:
121-
series = series.astype(str).str.strip().str.lower()
122-
category_map = {
123-
str(seg.value).strip().lower(): seg.name for seg in segments
124-
}
125-
126-
def assign_segment(value):
127-
if pd.isna(value):
128-
return None
129-
130-
# ---------- NUMERIC ----------
131-
if is_numeric:
132-
prev = None
133-
for seg in segments:
134-
bound = float(seg.value)
135-
if prev is None and value <= bound:
136-
return seg.name
137-
if prev is not None and prev < value <= bound:
138-
return seg.name
139-
prev = bound
140-
return segments[-1].name # last open segment
141-
142-
# ---------- CATEGORICAL ----------
143-
return category_map.get(str(value).strip().lower())
144-
145-
data_df["_segment"] = series.apply(assign_segment)
88+
seg_series = seg_series.astype(str).str.strip().str.lower()
14689

14790
# --------------------------------------------------
148-
# 5. Aggregate statistics per segment
91+
# Resolve Case Commodity Levels (FIXED)
14992
# --------------------------------------------------
150-
aggregated: Dict[str, Dict[str, Dict[str, float]]] = {}
151-
152-
for var in output_variables:
153-
if var not in data_df.columns:
154-
continue
155-
if not pd.api.types.is_numeric_dtype(data_df[var]):
156-
continue
157-
158-
stats = (
159-
data_df.groupby("_segment")[var]
160-
.agg(
161-
current="median",
162-
feasible=lambda x: x.quantile(0.9),
163-
)
164-
.dropna()
165-
.to_dict(orient="index")
166-
)
167-
168-
aggregated[var] = stats
169-
170-
# --------------------------------------------------
171-
# 6. Build Segment + SegmentAnswer payload
172-
# --------------------------------------------------
173-
segments_payload = []
174-
175-
# Generate case commodity value
17693
case_commodities = (
17794
session.query(CaseCommodity)
17895
.filter(CaseCommodity.case == case_id)
17996
.all()
18097
)
181-
case_commodities = [
182-
cm.simplify_with_case_commodity_level for cm in case_commodities
183-
]
184-
case_commodity_levels = {}
185-
case_commodity_breakdowns = {}
98+
99+
commodity_level_map = {}
186100
for cc in case_commodities:
187-
key = f"{cc['commodity_type']}"
188-
case_commodity_levels[key] = cc["id"]
189-
case_commodity_breakdowns[key] = cc["breakdown"]
190-
# eol case commodity
101+
if cc.commodity_type == CaseCommodityType.focus:
102+
commodity_level_map["primary"] = cc.id
103+
else:
104+
commodity_level_map[cc.commodity_type.value] = cc.id
191105

192-
for seg in segments:
193-
seg_name = seg.name
106+
# --------------------------------------------------
107+
# Process segments with BOUNDARY FILTERING
108+
# --------------------------------------------------
109+
segment_payloads = []
110+
111+
for idx, seg in enumerate(segments):
194112
seg_id = seg.id
113+
seg_name = seg.name
114+
115+
# ---------- APPLY SEGMENT FILTER ----------
116+
if is_numeric:
117+
lower = float(segments[idx - 1].value) if idx > 0 else None
118+
upper = float(seg.value)
119+
120+
if lower is None:
121+
mask = seg_series <= upper
122+
else:
123+
mask = (seg_series > lower) & (seg_series <= upper)
124+
else:
125+
mask = seg_series == str(seg.value).strip().lower()
195126

196-
seg_df = data_df[data_df["_segment"] == seg_name]
127+
seg_df = data_df[mask]
197128
number_of_farmers = int(len(seg_df))
198129

199-
seg_answers = []
130+
answers = []
200131

132+
# ---------- PER-SEGMENT AGGREGATION ----------
201133
for _, row in mapping_df.iterrows():
202-
var = str(row["variable_name"]).lower()
134+
raw_id = row.get("id", None)
135+
level = None
136+
qid = None
203137

204-
if var not in aggregated:
138+
if not raw_id:
205139
continue
206-
if seg_name not in aggregated[var]:
140+
141+
var = str(row["variable_name"]).strip().lower()
142+
if var not in seg_df.columns:
207143
continue
208144

209-
questionID = row.get("id")
210-
[qLevel, qID] = (
211-
questionID.split("-")
212-
if questionID and "-" in questionID
213-
else [None, None]
214-
)
145+
if not pd.api.types.is_numeric_dtype(seg_df[var]):
146+
continue
147+
148+
values = seg_df[var].dropna()
149+
if values.empty:
150+
continue
151+
152+
current_value = float(values.median())
153+
feasible_value = float(values.quantile(0.9))
154+
155+
if raw_id and "-" in str(raw_id):
156+
level, qid = raw_id.split("-", 1)
157+
level = level.lower()
158+
else:
159+
qid = raw_id
160+
# Default commodity level when mapping id has no prefix
161+
level = CaseCommodityType.diversified.value
162+
163+
# check for case_commodity_id
164+
case_commodity_id = commodity_level_map.get(level)
165+
if not case_commodity_id:
166+
raise HTTPException(
167+
status_code=400,
168+
detail=(
169+
f"Case commodity not found for level '{level}' (mapping id: {raw_id})" # noqa
170+
),
171+
)
172+
215173
question = resolve_question(
216174
session=session,
217-
question_id=qID,
175+
question_id=qid,
218176
public_key=row.get("public_key"),
219177
)
220178

221-
stats = aggregated[var][seg_name]
222-
223-
# primary / secondary / tertiary
224-
case_commodity_id = case_commodity_levels.get(qLevel)
225-
# build SegmentAnswer
226-
payload = SegmentAnswerBase(
227-
case_commodity=case_commodity_id,
228-
segment=seg_id,
229-
question=question.id,
230-
current_value=float(stats["current"]),
231-
feasible_value=float(stats["feasible"]),
179+
answers.append(
180+
SegmentAnswerBase(
181+
case_commodity=case_commodity_id,
182+
segment=seg_id,
183+
question=question.id,
184+
current_value=current_value,
185+
feasible_value=feasible_value,
186+
)
232187
)
233-
seg_answers.append(payload)
234188

235-
segments_payload.append(
189+
segment_payloads.append(
236190
SegmentUpdateBase(
237191
id=seg_id,
238192
name=seg_name,
239193
case=case_id,
240194
number_of_farmers=number_of_farmers,
241-
answers=seg_answers,
195+
answers=answers,
242196
)
243197
)
244198

245-
# Save segment answers and update segment number_of_farmers
246-
update_segment(
247-
session=session,
248-
payloads=segments_payload,
249-
)
250-
251199
# --------------------------------------------------
252-
# 7. Cleanup
200+
# Persist
253201
# --------------------------------------------------
202+
update_segment(session=session, payloads=segment_payloads)
203+
254204
try:
255-
REMOVE = False
256-
if REMOVE:
257-
os.remove(case_import.file_path)
205+
os.remove(case_import.file_path)
258206
except Exception:
259207
pass
260208

261-
# --------------------------------------------------
262-
# 8. Response
263-
# --------------------------------------------------
264209
return {
265210
"status": "success",
266211
"case_id": case_id,
267-
"segments": segments_payload,
268-
"total_segments": len(segments_payload),
269-
"drivers_processed": len(aggregated),
212+
"segments": segment_payloads,
213+
"total_segments": len(segment_payloads),
270214
}

0 commit comments

Comments
 (0)