|
1 | | -""" |
2 | | -Copyright 2023-present PyMC Labs |
3 | | -
|
4 | | -Licensed under the Apache License, Version 2.0 (the "License"); |
5 | | -you may not use this file except in compliance with the License. |
6 | | -You may obtain a copy of the License at |
7 | | -
|
8 | | - http://www.apache.org/licenses/LICENSE-2.0 |
9 | | -
|
10 | | -Unless required by applicable law or agreed to in writing, software |
11 | | -distributed under the License is distributed on an "AS IS" BASIS, |
12 | | -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | | -See the License for the specific language governing permissions and |
14 | | -limitations under the License. |
15 | | -""" |
16 | | - |
| 1 | +# Copyright 2023 - present PyMC Labs |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
17 | 14 |
|
18 | 15 | import warnings |
19 | | - |
20 | 16 | from typing import Dict, Iterable, List, Optional |
21 | 17 |
|
22 | | -from pytensor.tensor import TensorVariable |
23 | 18 | from arviz import InferenceData |
24 | 19 | from pandas import DataFrame |
| 20 | +from pytensor.tensor import TensorVariable |
25 | 21 |
|
26 | 22 | from homepy.nested_hierarchy_utils import ( |
27 | 23 | NestedHierarchy, |
@@ -94,9 +90,15 @@ class ModelBlock: |
94 | 90 | default_config_key = "base" |
95 | 91 |
|
96 | 92 | def __init__( |
97 | | - self, *, data=None, config: Optional[Dict] = None, config_key: Optional[str] = None |
| 93 | + self, |
| 94 | + *, |
| 95 | + data=None, |
| 96 | + config: Optional[Dict] = None, |
| 97 | + config_key: Optional[str] = None, |
98 | 98 | ): |
99 | | - self.config_key = config_key if config_key is not None else self.default_config_key |
| 99 | + self.config_key = ( |
| 100 | + config_key if config_key is not None else self.default_config_key |
| 101 | + ) |
100 | 102 | self.level_random_effects = None |
101 | 103 |
|
102 | 104 | def read_config(self, config=None): |
@@ -125,7 +127,9 @@ def make_observations(self, inputs, *, data=None, config: Optional[Dict] = None) |
125 | 127 | f"{self.__class__} class does not implement a make_observations method" |
126 | 128 | ) |
127 | 129 |
|
128 | | - def sum_reduce_inputs(self, inputs: Iterable[Optional[TensorVariable]]) -> TensorVariable: |
| 130 | + def sum_reduce_inputs( |
| 131 | + self, inputs: Iterable[Optional[TensorVariable]] |
| 132 | + ) -> TensorVariable: |
129 | 133 | return sum(i for i in inputs if i is not None) |
130 | 134 |
|
131 | 135 | def make_summaries( |
@@ -153,7 +157,8 @@ def make_summaries( |
153 | 157 | selected_summaries = [ |
154 | 158 | method_name |
155 | 159 | for method_name in dir(self) |
156 | | - if callable(getattr(self, method_name)) and method_name.startswith("summary_") |
| 160 | + if callable(getattr(self, method_name)) |
| 161 | + and method_name.startswith("summary_") |
157 | 162 | ] |
158 | 163 | summaries = {} |
159 | 164 | for summary_method in selected_summaries: |
@@ -181,7 +186,8 @@ def make_plots( |
181 | 186 | selected_plots = [ |
182 | 187 | method_name |
183 | 188 | for method_name in dir(self) |
184 | | - if callable(getattr(self, method_name)) and method_name.startswith("plot_") |
| 189 | + if callable(getattr(self, method_name)) |
| 190 | + and method_name.startswith("plot_") |
185 | 191 | ] |
186 | 192 | plots = {} |
187 | 193 | for plot_method in selected_plots: |
|
0 commit comments