@@ -28,12 +28,56 @@ class Realization:
2828 The weight of the realization, by default 1.
2929 params : Dict[str, Any], optional
3030 Additional parameters for the realization, by default an empty dict.
31+ requires : Dict[str, Any], optional
32+ Requirements for this realization to be valid, by default an empty dict.
33+ excludes : Dict[str, Any], optional
34+ Exclusions for this realization to be valid, by default an empty dict.
3135 """
3236
3337 name : str
3438 value : str | float | int
3539 weight : float = 1
3640 params : dict [str , Any ] = field (default_factory = dict )
41+ requires : dict [str , Any ] = field (default_factory = dict )
42+ excludes : dict [str , Any ] = field (default_factory = dict )
43+
44+ def is_valid (self , branch ):
45+ """
46+ Check if this realization is valid given a branch.
47+
48+ Parameters
49+ ----------
50+ branch : Branch
51+ The branch to check against.
52+
53+ Returns
54+ -------
55+ bool
56+ True if the realization is valid, False otherwise.
57+ """
58+
59+ def matches (ref , check ):
60+ if isinstance (ref , list ):
61+ ret = check in ref
62+ elif isinstance (ref , float ):
63+ ret = np .isclose (ref , check )
64+ else :
65+ ret = ref == check
66+ return ret
67+
68+ okay = True
69+
70+ if self .requires :
71+ # Check that the required realizations are present
72+ okay = all (matches (v , branch [k ].value ) for k , v in self .requires .items ())
73+
74+ if okay and self .excludes :
75+ # Check that the excludes realizations are _not_ present
76+ okay &= not all (
77+ matches (v , branch [k ].value ) for k , v in self .excludes .items ()
78+ )
79+
80+ return okay
3781
3882
3983@dataclass
@@ -150,7 +194,9 @@ def by_value(self, value):
150194 def __iter__ (self ):
151195 for a in self .alts :
152196 if a .weight > 0 :
153- yield Realization (self .name , a .value , a .weight , a .params )
197+ yield Realization (
198+ self .name , a .value , a .weight , a .params , a .requires , a .excludes
199+ )
154200
155201 def to_xarray (self , dim_name : str , name : str = "" ) -> xr .DataArray :
156202 """
@@ -347,10 +393,18 @@ class LogicTree:
347393 nodes : list [Node ]
348394
349395 def __iter__ (self ) -> Branch :
396+ # Keep track of seen branches to avoid duplicates when multiple alternatives
397+ # have the same value but different conditional requirements
398+ seen_branches = set ()
399+
350400 for reals in itertools .product (* self .nodes ):
351401 branch = Branch ({r .name : r for r in reals })
352402 if self .is_valid (branch ):
353- yield branch
403+ # Create a hashable representation of the branch values
404+ branch_key = tuple (sorted ((k , v ) for k , v in branch .as_dict ().items ()))
405+ if branch_key not in seen_branches :
406+ seen_branches .add (branch_key )
407+ yield branch
354408
355409 def is_valid (self , branch ):
356410 """
@@ -367,10 +421,10 @@ def is_valid(self, branch):
367421 True if the branch is valid, False otherwise.
368422 """
369423 for param in branch .params .values ():
370- # Select the alternative on the logic tree
371- alt = self [param .name ].by_value (param .value )
372- if not alt .is_valid (branch ):
424+ # Check if this specific realization is valid for the branch
425+ if not param .is_valid (branch ):
373426 return False
427+
374428 return True
375429
376430 def __getitem__ (self , key ):
0 commit comments