|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| -import csv |
15 |
| -import itertools |
16 | 14 | import os
|
17 | 15 | from unittest import mock
|
18 | 16 | from unittest.mock import MagicMock
|
@@ -172,56 +170,128 @@ def test_append_metrics_file(_, tmp_path):
|
172 | 170 |
|
173 | 171 | def test_append_columns(tmp_path):
|
174 | 172 | """Test that the CSV file gets rewritten with new headers if the columns change."""
|
175 |
| - logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1) |
| 173 | + logger = CSVLogger(tmp_path, flush_logs_every_n_steps=2) |
176 | 174 |
|
177 | 175 | # initial metrics
|
178 | 176 | logger.log_metrics({"a": 1, "b": 2})
|
179 |
| - _assert_csv_content( |
180 |
| - logger.experiment.metrics_file_path, |
181 |
| - expected_headers={"step", "a", "b"}, |
182 |
| - expected_content=[{"step": "0", "a": "1", "b": "2"}], |
183 |
| - ) |
184 | 177 |
|
185 | 178 | # new key appears
|
186 |
| - logger.log_metrics({"a": 11, "b": 22, "c": 33}) |
187 |
| - _assert_csv_content( |
188 |
| - logger.experiment.metrics_file_path, |
189 |
| - expected_headers={"step", "a", "b", "c"}, |
190 |
| - expected_content=[ |
191 |
| - {"step": "0", "a": "1", "b": "2", "c": ""}, |
192 |
| - {"step": "0", "a": "11", "b": "22", "c": "33"}, |
193 |
| - ], |
194 |
| - ) |
| 179 | + logger.log_metrics({"a": 1, "b": 2, "c": 3}) |
| 180 | + with open(logger.experiment.metrics_file_path) as file: |
| 181 | + lines = file.readlines() |
| 182 | + header = lines[0].strip() |
| 183 | + assert header.split(",") == ["a", "b", "c", "step"] |
| 184 | + assert len(lines) == 3 # header + 2 data rows |
195 | 185 |
|
196 | 186 | # key disappears
|
197 |
| - logger.log_metrics({"a": 111, "c": 333}) |
198 |
| - _assert_csv_content( |
199 |
| - logger.experiment.metrics_file_path, |
200 |
| - expected_headers={"step", "a", "b", "c"}, |
201 |
| - expected_content=[ |
202 |
| - {"step": "0", "a": "1", "b": "2", "c": ""}, |
203 |
| - {"step": "0", "a": "11", "b": "22", "c": "33"}, |
204 |
| - {"step": "0", "a": "111", "b": "", "c": "333"}, |
205 |
| - ], |
206 |
| - ) |
207 |
| - |
208 |
| - |
209 |
| -def _assert_csv_content( |
210 |
| - path: str, |
211 |
| - expected_headers: set[str], |
212 |
| - expected_content: list[dict[str, str]], |
213 |
| -) -> None: |
214 |
| - """Verifies the content of a local csv file with the expected ones.""" |
215 |
| - headers, content = _read_csv(path) |
216 |
| - assert headers == expected_headers |
217 |
| - for actual, expected in itertools.zip_longest(content, expected_content): |
218 |
| - assert actual == expected |
219 |
| - |
220 |
| - |
221 |
| -def _read_csv(path: str) -> tuple[set[str], list[dict[str, str]]]: |
222 |
| - """Reads a local csv file and returns the headers and content.""" |
223 |
| - with open(path) as file: |
224 |
| - reader = csv.DictReader(file) |
225 |
| - headers = set(reader.fieldnames) |
226 |
| - content = list(reader) |
227 |
| - return headers, content |
| 187 | + logger.log_metrics({"a": 1, "c": 3}) |
| 188 | + logger.save() |
| 189 | + with open(logger.experiment.metrics_file_path) as file: |
| 190 | + lines = file.readlines() |
| 191 | + header = lines[0].strip() |
| 192 | + assert header.split(",") == ["a", "b", "c", "step"] |
| 193 | + assert len(lines) == 4 # header + 3 data rows |
| 194 | + |
| 195 | + |
| 196 | +@mock.patch( |
| 197 | + # Mock the existence check, so we can simulate appending to the metrics file |
| 198 | + "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" |
| 199 | +) |
| 200 | +def test_rewrite_with_new_header(_, tmp_path): |
| 201 | + """Test that existing files get rewritten correctly when new columns are added.""" |
| 202 | + # write a csv file manually to simulate existing data |
| 203 | + csv_path = tmp_path / "metrics.csv" |
| 204 | + with open(csv_path, "w") as file: |
| 205 | + file.write("a,b,step\n") |
| 206 | + file.write("1,2,0\n") |
| 207 | + |
| 208 | + writer = _ExperimentWriter(log_dir=str(tmp_path)) |
| 209 | + |
| 210 | + # Add metrics with a new column |
| 211 | + writer.log_metrics({"a": 2, "b": 3, "c": 4}, step=1) |
| 212 | + writer.save() |
| 213 | + # The rewritten file should have the new columns and preserve old data |
| 214 | + with open(csv_path) as file: |
| 215 | + lines = file.readlines() |
| 216 | + assert len(lines) == 3 # header + 2 data rows |
| 217 | + header = lines[0].strip() |
| 218 | + assert header.split(",") == ["a", "b", "c", "step"] |
| 219 | + # verify old data is preserved |
| 220 | + assert lines[1].strip().split(",") == ["1", "2", "", "0"] # old row with empty new column |
| 221 | + assert lines[2].strip().split(",") == ["2", "3", "4", "1"] |
| 222 | + |
| 223 | + |
| 224 | +def test_log_metrics_column_order_sorted(tmp_path): |
| 225 | + """Test that the columns in the output metrics file are sorted by name.""" |
| 226 | + logger = CSVLogger(tmp_path) |
| 227 | + logger.log_metrics({"c": 0.1}) |
| 228 | + logger.log_metrics({"c": 0.2}) |
| 229 | + logger.log_metrics({"b": 0.3}) |
| 230 | + logger.log_metrics({"a": 0.4}) |
| 231 | + logger.save() |
| 232 | + logger.log_metrics({"d": 0.5}) |
| 233 | + logger.save() |
| 234 | + |
| 235 | + with open(logger.experiment.metrics_file_path) as fp: |
| 236 | + lines = fp.readlines() |
| 237 | + |
| 238 | + assert lines[0].strip() == "a,b,c,d,step" |
| 239 | + |
| 240 | + |
| 241 | +@mock.patch("lightning.fabric.loggers.csv_logs.get_filesystem") |
| 242 | +@mock.patch("lightning.fabric.loggers.csv_logs._ExperimentWriter._read_existing_metrics") |
| 243 | +def test_remote_filesystem_uses_write_mode(mock_read_existing, mock_get_fs, tmp_path): |
| 244 | + """Test that remote filesystems use write mode.""" |
| 245 | + mock_fs = MagicMock() |
| 246 | + mock_fs.isfile.return_value = False # File doesn't exist |
| 247 | + mock_fs.makedirs = MagicMock() |
| 248 | + mock_get_fs.return_value = mock_fs |
| 249 | + |
| 250 | + logger = CSVLogger(tmp_path) |
| 251 | + assert not logger.experiment._is_local_fs |
| 252 | + |
| 253 | + logger.log_metrics({"a": 0.3}, step=1) |
| 254 | + logger.save() |
| 255 | + |
| 256 | + # Verify _read_existing_metrics was NOT called (file doesn't exist) |
| 257 | + mock_read_existing.assert_not_called() |
| 258 | + |
| 259 | + # Verify write mode was used (remote FS should never use append) |
| 260 | + mock_fs.open.assert_called() |
| 261 | + call_args = mock_fs.open.call_args_list[-1] # Get the last call |
| 262 | + |
| 263 | + # Extract the mode parameter specifically |
| 264 | + args, kwargs = call_args |
| 265 | + mode = kwargs.get("mode", "r") # Default to 'r' if mode not specified |
| 266 | + assert mode == "w", f"Expected write mode 'w', but got mode: '{mode}'" |
| 267 | + |
| 268 | + |
| 269 | +@mock.patch("lightning.fabric.loggers.csv_logs.get_filesystem") |
| 270 | +@mock.patch("lightning.fabric.loggers.csv_logs._ExperimentWriter._read_existing_metrics") |
| 271 | +def test_remote_filesystem_preserves_existing_data(mock_read_existing, mock_get_fs, tmp_path): |
| 272 | + """Test that remote filesystem reads existing data and preserves it when rewriting.""" |
| 273 | + # Mock remote filesystem with existing file |
| 274 | + mock_fs = MagicMock() |
| 275 | + mock_fs.isfile.return_value = True |
| 276 | + mock_fs.makedirs = MagicMock() |
| 277 | + mock_get_fs.return_value = mock_fs |
| 278 | + |
| 279 | + # Mock existing data |
| 280 | + mock_read_existing.return_value = [{"a": 0.1, "step": 0}, {"a": 0.2, "step": 1}] |
| 281 | + |
| 282 | + logger = CSVLogger(tmp_path) |
| 283 | + assert not logger.experiment._is_local_fs |
| 284 | + |
| 285 | + # Add new metrics - should read existing and combine |
| 286 | + logger.log_metrics({"a": 0.3}, step=2) |
| 287 | + logger.save() |
| 288 | + |
| 289 | + # Verify that _read_existing_metrics was called (should read existing data) |
| 290 | + mock_read_existing.assert_called_once() |
| 291 | + |
| 292 | + # Verify write mode was used |
| 293 | + mock_fs.open.assert_called() |
| 294 | + last_call = mock_fs.open.call_args_list[-1] |
| 295 | + args, kwargs = last_call |
| 296 | + mode = kwargs.get("mode", "r") |
| 297 | + assert mode == "w", f"Expected write mode 'w', but got mode: '{mode}'" |
0 commit comments