Skip to content

Commit c095b03

Browse files
corrects error with unpack_results argument (#589)
* corrects error with unpack_results argument * update bootstrap pydantic models --------- Co-authored-by: samlamont <sam.lamont@gmail.com>
1 parent 01ae4cd commit c095b03

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/teehr/models/metrics/basemodels.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Enums and Basemodels for metric classes."""
2-
from typing import Union, Callable
2+
from typing import Union, Callable, List
33

44
from teehr.models.str_enum import StrEnum
55
from teehr.querying.utils import unpack_sdf_dict_columns
@@ -38,9 +38,20 @@ def update_return_type(cls, values):
3838
return values
3939

4040

41-
class BootstrapBasemodel(MetricsBasemodel):
41+
class BootstrapBasemodel(PydanticBaseModel):
4242
"""Bootstrap Basemodel configuration."""
4343

44+
return_type: Union[str, T.ArrayType, T.MapType] = Field(default=None)
45+
reps: int = 1000
46+
seed: Union[int, None] = None
47+
quantiles: Union[List[float], None] = None
48+
49+
model_config = ConfigDict(
50+
arbitrary_types_allowed=True,
51+
validate_assignment=True,
52+
extra='forbid' # raise an error if extra fields are passed
53+
)
54+
4455
@model_validator(mode="before")
4556
def update_return_type(cls, values):
4657
"""Update the return type based on the quantiles."""

src/teehr/models/metrics/bootstrap_models.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ class Gumboot(BootstrapBasemodel):
5656
5757
"""
5858

59-
reps: int = 1000
60-
seed: Union[int, None] = None
61-
quantiles: Union[List[float], None] = None
6259
boot_year_file: Union[str, Path, None] = None
6360
water_year_month: int = 10
6461
name: str = Field(default="Gumboot")
@@ -92,11 +89,8 @@ class CircularBlock(BootstrapBasemodel):
9289
The wrapper to generate the bootstrapping function.
9390
"""
9491

95-
seed: Union[int, None] = None
9692
random_state: Union[RandomState, None] = None
97-
reps: int = 1000
9893
block_size: int = 365
99-
quantiles: Union[List[float], None] = None
10094
name: str = Field(default="CircularBlock")
10195
include_value_time: bool = Field(False, frozen=True)
10296
func: Callable = Field(
@@ -131,11 +125,8 @@ class Stationary(BootstrapBasemodel):
131125
The wrapper to generate the bootstrapping function.
132126
"""
133127

134-
seed: Union[int, None] = None
135128
random_state: Union[RandomState, None] = None
136-
reps: int = 1000
137129
block_size: int = 365
138-
quantiles: Union[List[float], None] = None
139130
name: str = Field(default="Stationary")
140131
include_value_time: bool = Field(False, frozen=True)
141132
func: Callable = Field(

0 commit comments

Comments
 (0)