Skip to content

Commit 61aff7f

Browse files
committed
fix
1 parent 148d9ec commit 61aff7f

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _fit(self, context: Context, sampler: SamplerType) -> None:
130130
self.context = context
131131
self._logger.info("starting pipeline optimization...")
132132
self.context.callback_handler.start_run(
133-
run_name=self.context.logging_config.run_name,
133+
run_name=self.context.logging_config.get_run_name(),
134134
dirpath=self.context.logging_config.dirpath,
135135
)
136136
for node_type in NodeType:

autointent/configs/_optimization.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,24 @@ class DataConfig(BaseModel):
1717
"""Hold-out or cross-validation."""
1818
n_folds: PositiveInt = Field(3, description="Number of folds in cross-validation.")
1919
"""Number of folds in cross-validation."""
20-
validation_size: FloatFromZeroToOne = Field(0.2, description="Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).")
20+
validation_size: FloatFromZeroToOne = Field(
21+
0.2,
22+
description=(
23+
"Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."
24+
),
25+
)
2126
"""Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split)."""
22-
separation_ratio: FloatFromZeroToOne | None = Field(0.5, description="Set to float to prevent data leak between scoring and decision nodes.")
27+
separation_ratio: FloatFromZeroToOne | None = Field(
28+
0.5, description="Set to float to prevent data leak between scoring and decision nodes."
29+
)
2330
"""Set to float to prevent data leak between scoring and decision nodes."""
2431

2532

2633
class LoggingConfig(BaseModel):
2734
"""Configuration for the logging."""
2835

36+
_run_name = get_run_name()
37+
2938
project_dir: Path | str | None = Field(None, description="Path to the directory with different runs.")
3039
"""Path to the directory with different runs."""
3140
run_name: str | None = Field(None, description="Name of the run. If None, a random name will be generated.")
@@ -34,16 +43,17 @@ class LoggingConfig(BaseModel):
3443
"""Whether to dump the modules or not"""
3544
clear_ram: bool = Field(False, description="Whether to clear the RAM after dumping the modules")
3645
"""Whether to clear the RAM after dumping the modules"""
37-
report_to: list[REPORTERS_NAMES] | None = Field(None, description="List of callbacks to report to. If None, no callbacks will be used") # type: ignore[valid-type]
46+
report_to: list[REPORTERS_NAMES] | None = Field( # type: ignore[valid-type]
47+
None, description="List of callbacks to report to. If None, no callbacks will be used"
48+
)
3849
"""List of callbacks to report to. If None, no callbacks will be used"""
3950

4051
@property
4152
def dirpath(self) -> Path:
4253
"""Path to the directory where the logs will be saved."""
43-
run_name = self.run_name or get_run_name()
4454
project_dir = self.project_dir or Path.cwd() / "runs"
4555
if not hasattr(self, "_dirpath"):
46-
self._dirpath = Path(project_dir) / run_name
56+
self._dirpath = Path(project_dir) / self.get_run_name()
4757
return self._dirpath
4858

4959
@property
@@ -53,6 +63,10 @@ def dump_dir(self) -> Path:
5363
self._dump_dir = self.dirpath / "modules_dumps"
5464
return self._dump_dir
5565

66+
def get_run_name(self) -> str:
67+
"""Get the run name."""
68+
return self.run_name or self._run_name
69+
5670

5771
class VectorIndexConfig(BaseModel):
5872
"""Configuration for the vector index."""

0 commit comments

Comments
 (0)