diff --git a/Lib/csv.py b/Lib/csv.py index 0a627ba7a512fa..3e23fa5781f25a 100644 --- a/Lib/csv.py +++ b/Lib/csv.py @@ -62,6 +62,10 @@ class excel: and when writing, each quote character embedded in the data is written as two quotes """ +from typing import ( + Any, Iterable, Iterator, Optional, Dict, Union, List, + Callable, Type, cast +) import types from _csv import Error, writer, reader, register_dialect, \ @@ -142,22 +146,30 @@ class unix_dialect(Dialect): class DictReader: - def __init__(self, f, fieldnames=None, restkey=None, restval=None, - dialect="excel", *args, **kwds): + def __init__( + self, + f: Iterable[str], + fieldnames: Optional[Iterable[str]] = None, + restkey: Optional[str] = None, + restval: Optional[Any] = None, + dialect: Union[str, Dialect] = "excel", + *args: Any, + **kwds: Any + ) -> None: if fieldnames is not None and iter(fieldnames) is fieldnames: fieldnames = list(fieldnames) - self._fieldnames = fieldnames # list of keys for the dict - self.restkey = restkey # key to catch long rows - self.restval = restval # default value for short rows - self.reader = reader(f, dialect, *args, **kwds) - self.dialect = dialect - self.line_num = 0 - - def __iter__(self): + self._fieldnames: Optional[List[str]] = fieldnames # Explicit type + self.restkey: Optional[str] = restkey + self.restval: Optional[Any] = restval + self.reader: Iterator[List[str]] = reader(f, dialect, *args, **kwds) + self.dialect: Union[str, Dialect] = dialect + self.line_num: int = 0 + + def __iter__(self) -> Iterator[Dict[str, Any]]: return self @property - def fieldnames(self): + def fieldnames(self) -> Optional[List[str]]: if self._fieldnames is None: try: self._fieldnames = next(self.reader) @@ -167,28 +179,30 @@ def fieldnames(self): return self._fieldnames @fieldnames.setter - def fieldnames(self, value): - self._fieldnames = value + def fieldnames(self, value: Optional[Iterable[str]]) -> None: + if value is not None and iter(value) is value: + value = list(value) + self._fieldnames = cast(Optional[List[str]], value) - def __next__(self): + def __next__(self) -> Dict[str, Any]: if self.line_num == 0: - # Used only for its side effect. - self.fieldnames + self.fieldnames # Force header parsing row = next(self.reader) self.line_num = self.reader.line_num - # unlike the basic reader, we prefer not to return blanks, - # because we will typically wind up with a dict full of None - # values while row == []: row = next(self.reader) - d = dict(zip(self.fieldnames, row)) - lf = len(self.fieldnames) - lr = len(row) + + fieldnames = self.fieldnames + if fieldnames is None: + raise ValueError("fieldnames must be set before iteration") + + d: Dict[str, Any] = dict(zip(fieldnames, row)) + lf, lr = len(fieldnames), len(row) if lf < lr: d[self.restkey] = row[lf:] elif lf > lr: - for key in self.fieldnames[lr:]: + for key in fieldnames[lr:]: d[key] = self.restval return d