Skip to content

Commit f781c87

Browse files
authored
Merge pull request #209 from PolicyEngine/fix/apply-labels-to-nested-breakdowns
feat: Add support for breakdown_labels and nested breakdowns
1 parent 4f1f237 commit f781c87

File tree

3 files changed

+617
-20
lines changed

3 files changed

+617
-20
lines changed

src/policyengine/utils/parameter_labels.py

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,34 +26,140 @@ def generate_label_for_parameter(param_node, system, scale_lookup):
2626
if "[" in param_name:
2727
return _generate_bracket_label(param_name, scale_lookup)
2828

29-
if param_node.parent and param_node.parent.metadata.get("breakdown"):
30-
return _generate_breakdown_label(param_node, system)
29+
# Check for breakdown - either direct child or nested
30+
breakdown_parent = _find_breakdown_parent(param_node)
31+
if breakdown_parent:
32+
return _generate_breakdown_label(param_node, system, breakdown_parent)
3133

3234
return None
3335

3436

35-
def _generate_breakdown_label(param_node, system):
36-
"""Generate label for a breakdown parameter using enum values."""
37-
parent = param_node.parent
38-
parent_label = parent.metadata.get("label")
39-
breakdown_vars = parent.metadata.get("breakdown", [])
37+
def _find_breakdown_parent(param_node):
38+
"""
39+
Walk up the tree to find the nearest ancestor with breakdown metadata.
40+
41+
Args:
42+
param_node: The CoreParameter object
43+
44+
Returns:
45+
The breakdown parent node, or None if not found
46+
"""
47+
current = param_node.parent
48+
while current:
49+
if current.metadata.get("breakdown"):
50+
return current
51+
current = getattr(current, "parent", None)
52+
return None
53+
54+
55+
def _generate_breakdown_label(param_node, system, breakdown_parent=None):
56+
"""
57+
Generate label for a breakdown parameter using enum values.
58+
59+
Handles both single-level and nested breakdowns by walking up to the
60+
breakdown parent and collecting all dimension values.
4061
62+
Args:
63+
param_node: The CoreParameter object
64+
system: The tax-benefit system
65+
breakdown_parent: The ancestor node with breakdown metadata (optional)
66+
67+
Returns:
68+
str or None: Generated label, or None if cannot generate
69+
"""
70+
# Find breakdown parent if not provided
71+
if breakdown_parent is None:
72+
breakdown_parent = _find_breakdown_parent(param_node)
73+
if not breakdown_parent:
74+
return None
75+
76+
parent_label = breakdown_parent.metadata.get("label")
4177
if not parent_label:
4278
return None
4379

44-
child_key = param_node.name.split(".")[-1]
80+
breakdown_vars = breakdown_parent.metadata.get("breakdown", [])
81+
breakdown_labels = breakdown_parent.metadata.get("breakdown_labels", [])
82+
83+
# Collect dimension values from breakdown parent to param_node
84+
dimension_values = _collect_dimension_values(
85+
param_node, breakdown_parent
86+
)
87+
88+
if not dimension_values:
89+
return None
90+
91+
# Generate labels for each dimension
92+
formatted_parts = []
93+
for i, (dim_key, dim_value) in enumerate(dimension_values):
94+
var_name = breakdown_vars[i] if i < len(breakdown_vars) else None
95+
dim_label = breakdown_labels[i] if i < len(breakdown_labels) else None
4596

46-
for var_name in breakdown_vars:
97+
formatted_value = _format_dimension_value(
98+
dim_value, var_name, dim_label, system
99+
)
100+
formatted_parts.append(formatted_value)
101+
102+
return f"{parent_label} ({', '.join(formatted_parts)})"
103+
104+
105+
def _collect_dimension_values(param_node, breakdown_parent):
106+
"""
107+
Collect dimension keys and values from breakdown parent to param_node.
108+
109+
Args:
110+
param_node: The CoreParameter object
111+
breakdown_parent: The ancestor node with breakdown metadata
112+
113+
Returns:
114+
list of (dimension_key, value) tuples, ordered from parent to child
115+
"""
116+
# Build path from param_node up to breakdown_parent
117+
path = []
118+
current = param_node
119+
while current and current != breakdown_parent:
120+
path.append(current)
121+
current = getattr(current, "parent", None)
122+
123+
# Reverse to get parent-to-child order
124+
path.reverse()
125+
126+
# Extract dimension values
127+
dimension_values = []
128+
for i, node in enumerate(path):
129+
key = node.name.split(".")[-1]
130+
dimension_values.append((i, key))
131+
132+
return dimension_values
133+
134+
135+
def _format_dimension_value(value, var_name, dim_label, system):
136+
"""
137+
Format a single dimension value with semantic label if available.
138+
139+
Args:
140+
value: The raw dimension value (e.g., "SINGLE", "1", "CA")
141+
var_name: The breakdown variable name (e.g., "filing_status", "range(1, 9)")
142+
dim_label: The human-readable label for this dimension (e.g., "Household size")
143+
system: The tax-benefit system
144+
145+
Returns:
146+
str: Formatted dimension value
147+
"""
148+
# First, try to get enum display value
149+
if var_name and isinstance(var_name, str) and not var_name.startswith("range(") and not var_name.startswith("list("):
47150
var = system.variables.get(var_name)
48151
if var and hasattr(var, "possible_values") and var.possible_values:
49-
enum_class = var.possible_values
50152
try:
51-
enum_value = enum_class[child_key].value
52-
return f"{parent_label} ({enum_value})"
153+
enum_value = var.possible_values[value].value
154+
return str(enum_value)
53155
except (KeyError, AttributeError):
54-
continue
156+
pass
157+
158+
# For range() dimensions or when no enum found, use breakdown_label if available
159+
if dim_label:
160+
return f"{dim_label} {value}"
55161

56-
return f"{parent_label} ({child_key})"
162+
return value
57163

58164

59165
def _generate_bracket_label(param_name, scale_lookup):

tests/fixtures/parameter_labels_fixtures.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class MockFilingStatus(Enum):
1111
SINGLE = "Single"
1212
JOINT = "Joint"
1313
HEAD_OF_HOUSEHOLD = "Head of household"
14+
MARRIED_FILING_JOINTLY = "Married filing jointly"
1415

1516

1617
class MockStateCode(Enum):
@@ -38,16 +39,21 @@ def create_mock_parent_node(
3839
name: str,
3940
label: str | None = None,
4041
breakdown: list[str] | None = None,
42+
breakdown_labels: list[str] | None = None,
43+
parent: Any = None,
4144
) -> MagicMock:
4245
"""Create a mock parent ParameterNode with optional breakdown metadata."""
43-
parent = MagicMock()
44-
parent.name = name
45-
parent.metadata = {}
46+
node = MagicMock()
47+
node.name = name
48+
node.metadata = {}
49+
node.parent = parent
4650
if label:
47-
parent.metadata["label"] = label
51+
node.metadata["label"] = label
4852
if breakdown:
49-
parent.metadata["breakdown"] = breakdown
50-
return parent
53+
node.metadata["breakdown"] = breakdown
54+
if breakdown_labels:
55+
node.metadata["breakdown_labels"] = breakdown_labels
56+
return node
5157

5258

5359
def create_mock_scale(

0 commit comments

Comments
 (0)