diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index fc7254b3c4..6a94b0bb38 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -387,13 +387,15 @@ def setup(self, cfg: DictConfig) -> None: ) if self._compile_optimizer_step: if self._optimizer_in_bwd: - raise ValueError( - "optimizer_in_bwd not supported with compiling the optimizer step" + self._logger.warning( + "Compile optimizer is not supported for optimizer_in_bwd. Setting self._compile_optimizer_step to False." + ) + self._compile_optimizer_step = False + else: + self._optimizer.step = torch.compile( + self._optimizer.step, + backend=self._compile_backend, ) - self._optimizer.step = torch.compile( - self._optimizer.step, - backend=self._compile_backend, - ) if self._resume_from_checkpoint: # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 1f6b91f163..5432a43adb 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -387,13 +387,15 @@ def setup(self, cfg: DictConfig) -> None: ) if self._compile_optimizer_step: if self._optimizer_in_bwd: - raise ValueError( - "optimizer_in_bwd not supported with compiling the optimizer step" + self._logger.warning( + "Compile optimizer is not supported for optimizer_in_bwd. Setting self._compile_optimizer_step to False." + ) + self._compile_optimizer_step = False + else: + self._optimizer.step = torch.compile( + self._optimizer.step, + backend=self._compile_backend, ) - self._optimizer.step = torch.compile( - self._optimizer.step, - backend=self._compile_backend, - ) if self._resume_from_checkpoint: # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously