Skip to content

Commit eea6b44

Browse files
alvitawacarmocca
authored andcommitted
Fixed encoding issues on terminals that do not support unicode characters (#12828)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 194c7c7 commit eea6b44

File tree

3 files changed

+35
-1
lines changed

3 files changed

+35
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))
1717

1818

19+
- Fixed encoding issues on terminals that do not support unicode characters ([#12828](https://github.com/PyTorchLightning/pytorch-lightning/pull/12828))
20+
21+
1922
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))
2023

2124

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
import shutil
16+
import sys
1617
from collections import ChainMap, OrderedDict
1718
from functools import partial
1819
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
@@ -336,6 +337,10 @@ def _find_value(data: dict, target: str) -> Iterable[Any]:
336337

337338
@staticmethod
338339
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
340+
# print to stdout by default
341+
if file is None:
342+
file = sys.stdout
343+
339344
# remove the dl idx suffix
340345
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
341346
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
@@ -384,7 +389,16 @@ def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]]
384389
row_format = f"{{:^{max_length}}}" * len(table_headers)
385390
half_term_size = int(term_size / 2)
386391

387-
bar = "─" * term_size
392+
try:
393+
# some terminals do not support this character
394+
if hasattr(file, "encoding") and file.encoding is not None:
395+
"─".encode(file.encoding)
396+
except UnicodeEncodeError:
397+
bar_character = "-"
398+
else:
399+
bar_character = "─"
400+
bar = bar_character * term_size
401+
388402
lines = [bar, row_format.format(*table_headers).rstrip(), bar]
389403
for metric, row in zip(metrics, table_rows):
390404
# deal with column overflow

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,23 @@ def test_native_print_results(monkeypatch, inputs, expected):
870870
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()
871871

872872

873+
@pytest.mark.parametrize("encoding", ["latin-1", "utf-8"])
874+
def test_native_print_results_encodings(monkeypatch, encoding):
875+
import pytorch_lightning.loops.dataloader.evaluation_loop as imports
876+
877+
monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
878+
879+
out = mock.Mock()
880+
out.encoding = encoding
881+
EvaluationLoop._print_results(*inputs0, file=out)
882+
883+
# Attempt to encode everything the file is told to write with the given encoding
884+
for call_ in out.method_calls:
885+
name, args, kwargs = call_
886+
if name == "write":
887+
args[0].encode(encoding)
888+
889+
873890
expected0 = """
874891
┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
875892
┃ Test metric ┃ DataLoader 0 ┃ DataLoader 1 ┃

0 commit comments

Comments
 (0)