1+ import difflib
12import os
23from pathlib import Path
34from textwrap import dedent
@@ -58,6 +59,7 @@ def run_and_assert(
5859 tmpdir ,
5960 input_code ,
6061 expected ,
62+ expected_diff_per_change : list [str ] = [],
6163 num_changes : int = 1 ,
6264 min_num_changes : int | None = None ,
6365 root : Path | None = None ,
@@ -99,13 +101,19 @@ def run_and_assert(
99101 tmp_file_path ,
100102 input_code ,
101103 expected ,
102- changes [0 ],
104+ expected_diff_per_change ,
105+ num_changes ,
106+ changes ,
103107 )
104108
105109 def assert_num_changes (self , changes , expected_num_changes , min_num_changes ):
106- assert len (changes ) == 1
110+ print (len (changes ))
111+ print (changes )
112+ for c in changes :
113+ print (c .diff )
114+ assert len (changes ) == expected_num_changes
107115
108- actual_num = len (changes [ 0 ]. changes )
116+ actual_num = len (changes )
109117
110118 if min_num_changes is not None :
111119 assert (
@@ -116,25 +124,54 @@ def assert_num_changes(self, changes, expected_num_changes, min_num_changes):
116124 actual_num == expected_num_changes
117125 ), f"Expected { expected_num_changes } changes but { actual_num } were created."
118126
119- def assert_changes (self , root , file_path , input_code , expected , changes ):
120- assert os .path .relpath (file_path , root ) == changes .path
121- assert all (change .description for change in changes .changes )
122-
123- expected_diff = create_diff (
124- dedent (input_code ).splitlines (keepends = True ),
125- dedent (expected ).splitlines (keepends = True ),
127+ def assert_changes (
128+ self ,
129+ root ,
130+ file_path ,
131+ input_code ,
132+ expected ,
133+ expected_diff_per_change ,
134+ num_changes ,
135+ changes ,
136+ ):
137+ assert all (
138+ os .path .relpath (file_path , root ) == change .path for change in changes
126139 )
127- try :
128- assert expected_diff == changes .diff
129- except AssertionError :
130- raise DiffError (expected_diff , changes .diff )
131-
132- output_code = file_path .read_bytes ().decode ("utf-8" )
133-
134- try :
135- assert output_code == (format_expected := dedent (expected ))
136- except AssertionError :
137- raise DiffError (format_expected , output_code )
140+ assert all (c .description for change in changes for c in change .changes )
141+
142+ # assert each change individually
143+ if num_changes > 1 :
144+ assert num_changes == len (expected_diff_per_change )
145+ for change , diff in zip (changes , expected_diff_per_change ):
146+ print (change .diff )
147+ print ("-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-" )
148+ print (diff )
149+ print ("+++++++++++++++++++++++++++++++" )
150+ print (
151+ "\n " .join (
152+ difflib .ndiff (diff .splitlines (), change .diff .splitlines ())
153+ )
154+ .replace (" " , "␣" )
155+ .replace ("\t " , "→" )
156+ )
157+ assert change .diff == diff
158+ else :
159+ # generate diff from expected code
160+ expected_diff = create_diff (
161+ dedent (input_code ).splitlines (keepends = True ),
162+ dedent (expected ).splitlines (keepends = True ),
163+ )
164+ try :
165+ assert expected_diff == changes .diff
166+ except AssertionError :
167+ raise DiffError (expected_diff , changes .diff )
168+
169+ output_code = file_path .read_bytes ().decode ("utf-8" )
170+
171+ try :
172+ assert output_code == (format_expected := dedent (expected ))
173+ except AssertionError :
174+ raise DiffError (format_expected , output_code )
138175
139176 def run_and_assert_filepath (
140177 self ,
@@ -171,6 +208,7 @@ def run_and_assert(
171208 tmpdir ,
172209 input_code ,
173210 expected ,
211+ expected_diff_per_change : list [str ] | None = None ,
174212 num_changes : int = 1 ,
175213 min_num_changes : int | None = None ,
176214 root : Path | None = None ,
@@ -217,7 +255,9 @@ def run_and_assert(
217255 tmp_file_path ,
218256 input_code ,
219257 expected ,
220- changes [0 ],
258+ expected_diff_per_change ,
259+ num_changes ,
260+ changes ,
221261 )
222262
223263 return changes
0 commit comments