Skip to content

Commit bf6db3f

Browse files
authored
Use match statement to determine distribution
1 parent 4f1bdbf commit bf6db3f

File tree

1 file changed

+101
-81
lines changed

1 file changed

+101
-81
lines changed

src/semeio/fmudesign/design_distributions.py

Lines changed: 101 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -143,132 +143,152 @@ def to_probabilit(
143143
distname=distname, parameters=list(dist_parameters)
144144
)
145145

146-
# Normal distribution
147-
if distname.startswith("norm"):
148-
if len(parameters) not in (2, 4):
149-
raise ValueError(
150-
f"Normal must have 2 or 4 parameters, got: {len(parameters)} ({parameters})"
151-
)
152-
if distname.lower().startswith("normal_p10_p90"):
153-
p10, p90 = parameters[:2]
146+
match (distname, parameters):
147+
# ================== NORMAL ==================
148+
149+
case [_, (p10, p90)] if distname.startswith("normal_p10_p90"):
154150
# We use the equations
155151
# p10 = mu - z*sigma and p90 = mu + z*sigma
156152
# to find mu and sigma
157153
z_score = stats.norm.ppf(0.9)
158154
mean = (p10 + p90) / 2
159-
stddev = (p90 - p10) / (2 * z_score)
160-
else:
161-
mean, stddev = parameters[:2]
162-
if len(parameters) == 2:
163-
return probabilit.distributions.Normal(mean, stddev)
164-
elif len(parameters) == 4:
165-
low, high = parameters[2:]
155+
std = (p90 - p10) / (2 * z_score)
156+
return probabilit.distributions.Normal(mean=mean, std=std)
157+
158+
case [_, (p10, p90, low, high)] if distname.startswith("normal_p10_p90"):
159+
z_score = stats.norm.ppf(0.9)
160+
mean = (p10 + p90) / 2
161+
std = (p90 - p10) / (2 * z_score)
166162
return probabilit.distributions.TruncatedNormal(
167-
mean, stddev, low=low, high=high
163+
mean=mean, std=std, low=low, high=high
168164
)
169-
elif distname.startswith("logn"):
170-
if len(parameters) not in (2, 4):
165+
166+
case [_, (mean, std)] if distname.startswith("norm"):
167+
return probabilit.distributions.Normal(mean=mean, std=std)
168+
169+
case [_, (mean, std, low, high)] if distname.startswith("norm"):
170+
return probabilit.distributions.TruncatedNormal(
171+
mean=mean, std=std, low=low, high=high
172+
)
173+
174+
case [_, parameters] if distname.startswith("norm"):
171175
raise ValueError(
172-
f"Lognormal must have 2 or 4 parameters, got: {len(parameters)} ({parameters})"
176+
f"Normal must have 2 or 4 parameters, got: {len(parameters)} ({parameters})"
173177
)
174-
mean, sigma = parameters[:2]
175-
if len(parameters) == 2:
178+
179+
# ================== LOGNORMAL ==================
180+
181+
case [_, (mu, sigma)] if distname.startswith("logn"):
176182
return probabilit.distributions.Lognormal.from_log_params(
177-
mu=mean, sigma=sigma
183+
mu=mu, sigma=sigma
178184
)
179-
elif len(parameters) == 4:
180-
low, high = parameters[2:]
185+
186+
case [_, (mu, sigma, low, high)] if distname.startswith("logn"):
187+
# (mu, sigma) are defined in log-spce, but (low, high) are defined on exp-space
181188
return probabilit.modeling.Exp(
182189
probabilit.distributions.TruncatedNormal(
183-
mean, sigma, low=np.log(low), high=np.log(high)
190+
mu, sigma, low=np.log(low), high=np.log(high)
184191
)
185192
)
186-
elif distname.startswith("unif"):
187-
if len(parameters) != 2:
193+
194+
case [_, parameters] if distname.startswith("logn"):
188195
raise ValueError(
189-
f"Uniform must have 2 parameters, got: {len(parameters)} ({parameters})"
196+
f"Lognormal must have 2 or 4 parameters, got: {len(parameters)} ({parameters})"
190197
)
191-
if distname.startswith("uniform_p10_p90"):
192-
p10, p90 = parameters
198+
199+
# ================== UNIFORM ==================
200+
201+
case [_, (p10, p90)] if distname.startswith("uniform_p10_p90"):
193202
length = (p90 - p10) / 0.8
194-
low = p10 - 0.1 * length
195-
high = p90 + 0.1 * length
196-
else:
197-
low, high = parameters
198-
return probabilit.distributions.Uniform(minimum=low, maximum=high)
199-
elif distname.startswith("triang"):
200-
if len(parameters) != 3:
203+
minimum = p10 - 0.1 * length
204+
maximum = p90 + 0.1 * length
205+
return probabilit.distributions.Uniform(minimum=minimum, maximum=maximum)
206+
207+
case [_, (minimum, maximum)] if distname.startswith("unif"):
208+
return probabilit.distributions.Uniform(minimum=minimum, maximum=maximum)
209+
210+
case [_, parameters] if distname.startswith("unif"):
201211
raise ValueError(
202-
f"Triangular must have 3 parameters, got: {len(parameters)} ({parameters})"
212+
f"Uniform must have 2 parameters, got: {len(parameters)} ({parameters})"
203213
)
204-
low, mode, high = parameters
205-
if distname.startswith("triangular_p10_p90"):
214+
215+
# ================== TRIANGULAR ==================
216+
217+
case [_, (low, mode, high)] if distname.startswith("triangular_p10_p90"):
206218
return probabilit.distributions.Triangular(
207219
low=low, mode=mode, high=high, low_perc=0.1, high_perc=0.9
208220
)
209-
else:
221+
222+
case [_, (minimum, mode, maximum)] if distname.startswith("triang"):
210223
return probabilit.distributions.Triangular(
211-
low=low, mode=mode, high=high, low_perc=0.0, high_perc=1.0
224+
low=minimum, mode=mode, high=maximum, low_perc=0.0, high_perc=1.0
225+
)
226+
227+
case [_, parameters] if distname.startswith("triang"):
228+
raise ValueError(
229+
f"Triangular must have 3 parameters, got: {len(parameters)} ({parameters})"
212230
)
213-
elif distname.startswith("beta"):
214-
if len(parameters) == 2:
215-
a, b = parameters
231+
232+
# ================== BETA ==================
233+
234+
case [_, (a, b)] if distname.startswith("beta"):
216235
# Defaults to probabilit.Distribution("beta", a=a, b=b, loc=0, scale=1)
217236
return probabilit.Distribution("beta", a=a, b=b)
218-
if len(parameters) == 4:
219-
a, b, low, high = parameters
237+
238+
case [_, (a, b, low, high)] if distname.startswith("beta"):
220239
loc = low
221240
scale = high - low
222241
return probabilit.Distribution("beta", a=a, b=b, loc=loc, scale=scale)
223-
else:
242+
243+
case [_, parameters] if distname.startswith("beta"):
224244
raise ValueError(
225245
f"Beta must have 2 or 4 parameters, got: {len(parameters)} ({parameters})"
226246
)
227-
elif distname.startswith("pert"):
228-
if len(parameters) not in (3, 4):
229-
raise ValueError(
230-
f"PERT must have 3 or 4 parameters, got: {len(parameters)} ({parameters})"
247+
248+
# ================== PERT ==================
249+
250+
case [_, (low, mode, high)] if distname.startswith("pert_p10_p90"):
251+
return probabilit.distributions.PERT(
252+
low=low, mode=mode, high=high, low_perc=0.1, high_perc=0.9
231253
)
232-
if distname.startswith("pert_p10_p90"):
233-
if len(parameters) == 3:
234-
low, mode, high = parameters
235-
return probabilit.distributions.PERT(
236-
low=low, mode=mode, high=high, low_perc=0.1, high_perc=0.9
237-
)
238-
elif len(parameters) == 4:
239-
low, mode, high, scale = parameters
240-
return probabilit.distributions.PERT(
241-
low=low,
242-
mode=mode,
243-
high=high,
244-
low_perc=0.1,
245-
high_perc=0.9,
246-
gamma=scale,
247-
)
248-
elif len(parameters) == 3:
249-
low, mode, high = parameters
254+
255+
case [_, (low, mode, high, scale)] if distname.startswith("pert_p10_p90"):
256+
return probabilit.distributions.PERT(
257+
low=low, mode=mode, high=high, low_perc=0.1, high_perc=0.9, gamma=scale
258+
)
259+
260+
case [_, (minimum, mode, maximum)] if distname.startswith("pert"):
250261
return probabilit.distributions.PERT(
251-
low=low, mode=mode, high=high, low_perc=0.0, high_perc=1.0
262+
low=minimum, mode=mode, high=maximum, low_perc=0.0, high_perc=1.0
252263
)
253-
elif len(parameters) == 4:
254-
low, mode, high, scale = parameters
264+
265+
case [_, (minimum, mode, maximum, scale)] if distname.startswith("pert"):
255266
return probabilit.distributions.PERT(
256-
low=low,
267+
low=minimum,
257268
mode=mode,
258-
high=high,
269+
high=maximum,
259270
low_perc=0.0,
260271
high_perc=1.0,
261272
gamma=scale,
262273
)
263-
elif distname.startswith("logunif"):
264-
if len(parameters) != 2:
274+
275+
case [_, parameters] if distname.startswith("pert"):
276+
raise ValueError(
277+
f"PERT must have 3 or 4 parameters, got: {len(parameters)} ({parameters})"
278+
)
279+
280+
# ================== LOGUNIFORM ==================
281+
282+
case [_, (low, high)] if distname.startswith("logunif"):
283+
return probabilit.Distribution("loguniform", low, high)
284+
285+
case [_, parameters] if distname.startswith("logunif"):
265286
raise ValueError(
266287
f"Loguniform must have 2 parameters, got: {len(parameters)} ({parameters})"
267288
)
268-
low, high = parameters
269-
return probabilit.Distribution("loguniform", a=low, b=high)
270-
else:
271-
raise ValueError(f"Distribution name {distname} is not implemented")
289+
290+
case [distname, parameters]:
291+
raise ValueError(f"Invalid combination of {distname=} and {parameters=}.")
272292

273293

274294
def is_number(teststring: str) -> bool:

0 commit comments

Comments
 (0)