Skip to content

Commit e5f7849

Browse files
committed
test: enhance CSVLogger tests for column handling and remote filesystem behavior
1 parent d85dbe4 commit e5f7849

File tree

1 file changed

+118
-48
lines changed

1 file changed

+118
-48
lines changed

tests/tests_fabric/loggers/test_csv.py

Lines changed: 118 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import csv
15-
import itertools
1614
import os
1715
from unittest import mock
1816
from unittest.mock import MagicMock
@@ -172,56 +170,128 @@ def test_append_metrics_file(_, tmp_path):
172170

173171
def test_append_columns(tmp_path):
174172
"""Test that the CSV file gets rewritten with new headers if the columns change."""
175-
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1)
173+
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=2)
176174

177175
# initial metrics
178176
logger.log_metrics({"a": 1, "b": 2})
179-
_assert_csv_content(
180-
logger.experiment.metrics_file_path,
181-
expected_headers={"step", "a", "b"},
182-
expected_content=[{"step": "0", "a": "1", "b": "2"}],
183-
)
184177

185178
# new key appears
186-
logger.log_metrics({"a": 11, "b": 22, "c": 33})
187-
_assert_csv_content(
188-
logger.experiment.metrics_file_path,
189-
expected_headers={"step", "a", "b", "c"},
190-
expected_content=[
191-
{"step": "0", "a": "1", "b": "2", "c": ""},
192-
{"step": "0", "a": "11", "b": "22", "c": "33"},
193-
],
194-
)
179+
logger.log_metrics({"a": 1, "b": 2, "c": 3})
180+
with open(logger.experiment.metrics_file_path) as file:
181+
lines = file.readlines()
182+
header = lines[0].strip()
183+
assert header.split(",") == ["a", "b", "c", "step"]
184+
assert len(lines) == 3 # header + 2 data rows
195185

196186
# key disappears
197-
logger.log_metrics({"a": 111, "c": 333})
198-
_assert_csv_content(
199-
logger.experiment.metrics_file_path,
200-
expected_headers={"step", "a", "b", "c"},
201-
expected_content=[
202-
{"step": "0", "a": "1", "b": "2", "c": ""},
203-
{"step": "0", "a": "11", "b": "22", "c": "33"},
204-
{"step": "0", "a": "111", "b": "", "c": "333"},
205-
],
206-
)
207-
208-
209-
def _assert_csv_content(
210-
path: str,
211-
expected_headers: set[str],
212-
expected_content: list[dict[str, str]],
213-
) -> None:
214-
"""Verifies the content of a local csv file with the expected ones."""
215-
headers, content = _read_csv(path)
216-
assert headers == expected_headers
217-
for actual, expected in itertools.zip_longest(content, expected_content):
218-
assert actual == expected
219-
220-
221-
def _read_csv(path: str) -> tuple[set[str], list[dict[str, str]]]:
222-
"""Reads a local csv file and returns the headers and content."""
223-
with open(path) as file:
224-
reader = csv.DictReader(file)
225-
headers = set(reader.fieldnames)
226-
content = list(reader)
227-
return headers, content
187+
logger.log_metrics({"a": 1, "c": 3})
188+
logger.save()
189+
with open(logger.experiment.metrics_file_path) as file:
190+
lines = file.readlines()
191+
header = lines[0].strip()
192+
assert header.split(",") == ["a", "b", "c", "step"]
193+
assert len(lines) == 4 # header + 3 data rows
194+
195+
196+
@mock.patch(
197+
# Mock the existence check, so we can simulate appending to the metrics file
198+
"lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists"
199+
)
200+
def test_rewrite_with_new_header(_, tmp_path):
201+
"""Test that existing files get rewritten correctly when new columns are added."""
202+
# write a csv file manually to simulate existing data
203+
csv_path = tmp_path / "metrics.csv"
204+
with open(csv_path, "w") as file:
205+
file.write("a,b,step\n")
206+
file.write("1,2,0\n")
207+
208+
writer = _ExperimentWriter(log_dir=str(tmp_path))
209+
210+
# Add metrics with a new column
211+
writer.log_metrics({"a": 2, "b": 3, "c": 4}, step=1)
212+
writer.save()
213+
# The rewritten file should have the new columns and preserve old data
214+
with open(csv_path) as file:
215+
lines = file.readlines()
216+
assert len(lines) == 3 # header + 2 data rows
217+
header = lines[0].strip()
218+
assert header.split(",") == ["a", "b", "c", "step"]
219+
# verify old data is preserved
220+
assert lines[1].strip().split(",") == ["1", "2", "", "0"] # old row with empty new column
221+
assert lines[2].strip().split(",") == ["2", "3", "4", "1"]
222+
223+
224+
def test_log_metrics_column_order_sorted(tmp_path):
225+
"""Test that the columns in the output metrics file are sorted by name."""
226+
logger = CSVLogger(tmp_path)
227+
logger.log_metrics({"c": 0.1})
228+
logger.log_metrics({"c": 0.2})
229+
logger.log_metrics({"b": 0.3})
230+
logger.log_metrics({"a": 0.4})
231+
logger.save()
232+
logger.log_metrics({"d": 0.5})
233+
logger.save()
234+
235+
with open(logger.experiment.metrics_file_path) as fp:
236+
lines = fp.readlines()
237+
238+
assert lines[0].strip() == "a,b,c,d,step"
239+
240+
241+
@mock.patch("lightning.fabric.loggers.csv_logs.get_filesystem")
242+
@mock.patch("lightning.fabric.loggers.csv_logs._ExperimentWriter._read_existing_metrics")
243+
def test_remote_filesystem_uses_write_mode(mock_read_existing, mock_get_fs, tmp_path):
244+
"""Test that remote filesystems use write mode."""
245+
mock_fs = MagicMock()
246+
mock_fs.isfile.return_value = False # File doesn't exist
247+
mock_fs.makedirs = MagicMock()
248+
mock_get_fs.return_value = mock_fs
249+
250+
logger = CSVLogger(tmp_path)
251+
assert not logger.experiment._is_local_fs
252+
253+
logger.log_metrics({"a": 0.3}, step=1)
254+
logger.save()
255+
256+
# Verify _read_existing_metrics was NOT called (file doesn't exist)
257+
mock_read_existing.assert_not_called()
258+
259+
# Verify write mode was used (remote FS should never use append)
260+
mock_fs.open.assert_called()
261+
call_args = mock_fs.open.call_args_list[-1] # Get the last call
262+
263+
# Extract the mode parameter specifically
264+
args, kwargs = call_args
265+
mode = kwargs.get("mode", "r") # Default to 'r' if mode not specified
266+
assert mode == "w", f"Expected write mode 'w', but got mode: '{mode}'"
267+
268+
269+
@mock.patch("lightning.fabric.loggers.csv_logs.get_filesystem")
270+
@mock.patch("lightning.fabric.loggers.csv_logs._ExperimentWriter._read_existing_metrics")
271+
def test_remote_filesystem_preserves_existing_data(mock_read_existing, mock_get_fs, tmp_path):
272+
"""Test that remote filesystem reads existing data and preserves it when rewriting."""
273+
# Mock remote filesystem with existing file
274+
mock_fs = MagicMock()
275+
mock_fs.isfile.return_value = True
276+
mock_fs.makedirs = MagicMock()
277+
mock_get_fs.return_value = mock_fs
278+
279+
# Mock existing data
280+
mock_read_existing.return_value = [{"a": 0.1, "step": 0}, {"a": 0.2, "step": 1}]
281+
282+
logger = CSVLogger(tmp_path)
283+
assert not logger.experiment._is_local_fs
284+
285+
# Add new metrics - should read existing and combine
286+
logger.log_metrics({"a": 0.3}, step=2)
287+
logger.save()
288+
289+
# Verify that _read_existing_metrics was called (should read existing data)
290+
mock_read_existing.assert_called_once()
291+
292+
# Verify write mode was used
293+
mock_fs.open.assert_called()
294+
last_call = mock_fs.open.call_args_list[-1]
295+
args, kwargs = last_call
296+
mode = kwargs.get("mode", "r")
297+
assert mode == "w", f"Expected write mode 'w', but got mode: '{mode}'"

0 commit comments

Comments
 (0)