diff --git a/formtools/wizard/views.py b/formtools/wizard/views.py index f16f59c..ea19ec2 100644 --- a/formtools/wizard/views.py +++ b/formtools/wizard/views.py @@ -209,9 +209,18 @@ def get_form_list(self): and respect the result. (True means add the form, False means ignore the form) - The form_list is always generated on the fly because condition methods - could use data from other (maybe previous forms). + The form_list is generated once per wizard instance to avoid repeated + expensive condition evaluations (e.g., database queries). """ + # Check if condition_dict has been modified since last resolution + condition_dict_signature = ( + id(self.condition_dict), + tuple(sorted(self.condition_dict.items())), + ) + if (hasattr(self, '_resolved_form_list') and + self._condition_dict_signature == condition_dict_signature): + return self._resolved_form_list + form_list = OrderedDict() if getattr(self, '_check_cond_started', False): # Guard against infinite recursion, in the case a get_form_list is @@ -228,6 +237,8 @@ def get_form_list(self): if condition: form_list[form_key] = form_class del self._check_cond_started + self._resolved_form_list = form_list + self._condition_dict_signature = condition_dict_signature return form_list def dispatch(self, request, *args, **kwargs): @@ -301,6 +312,12 @@ def post(self, *args, **kwargs): # if the form is valid, store the cleaned data and files. self.storage.set_step_data(self.steps.current, self.process_step(form)) self.storage.set_step_files(self.steps.current, self.process_step_files(form)) + # Clear caches as changed step data could affect conditions + del self._resolved_form_list + del self._condition_dict_signature + for attr_name in list(self.__dict__.keys()): + if attr_name.startswith('_cleaned_data_cache_'): + delattr(self, attr_name) # check if the current step is the last step if self.steps.current == self.steps.last: @@ -497,13 +514,20 @@ def get_cleaned_data_for_step(self, step): If the data doesn't validate, None will be returned. """ if step in self.form_list: + cache_key = f'_cleaned_data_cache_{step}' + if cached_data := getattr(self, cache_key, None): + return cached_data form_obj = self.get_form( step=step, data=self.storage.get_step_data(step), files=self.storage.get_step_files(step), ) if form_obj.is_valid(): - return form_obj.cleaned_data + cleaned_data = form_obj.cleaned_data + setattr(self, cache_key, cleaned_data) + return cleaned_data + else: + setattr(self, cache_key, None) return None def get_next_step(self, step=None): diff --git a/tests/wizard/test_forms.py b/tests/wizard/test_forms.py index e637cae..73956a2 100644 --- a/tests/wizard/test_forms.py +++ b/tests/wizard/test_forms.py @@ -179,6 +179,32 @@ def subsequent_step_check(wizard): finally: sys.setrecursionlimit(old_limit) + def test_form_initial_multiple_calls_regression(self): + def step2_condition(wizard): + wizard.get_cleaned_data_for_step('start') + return True + + class TestWizardWithTracking(TestWizard): + condition_dict = {'step2': step2_condition} + initial_call_count = 0 + + def get_form_initial(self, step): + self.initial_call_count += 1 + return super().get_form_initial(step) + + testform = TestWizardWithTracking.as_view( + [('start', Step1), ('step2', Step2)] + ) + request = get_request( + { + 'test_wizard_with_tracking-current_step': 'start', + 'start-name': 'test' + } + ) + response, instance = testform(request) + calls_during_submission = instance.initial_call_count + self.assertLessEqual(calls_during_submission, 4) + def test_form_condition_unstable(self): request = get_request() testform = TestWizard.as_view(