diff --git a/.gitignore b/.gitignore index d0d6b8e..1dcfb94 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ __pycache__ -output.png \ No newline at end of file +output.png + +*.egg-info + +.idea \ No newline at end of file diff --git a/DeepCache/extension/deepcache.py b/DeepCache/extension/deepcache.py index e96ddce..b625ac5 100644 --- a/DeepCache/extension/deepcache.py +++ b/DeepCache/extension/deepcache.py @@ -1,15 +1,26 @@ class DeepCacheSDHelper(object): def __init__(self, pipe=None): - if pipe is not None: self.pipe = pipe + if pipe is not None: + self.pipe = pipe + self._enabled = False - def enable(self, pipe=None): - assert self.pipe is not None - self.reset_states() - self.wrap_modules() + def enable(self): + if not self._enabled: + self.reset_states() + self.wrap_modules() + self._enabled = True + print("Enabling Deepcache") + else: + print("DeepCache is already enabled.") def disable(self): - self.unwrap_modules() - self.reset_states() + if self._enabled: + self.unwrap_modules() + self.reset_states() + self._enabled = False + print("Disabling Deepcache") + else: + print("DeepCache is already disabled.") def set_params(self,cache_interval=1, cache_branch_id=0, skip_mode='uniform'): cache_layer_id = cache_branch_id % 3