Skip to content

Commit 476ee18

Browse files
authored
Fix cache args (#4522)
* reimplement Gaussian noise * add RandomGaussianBlur aug * minor fix| * fix unit tests * reply comments * provide workaround for XPU batch search * return back parameters for MaskRCNN * fix unit test * fix train args
1 parent 73ea1b0 commit 476ee18

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

src/otx/backend/native/engine.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,7 @@ def train(
152152
min_epochs: int = 1,
153153
seed: int | None = None,
154154
deterministic: bool | Literal["warn"] = False,
155-
precision: _PRECISION_INPUT | None = None,
156-
val_check_interval: int | float | None = None,
155+
precision: _PRECISION_INPUT | None = 16,
157156
callbacks: list[Callback] | Callback | None = None,
158157
logger: Logger | Iterable[Logger] | bool | None = None,
159158
resume: bool = False,
@@ -162,7 +161,7 @@ def train(
162161
adaptive_bs: Literal["None", "Safe", "Full"] = "None",
163162
check_val_every_n_epoch: int | None = 1,
164163
num_sanity_val_steps: int | None = 0,
165-
log_every_n_steps: int | None = 1,
164+
gradient_clip_val: float | None = None,
166165
**kwargs,
167166
) -> dict[str, Any]:
168167
r"""Trains the model using the provided LightningModule and OTXDataModule.
@@ -175,7 +174,6 @@ def train(
175174
Also, can be set to `warn` to avoid failures, because some operations don't
176175
support deterministic mode. Defaults to False.
177176
precision (_PRECISION_INPUT | None, optional): The precision of the model. Defaults to 16.
178-
val_check_interval (int | float | None, optional): The validation check interval. Defaults to None.
179177
callbacks (list[Callback] | Callback | None, optional): The callbacks to be used during training.
180178
logger (Logger | Iterable[Logger] | bool | None, optional): The logger(s) to be used. Defaults to None.
181179
resume (bool, optional): If True, tries to resume training from existing checkpoint.
@@ -188,6 +186,7 @@ def train(
188186
Defaults to "None".
189187
check_val_every_n_epoch (int | None, optional): How often to check validation. Defaults to 1.
190188
num_sanity_val_steps (int | None, optional): Number of validation steps to run before training starts.
189+
gradient_clip_val (float | None, optional): The value for gradient clipping. Defaults to None.
191190
**kwargs: Additional keyword arguments for pl.Trainer configuration.
192191
193192
Returns:
@@ -243,10 +242,9 @@ def train(
243242
max_epochs=max_epochs,
244243
min_epochs=min_epochs,
245244
deterministic=deterministic,
246-
val_check_interval=val_check_interval,
247245
check_val_every_n_epoch=check_val_every_n_epoch,
248246
num_sanity_val_steps=num_sanity_val_steps,
249-
log_every_n_steps=log_every_n_steps,
247+
gradient_clip_val=gradient_clip_val,
250248
**kwargs,
251249
)
252250
fit_kwargs: dict[str, Any] = {}
@@ -877,13 +875,18 @@ def _apply_param_overrides(self, param_kwargs: dict[str, Any]) -> None:
877875
"""Apply parameter overrides based on the current local variables."""
878876
sig = inspect.signature(self.train)
879877
add_kwargs = param_kwargs.pop("kwargs", {})
880-
self._cache.update(**add_kwargs)
881878
for param_name, param in sig.parameters.items():
882-
if param_name in param_kwargs:
883-
current_value = param_kwargs[param_name]
884-
# Apply override if current value matches default and we have an override
885-
if (current_value != param.default) or (param_name not in self._cache.args):
879+
if param_name in param_kwargs and param_name in self._cache.args:
880+
# if both `param_kwargs` and `_cache.args` have the same parameter,
881+
# we will use the value from `param_kwargs` if it is different from the default
882+
# value of the parameter.
883+
# Otherwise, we will keep the value from `_cache.args`.
884+
current_value = param_kwargs.pop(param_name)
885+
if current_value != param.default:
886886
self._cache.args[param_name] = current_value
887+
# update the cache with the remaining parameters
888+
self._cache.update(**param_kwargs)
889+
self._cache.update(**add_kwargs)
887890

888891
def configure_accelerator(self) -> None:
889892
"""Updates the cache arguments based on the device type."""

0 commit comments

Comments
 (0)