55from io import BytesIO
66
77from models .question import Question
8- from models .case_commodity import CaseCommodity
8+ from models .case_commodity import CaseCommodity , CaseCommodityType
99from models .segment import SegmentUpdateBase , SegmentAnswerBase
1010from utils .case_import_storage import load_import_file
1111from db .crud_case_import import get_case_import
@@ -25,10 +25,7 @@ def resolve_question(
2525 q = session .get (Question , int (question_id ))
2626 if q :
2727 return q
28- raise HTTPException (
29- status_code = 400 ,
30- detail = f"Question not found for id { question_id } " ,
31- )
28+ raise HTTPException (400 , f"Question not found for id { question_id } " )
3229
3330 if public_key :
3431 q = (
@@ -39,34 +36,29 @@ def resolve_question(
3936 if q :
4037 return q
4138 raise HTTPException (
42- status_code = 400 ,
43- detail = f"Question not found for public_key '{ public_key } '" ,
39+ 400 , f"Question not found for public_key '{ public_key } '"
4440 )
4541
4642 raise HTTPException (
47- status_code = 400 ,
48- detail = "Mapping row must contain either id or public_key" ,
43+ 400 , "Mapping row must contain either id or public_key"
4944 )
5045
5146
5247# --------------------------------------------------
53- # Main processor
48+ # Main processor (FINAL FIX)
5449# --------------------------------------------------
5550def process_confirmed_segmentation (
5651 * ,
5752 payload ,
5853 session ,
5954) -> Dict [str , Any ]:
6055
61- # --------------------------------------------------
62- # 1. Inputs
63- # --------------------------------------------------
6456 case_id = payload .case_id
6557 segmentation_variable = payload .segmentation_variable .strip ().lower ()
6658 segments = sorted (payload .segments , key = lambda s : s .index )
6759
6860 # --------------------------------------------------
69- # 2. Load import file (bytes)
61+ # Load import file
7062 # --------------------------------------------------
7163 case_import = get_case_import (session = session , import_id = payload .import_id )
7264 content = load_import_file (case_import .file_path )
@@ -76,195 +68,147 @@ def process_confirmed_segmentation(
7668 data_df = pd .read_excel (xls , sheet_name = "data" )
7769 mapping_df = pd .read_excel (xls , sheet_name = "mapping" )
7870
79- # normalize column names
8071 data_df .columns = data_df .columns .str .strip ().str .lower ()
8172 mapping_df .columns = mapping_df .columns .str .strip ().str .lower ()
8273 except Exception :
83- raise HTTPException (
84- status_code = 400 ,
85- detail = "Failed to read import workbook" ,
86- )
74+ raise HTTPException (400 , "Failed to read import workbook" )
8775
8876 if segmentation_variable not in data_df .columns :
8977 raise HTTPException (
90- status_code = 400 ,
91- detail = f"Segmentation variable { segmentation_variable } not found" ,
78+ 400 , f"Segmentation variable '{ segmentation_variable } ' not found"
9279 )
9380
9481 # --------------------------------------------------
95- # 3. Validate mapping sheet
82+ # Prepare segmentation series
9683 # --------------------------------------------------
97- if "variable_name" not in mapping_df .columns :
98- raise HTTPException (
99- status_code = 400 ,
100- detail = "Mapping sheet must contain 'variable_name'" ,
101- )
102-
103- if not {"id" , "public_key" } & set (mapping_df .columns ):
104- raise HTTPException (
105- status_code = 400 ,
106- detail = "Mapping sheet must contain at least 'id' or 'public_key'" ,
107- )
108-
109- # Only variables explicitly mapped are output drivers
110- output_variables = (
111- mapping_df ["variable_name" ].dropna ().str .lower ().tolist ()
112- )
113-
114- # --------------------------------------------------
115- # 4. Segment assignment (UI AUTHORITATIVE)
116- # --------------------------------------------------
117- series = data_df [segmentation_variable ]
118- is_numeric = pd .api .types .is_numeric_dtype (series )
84+ seg_series = data_df [segmentation_variable ]
85+ is_numeric = pd .api .types .is_numeric_dtype (seg_series )
11986
12087 if not is_numeric :
121- series = series .astype (str ).str .strip ().str .lower ()
122- category_map = {
123- str (seg .value ).strip ().lower (): seg .name for seg in segments
124- }
125-
126- def assign_segment (value ):
127- if pd .isna (value ):
128- return None
129-
130- # ---------- NUMERIC ----------
131- if is_numeric :
132- prev = None
133- for seg in segments :
134- bound = float (seg .value )
135- if prev is None and value <= bound :
136- return seg .name
137- if prev is not None and prev < value <= bound :
138- return seg .name
139- prev = bound
140- return segments [- 1 ].name # last open segment
141-
142- # ---------- CATEGORICAL ----------
143- return category_map .get (str (value ).strip ().lower ())
144-
145- data_df ["_segment" ] = series .apply (assign_segment )
88+ seg_series = seg_series .astype (str ).str .strip ().str .lower ()
14689
14790 # --------------------------------------------------
148- # 5. Aggregate statistics per segment
91+ # Resolve Case Commodity Levels (FIXED)
14992 # --------------------------------------------------
150- aggregated : Dict [str , Dict [str , Dict [str , float ]]] = {}
151-
152- for var in output_variables :
153- if var not in data_df .columns :
154- continue
155- if not pd .api .types .is_numeric_dtype (data_df [var ]):
156- continue
157-
158- stats = (
159- data_df .groupby ("_segment" )[var ]
160- .agg (
161- current = "median" ,
162- feasible = lambda x : x .quantile (0.9 ),
163- )
164- .dropna ()
165- .to_dict (orient = "index" )
166- )
167-
168- aggregated [var ] = stats
169-
170- # --------------------------------------------------
171- # 6. Build Segment + SegmentAnswer payload
172- # --------------------------------------------------
173- segments_payload = []
174-
175- # Generate case commodity value
17693 case_commodities = (
17794 session .query (CaseCommodity )
17895 .filter (CaseCommodity .case == case_id )
17996 .all ()
18097 )
181- case_commodities = [
182- cm .simplify_with_case_commodity_level for cm in case_commodities
183- ]
184- case_commodity_levels = {}
185- case_commodity_breakdowns = {}
98+
99+ commodity_level_map = {}
186100 for cc in case_commodities :
187- key = f" { cc [ ' commodity_type' ] } "
188- case_commodity_levels [ key ] = cc [ "id" ]
189- case_commodity_breakdowns [ key ] = cc [ "breakdown" ]
190- # eol case commodity
101+ if cc . commodity_type == CaseCommodityType . focus :
102+ commodity_level_map [ "primary" ] = cc . id
103+ else :
104+ commodity_level_map [ cc . commodity_type . value ] = cc . id
191105
192- for seg in segments :
193- seg_name = seg .name
106+ # --------------------------------------------------
107+ # Process segments with BOUNDARY FILTERING
108+ # --------------------------------------------------
109+ segment_payloads = []
110+
111+ for idx , seg in enumerate (segments ):
194112 seg_id = seg .id
113+ seg_name = seg .name
114+
115+ # ---------- APPLY SEGMENT FILTER ----------
116+ if is_numeric :
117+ lower = float (segments [idx - 1 ].value ) if idx > 0 else None
118+ upper = float (seg .value )
119+
120+ if lower is None :
121+ mask = seg_series <= upper
122+ else :
123+ mask = (seg_series > lower ) & (seg_series <= upper )
124+ else :
125+ mask = seg_series == str (seg .value ).strip ().lower ()
195126
196- seg_df = data_df [data_df [ "_segment" ] == seg_name ]
127+ seg_df = data_df [mask ]
197128 number_of_farmers = int (len (seg_df ))
198129
199- seg_answers = []
130+ answers = []
200131
132+ # ---------- PER-SEGMENT AGGREGATION ----------
201133 for _ , row in mapping_df .iterrows ():
202- var = str (row ["variable_name" ]).lower ()
134+ raw_id = row .get ("id" , None )
135+ level = None
136+ qid = None
203137
204- if var not in aggregated :
138+ if not raw_id :
205139 continue
206- if seg_name not in aggregated [var ]:
140+
141+ var = str (row ["variable_name" ]).strip ().lower ()
142+ if var not in seg_df .columns :
207143 continue
208144
209- questionID = row .get ("id" )
210- [qLevel , qID ] = (
211- questionID .split ("-" )
212- if questionID and "-" in questionID
213- else [None , None ]
214- )
145+ if not pd .api .types .is_numeric_dtype (seg_df [var ]):
146+ continue
147+
148+ values = seg_df [var ].dropna ()
149+ if values .empty :
150+ continue
151+
152+ current_value = float (values .median ())
153+ feasible_value = float (values .quantile (0.9 ))
154+
155+ if raw_id and "-" in str (raw_id ):
156+ level , qid = raw_id .split ("-" , 1 )
157+ level = level .lower ()
158+ else :
159+ qid = raw_id
160+ # Default commodity level when mapping id has no prefix
161+ level = CaseCommodityType .diversified .value
162+
163+ # check for case_commodity_id
164+ case_commodity_id = commodity_level_map .get (level )
165+ if not case_commodity_id :
166+ raise HTTPException (
167+ status_code = 400 ,
168+ detail = (
169+ f"Case commodity not found for level '{ level } ' (mapping id: { raw_id } )" # noqa
170+ ),
171+ )
172+
215173 question = resolve_question (
216174 session = session ,
217- question_id = qID ,
175+ question_id = qid ,
218176 public_key = row .get ("public_key" ),
219177 )
220178
221- stats = aggregated [var ][seg_name ]
222-
223- # primary / secondary / tertiary
224- case_commodity_id = case_commodity_levels .get (qLevel )
225- # build SegmentAnswer
226- payload = SegmentAnswerBase (
227- case_commodity = case_commodity_id ,
228- segment = seg_id ,
229- question = question .id ,
230- current_value = float (stats ["current" ]),
231- feasible_value = float (stats ["feasible" ]),
179+ answers .append (
180+ SegmentAnswerBase (
181+ case_commodity = case_commodity_id ,
182+ segment = seg_id ,
183+ question = question .id ,
184+ current_value = current_value ,
185+ feasible_value = feasible_value ,
186+ )
232187 )
233- seg_answers .append (payload )
234188
235- segments_payload .append (
189+ segment_payloads .append (
236190 SegmentUpdateBase (
237191 id = seg_id ,
238192 name = seg_name ,
239193 case = case_id ,
240194 number_of_farmers = number_of_farmers ,
241- answers = seg_answers ,
195+ answers = answers ,
242196 )
243197 )
244198
245- # Save segment answers and update segment number_of_farmers
246- update_segment (
247- session = session ,
248- payloads = segments_payload ,
249- )
250-
251199 # --------------------------------------------------
252- # 7. Cleanup
200+ # Persist
253201 # --------------------------------------------------
202+ update_segment (session = session , payloads = segment_payloads )
203+
254204 try :
255- REMOVE = False
256- if REMOVE :
257- os .remove (case_import .file_path )
205+ os .remove (case_import .file_path )
258206 except Exception :
259207 pass
260208
261- # --------------------------------------------------
262- # 8. Response
263- # --------------------------------------------------
264209 return {
265210 "status" : "success" ,
266211 "case_id" : case_id ,
267- "segments" : segments_payload ,
268- "total_segments" : len (segments_payload ),
269- "drivers_processed" : len (aggregated ),
212+ "segments" : segment_payloads ,
213+ "total_segments" : len (segment_payloads ),
270214 }
0 commit comments