Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions Lib/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class excel:
and when writing, each quote character embedded in the data is
written as two quotes
"""
from typing import (
Any, Iterable, Iterator, Optional, Dict, Union, List,
Callable, Type, cast
)

import types
from _csv import Error, writer, reader, register_dialect, \
Expand Down Expand Up @@ -142,22 +146,30 @@ class unix_dialect(Dialect):


class DictReader:
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
def __init__(
self,
f: Iterable[str],
fieldnames: Optional[Iterable[str]] = None,
restkey: Optional[str] = None,
restval: Optional[Any] = None,
dialect: Union[str, Dialect] = "excel",
*args: Any,
**kwds: Any
) -> None:
if fieldnames is not None and iter(fieldnames) is fieldnames:
fieldnames = list(fieldnames)
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
self.reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0

def __iter__(self):
self._fieldnames: Optional[List[str]] = fieldnames # Explicit type
self.restkey: Optional[str] = restkey
self.restval: Optional[Any] = restval
self.reader: Iterator[List[str]] = reader(f, dialect, *args, **kwds)
self.dialect: Union[str, Dialect] = dialect
self.line_num: int = 0

def __iter__(self) -> Iterator[Dict[str, Any]]:
return self

@property
def fieldnames(self):
def fieldnames(self) -> Optional[List[str]]:
if self._fieldnames is None:
try:
self._fieldnames = next(self.reader)
Expand All @@ -167,28 +179,30 @@ def fieldnames(self):
return self._fieldnames

@fieldnames.setter
def fieldnames(self, value):
self._fieldnames = value
def fieldnames(self, value: Optional[Iterable[str]]) -> None:
if value is not None and iter(value) is value:
value = list(value)
self._fieldnames = cast(Optional[List[str]], value)

def __next__(self):
def __next__(self) -> Dict[str, Any]:
if self.line_num == 0:
# Used only for its side effect.
self.fieldnames
self.fieldnames # Force header parsing
row = next(self.reader)
self.line_num = self.reader.line_num

# unlike the basic reader, we prefer not to return blanks,
# because we will typically wind up with a dict full of None
# values
while row == []:
row = next(self.reader)
d = dict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)

fieldnames = self.fieldnames
if fieldnames is None:
raise ValueError("fieldnames must be set before iteration")

d: Dict[str, Any] = dict(zip(fieldnames, row))
lf, lr = len(fieldnames), len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
for key in fieldnames[lr:]:
d[key] = self.restval
return d

Expand Down
Loading