Skip to content

Commit f91f392

Browse files
committed
Some linting
1 parent 532558a commit f91f392

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

causal_testing/generation/enum_gen.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,40 @@
1+
"""This module contains the class EnumGen, which allows us to easily create generating uniform distributions from enums."""
2+
13
from scipy.stats import rv_discrete
4+
from enum import Enum
5+
import numpy as np
26

37

48
class EnumGen(rv_discrete):
5-
def __init__(self, dt: Enum):
6-
self.dt = dict(enumerate(dt, 1))
9+
"""This class allows us to easily create generating uniform distributions from enums. This is helpful for generating concrete test inputs from abstract test cases."""
10+
11+
def __init__(self, datatype: Enum):
12+
self.dt = dict(enumerate(datatype, 1))
713
self.inverse_dt = {v: k for k, v in self.dt.items()}
814

9-
def ppf(self, q, *args, **kwds):
15+
def ppf(self, q):
16+
"""Percent point function (inverse of `cdf`) at q of the given RV.
17+
Parameters
18+
----------
19+
q : array_like
20+
Lower tail probability.
21+
Returns
22+
-------
23+
k : array_like
24+
Quantile corresponding to the lower tail probability, q.
25+
"""
1026
return np.vectorize(self.dt.get)(np.ceil(len(self.dt) * q))
1127

12-
def cdf(self, q, *args, **kwds):
13-
return np.vectorize(self.inverse_dt.get)(q) / len(Car)
28+
def cdf(self, q):
29+
"""
30+
Cumulative distribution function of the given RV.
31+
Parameters
32+
----------
33+
q : array_like
34+
quantiles
35+
Returns
36+
-------
37+
cdf : ndarray
38+
Cumulative distribution function evaluated at `x`
39+
"""
40+
return np.vectorize(self.inverse_dt.get)(q) / len(self.dt)

causal_testing/json_front/json_class.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def _create_abstract_test_case(self, test, mutates, effects):
7777
assert len(test["mutations"]) == 1
7878
treatment_var = next(self.scenario.variables[v] for v in test["mutations"])
7979
if not treatment_var.distribution:
80-
fitter = Fitter(self.data[var.name], distributions=get_common_distributions())
80+
fitter = Fitter(self.data[treatment_var.name], distributions=get_common_distributions())
8181
fitter.fit()
8282
(dist, params) = list(fitter.get_best(method="sumsquare_error").items())[0]
83-
var.distribution = getattr(scipy.stats, dist)(**params)
84-
self._append_to_file(var.name + f" {dist}({params})", logging.INFO)
83+
treatment_var.distribution = getattr(scipy.stats, dist)(**params)
84+
self._append_to_file(treatment_var.name + f" {dist}({params})", logging.INFO)
8585

8686
abstract_test = AbstractCausalTestCase(
8787
scenario=self.scenario,

0 commit comments

Comments
 (0)