Skip to content

Commit b0bee7f

Browse files
committed
Fixed logic tree with requirements.
1 parent e699a4f commit b0bee7f

File tree

3 files changed

+364
-5
lines changed

3 files changed

+364
-5
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
[
2+
{
3+
"name": "site_classification",
4+
"alts": [
5+
{
6+
"value": "C",
7+
"weight": 0.3,
8+
"params": {
9+
"vs30_range": [
10+
360,
11+
760
12+
]
13+
}
14+
},
15+
{
16+
"value": "D",
17+
"weight": 0.6,
18+
"params": {
19+
"vs30_range": [
20+
180,
21+
360
22+
]
23+
}
24+
},
25+
{
26+
"value": "E",
27+
"weight": 0.1,
28+
"params": {
29+
"vs30_range": [
30+
0,
31+
180
32+
]
33+
}
34+
}
35+
]
36+
},
37+
{
38+
"name": "kappa",
39+
"alts": [
40+
{
41+
"value": 0.02,
42+
"weight": 0.2,
43+
"requires": {
44+
"site_classification": "C"
45+
}
46+
},
47+
{
48+
"value": 0.03,
49+
"weight": 0.6,
50+
"requires": {
51+
"site_classification": "C"
52+
}
53+
},
54+
{
55+
"value": 0.04,
56+
"weight": 0.2,
57+
"requires": {
58+
"site_classification": "C"
59+
}
60+
},
61+
{
62+
"value": 0.03,
63+
"weight": 0.2,
64+
"requires": {
65+
"site_classification": "D"
66+
}
67+
},
68+
{
69+
"value": 0.04,
70+
"weight": 0.6,
71+
"requires": {
72+
"site_classification": "D"
73+
}
74+
},
75+
{
76+
"value": 0.05,
77+
"weight": 0.2,
78+
"requires": {
79+
"site_classification": "D"
80+
}
81+
},
82+
{
83+
"value": 0.05,
84+
"weight": 0.4,
85+
"requires": {
86+
"site_classification": "E"
87+
}
88+
},
89+
{
90+
"value": 0.06,
91+
"weight": 0.6,
92+
"requires": {
93+
"site_classification": "E"
94+
}
95+
}
96+
]
97+
},
98+
{
99+
"name": "randomization_method",
100+
"alts": [
101+
{
102+
"value": "monte_carlo",
103+
"weight": 0.4,
104+
"params": {
105+
"n_realizations": 1000
106+
}
107+
},
108+
{
109+
"value": "latin_hypercube",
110+
"weight": 0.4,
111+
"params": {
112+
"n_realizations": 100
113+
}
114+
},
115+
{
116+
"value": "deterministic",
117+
"weight": 0.2,
118+
"params": {
119+
"n_realizations": 1
120+
}
121+
}
122+
]
123+
}
124+
]

src/pystrata/logic_tree.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,56 @@ class Realization:
2828
The weight of the realization, by default 1.
2929
params : Dict[str, Any], optional
3030
Additional parameters for the realization, by default an empty dict.
31+
requires : Dict[str, Any], optional
32+
Requirements for this realization to be valid, by default an empty dict.
33+
excludes : Dict[str, Any], optional
34+
Exclusions for this realization to be valid, by default an empty dict.
3135
"""
3236

3337
name: str
3438
value: str | float | int
3539
weight: float = 1
3640
params: dict[str, Any] = field(default_factory=dict)
41+
requires: dict[str, Any] = field(default_factory=dict)
42+
excludes: dict[str, Any] = field(default_factory=dict)
43+
44+
def is_valid(self, branch):
45+
"""
46+
Check if this realization is valid given a branch.
47+
48+
Parameters
49+
----------
50+
branch : Branch
51+
The branch to check against.
52+
53+
Returns
54+
-------
55+
bool
56+
True if the realization is valid, False otherwise.
57+
"""
58+
59+
def matches(ref, check):
60+
if isinstance(ref, list):
61+
ret = check in ref
62+
elif isinstance(ref, float):
63+
ret = np.isclose(ref, check)
64+
else:
65+
ret = ref == check
66+
return ret
67+
68+
okay = True
69+
70+
if self.requires:
71+
# Check that the required realizations are present
72+
okay = all(matches(v, branch[k].value) for k, v in self.requires.items())
73+
74+
if okay and self.excludes:
75+
# Check that the excludes realizations are _not_ present
76+
okay &= not all(
77+
matches(v, branch[k].value) for k, v in self.excludes.items()
78+
)
79+
80+
return okay
3781

3882

3983
@dataclass
@@ -150,7 +194,9 @@ def by_value(self, value):
150194
def __iter__(self):
151195
for a in self.alts:
152196
if a.weight > 0:
153-
yield Realization(self.name, a.value, a.weight, a.params)
197+
yield Realization(
198+
self.name, a.value, a.weight, a.params, a.requires, a.excludes
199+
)
154200

155201
def to_xarray(self, dim_name: str, name: str = "") -> xr.DataArray:
156202
"""
@@ -347,10 +393,18 @@ class LogicTree:
347393
nodes: list[Node]
348394

349395
def __iter__(self) -> Branch:
396+
# Keep track of seen branches to avoid duplicates when multiple alternatives
397+
# have the same value but different conditional requirements
398+
seen_branches = set()
399+
350400
for reals in itertools.product(*self.nodes):
351401
branch = Branch({r.name: r for r in reals})
352402
if self.is_valid(branch):
353-
yield branch
403+
# Create a hashable representation of the branch values
404+
branch_key = tuple(sorted((k, v) for k, v in branch.as_dict().items()))
405+
if branch_key not in seen_branches:
406+
seen_branches.add(branch_key)
407+
yield branch
354408

355409
def is_valid(self, branch):
356410
"""
@@ -367,10 +421,10 @@ def is_valid(self, branch):
367421
True if the branch is valid, False otherwise.
368422
"""
369423
for param in branch.params.values():
370-
# Select the alternative on the logic tree
371-
alt = self[param.name].by_value(param.value)
372-
if not alt.is_valid(branch):
424+
# Check if this specific realization is valid for the branch
425+
if not param.is_valid(branch):
373426
return False
427+
374428
return True
375429

376430
def __getitem__(self, key):

0 commit comments

Comments
 (0)