Skip to content

Commit a49046c

Browse files
authored
some more typechecking (ipython#14535)
2 parents d2b6c2a + b7cc43d commit a49046c

File tree

4 files changed

+108
-45
lines changed

4 files changed

+108
-45
lines changed

IPython/testing/tests/test_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_full_path_posix():
3232
spath = "/foo"
3333
result = tt.full_path(spath, ["a.txt", "b.txt"])
3434
assert result, ["/a.txt" == "/b.txt"]
35-
result = tt.full_path(spath, "a.txt")
35+
result = tt.full_path(spath, ["a.txt"])
3636
assert result == ["/a.txt"]
3737

3838

@@ -44,7 +44,7 @@ def test_full_path_win32():
4444
spath = "c:\\foo"
4545
result = tt.full_path(spath, ["a.txt", "b.txt"])
4646
assert result, ["c:\\a.txt" == "c:\\b.txt"]
47-
result = tt.full_path(spath, "a.txt")
47+
result = tt.full_path(spath, ["a.txt"])
4848
assert result == ["c:\\a.txt"]
4949

5050

IPython/testing/tools.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
3737

3838
@doctest_deco
39-
def full_path(startPath,files):
39+
def full_path(startPath: str, files: list[str]) -> list[str]:
4040
"""Make full paths for all the listed files, based on startPath.
4141
4242
Only the base part of startPath is kept, since this routine is typically
@@ -49,7 +49,7 @@ def full_path(startPath,files):
4949
Initial path to use as the base for the results. This path is split
5050
using os.path.split() and only its first component is kept.
5151
52-
files : string or list
52+
files : list
5353
One or more files.
5454
5555
Examples
@@ -61,13 +61,8 @@ def full_path(startPath,files):
6161
>>> full_path('/foo',['a.txt','b.txt'])
6262
['/a.txt', '/b.txt']
6363
64-
If a single file is given, the output is still a list::
65-
66-
>>> full_path('/foo','a.txt')
67-
['/a.txt']
6864
"""
69-
70-
files = list_strings(files)
65+
assert isinstance(files, list)
7166
base = os.path.split(startPath)[0]
7267
return [ os.path.join(base,f) for f in files ]
7368

IPython/utils/text.py

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,20 @@
1616
from string import Formatter
1717
from pathlib import Path
1818

19-
from typing import List, Dict, Tuple, Optional, cast, Sequence, Mapping, Any
19+
from typing import (
20+
List,
21+
Dict,
22+
Tuple,
23+
Optional,
24+
cast,
25+
Sequence,
26+
Mapping,
27+
Any,
28+
Union,
29+
Callable,
30+
Iterator,
31+
TypeVar,
32+
)
2033

2134
if sys.version_info < (3, 12):
2235
from typing_extensions import Self
@@ -138,8 +151,13 @@ def get_paths(self) -> List[Path]:
138151

139152
p = paths = property(get_paths)
140153

141-
def grep(self, pattern, prune = False, field = None):
142-
""" Return all strings matching 'pattern' (a regex or callable)
154+
def grep(
155+
self,
156+
pattern: Union[str, Callable[[Any], re.Match[str] | None]],
157+
prune: bool = False,
158+
field: Optional[int] = None,
159+
) -> Self:
160+
"""Return all strings matching 'pattern' (a regex or callable)
143161
144162
This is case-insensitive. If prune is true, return all items
145163
NOT matching the pattern.
@@ -154,7 +172,7 @@ def grep(self, pattern, prune = False, field = None):
154172
a.grep('chm', field=-1)
155173
"""
156174

157-
def match_target(s):
175+
def match_target(s: str) -> str:
158176
if field is None:
159177
return s
160178
parts = s.split()
@@ -169,12 +187,12 @@ def match_target(s):
169187
else:
170188
pred = pattern
171189
if not prune:
172-
return SList([el for el in self if pred(match_target(el))])
190+
return type(self)([el for el in self if pred(match_target(el))])
173191
else:
174-
return SList([el for el in self if not pred(match_target(el))])
192+
return type(self)([el for el in self if not pred(match_target(el))])
175193

176-
def fields(self, *fields):
177-
""" Collect whitespace-separated fields from string list
194+
def fields(self, *fields: List[str]) -> List[List[str]]:
195+
"""Collect whitespace-separated fields from string list
178196
179197
Allows quick awk-like usage of string lists.
180198
@@ -209,8 +227,12 @@ def fields(self, *fields):
209227

210228
return res
211229

212-
def sort(self,field= None, nums = False):
213-
""" sort by specified fields (see fields())
230+
def sort( # type:ignore[override]
231+
self,
232+
field: Optional[List[str]] = None,
233+
nums: bool = False,
234+
) -> Self:
235+
"""sort by specified fields (see fields())
214236
215237
Example::
216238
@@ -236,7 +258,7 @@ def sort(self,field= None, nums = False):
236258

237259

238260
dsu.sort()
239-
return SList([t[1] for t in dsu])
261+
return type(self)([t[1] for t in dsu])
240262

241263

242264
# FIXME: We need to reimplement type specific displayhook and then add this
@@ -255,7 +277,7 @@ def sort(self,field= None, nums = False):
255277
# print_slist = result_display.register(SList)(print_slist)
256278

257279

258-
def indent(instr,nspaces=4, ntabs=0, flatten=False):
280+
def indent(instr: str, nspaces: int = 4, ntabs: int = 0, flatten: bool = False) -> str:
259281
"""Indent a string a given number of spaces or tabstops.
260282
261283
indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
@@ -275,7 +297,7 @@ def indent(instr,nspaces=4, ntabs=0, flatten=False):
275297
276298
Returns
277299
-------
278-
str|unicode : string indented by ntabs and nspaces.
300+
str : string indented by ntabs and nspaces.
279301
280302
"""
281303
if instr is None:
@@ -292,7 +314,7 @@ def indent(instr,nspaces=4, ntabs=0, flatten=False):
292314
return outstr
293315

294316

295-
def list_strings(arg):
317+
def list_strings(arg: Union[str, List[str]]) -> List[str]:
296318
"""Always return a list of strings, given a string or list of strings
297319
as input.
298320
@@ -316,7 +338,7 @@ def list_strings(arg):
316338
return arg
317339

318340

319-
def marquee(txt='',width=78,mark='*'):
341+
def marquee(txt: str = "", width: int = 78, mark: str = "*") -> str:
320342
"""Return the input string centered in a 'marquee'.
321343
322344
Examples
@@ -343,7 +365,8 @@ def marquee(txt='',width=78,mark='*'):
343365

344366
ini_spaces_re = re.compile(r'^(\s+)')
345367

346-
def num_ini_spaces(strng):
368+
369+
def num_ini_spaces(strng: str) -> int:
347370
"""Return the number of initial spaces in a string"""
348371
warnings.warn(
349372
"`num_ini_spaces` is Pending Deprecation since IPython 8.17."
@@ -359,7 +382,7 @@ def num_ini_spaces(strng):
359382
return 0
360383

361384

362-
def format_screen(strng):
385+
def format_screen(strng: str) -> str:
363386
"""Format a string for screen printing.
364387
365388
This removes some latex-type format codes."""
@@ -396,7 +419,7 @@ def dedent(text: str) -> str:
396419
return '\n'.join([first, rest])
397420

398421

399-
def wrap_paragraphs(text, ncols=80):
422+
def wrap_paragraphs(text: str, ncols: int = 80) -> List[str]:
400423
"""Wrap multiple paragraphs to fit a specified width.
401424
402425
This is equivalent to textwrap.wrap, but with support for multiple
@@ -428,7 +451,7 @@ def wrap_paragraphs(text, ncols=80):
428451
return out_ps
429452

430453

431-
def strip_email_quotes(text):
454+
def strip_email_quotes(text: str) -> str:
432455
"""Strip leading email quotation characters ('>').
433456
434457
Removes any combination of leading '>' interspersed with whitespace that
@@ -478,7 +501,7 @@ def strip_email_quotes(text):
478501
return text
479502

480503

481-
def strip_ansi(source):
504+
def strip_ansi(source: str) -> str:
482505
"""
483506
Remove ansi escape codes from text.
484507
@@ -519,7 +542,8 @@ class EvalFormatter(Formatter):
519542
In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
520543
Out[3]: 'll'
521544
"""
522-
def get_field(self, name, args, kwargs):
545+
546+
def get_field(self, name: str, args: Any, kwargs: Any) -> Tuple[Any, str]:
523547
v = eval(name, kwargs)
524548
return v, name
525549

@@ -606,11 +630,15 @@ class DollarFormatter(FullEvalFormatter):
606630
In [4]: f.format('$a or {b}', a=1, b=2)
607631
Out[4]: '1 or 2'
608632
"""
609-
_dollar_pattern_ignore_single_quote = re.compile(r"(.*?)\$(\$?[\w\.]+)(?=([^']*'[^']*')*[^']*$)")
610-
def parse(self, fmt_string):
611-
for literal_txt, field_name, format_spec, conversion \
612-
in Formatter.parse(self, fmt_string):
613-
633+
634+
_dollar_pattern_ignore_single_quote = re.compile(
635+
r"(.*?)\$(\$?[\w\.]+)(?=([^']*'[^']*')*[^']*$)"
636+
)
637+
638+
def parse(self, fmt_string: str) -> Iterator[Tuple[Any, Any, Any, Any]]: # type: ignore
639+
for literal_txt, field_name, format_spec, conversion in Formatter.parse(
640+
self, fmt_string
641+
):
614642
# Find $foo patterns in the literal text.
615643
continue_from = 0
616644
txt = ""
@@ -627,14 +655,17 @@ def parse(self, fmt_string):
627655
# Re-yield the {foo} style pattern
628656
yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
629657

630-
def __repr__(self):
658+
def __repr__(self) -> str:
631659
return "<DollarFormatter>"
632660

633661
#-----------------------------------------------------------------------------
634662
# Utils to columnize a list of string
635663
#-----------------------------------------------------------------------------
636664

637-
def _col_chunks(l, max_rows, row_first=False):
665+
666+
def _col_chunks(
667+
l: List[int], max_rows: int, row_first: bool = False
668+
) -> Iterator[List[int]]:
638669
"""Yield successive max_rows-sized column chunks from l."""
639670
if row_first:
640671
ncols = (len(l) // max_rows) + (len(l) % max_rows > 0)
@@ -646,7 +677,7 @@ def _col_chunks(l, max_rows, row_first=False):
646677

647678

648679
def _find_optimal(
649-
rlist: List[str], row_first: bool, separator_size: int, displaywidth: int
680+
rlist: List[int], row_first: bool, separator_size: int, displaywidth: int
650681
) -> Dict[str, Any]:
651682
"""Calculate optimal info to columnize a list of string"""
652683
for max_rows in range(1, len(rlist) + 1):
@@ -662,7 +693,10 @@ def _find_optimal(
662693
}
663694

664695

665-
def _get_or_default(mylist, i, default=None):
696+
T = TypeVar("T")
697+
698+
699+
def _get_or_default(mylist: List[T], i: int, default: T) -> T:
666700
"""return list item number, or default if don't exist"""
667701
if i >= len(mylist):
668702
return default
@@ -740,9 +774,31 @@ def compute_item_matrix(
740774
)
741775
nrow, ncol = info["max_rows"], info["num_columns"]
742776
if row_first:
743-
return ([[_get_or_default(items, r * ncol + c, default=empty) for c in range(ncol)] for r in range(nrow)], info)
777+
return (
778+
[
779+
[
780+
_get_or_default(
781+
items, r * ncol + c, default=empty
782+
) # type:ignore[misc]
783+
for c in range(ncol)
784+
]
785+
for r in range(nrow)
786+
],
787+
info,
788+
)
744789
else:
745-
return ([[_get_or_default(items, c * nrow + r, default=empty) for c in range(ncol)] for r in range(nrow)], info)
790+
return (
791+
[
792+
[
793+
_get_or_default(
794+
items, c * nrow + r, default=empty
795+
) # type:ignore[misc]
796+
for c in range(ncol)
797+
]
798+
for r in range(nrow)
799+
],
800+
info,
801+
)
746802

747803

748804
def columnize(
@@ -795,7 +851,9 @@ def columnize(
795851
return "\n".join(map(sjoin, fmatrix)) + "\n"
796852

797853

798-
def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
854+
def get_text_list(
855+
list_: List[str], last_sep: str = " and ", sep: str = ", ", wrap_item_with: str = ""
856+
) -> str:
799857
"""
800858
Return a string with a natural enumeration of items
801859

pyproject.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,22 @@ warn_redundant_casts = true
150150
module = [
151151
"IPython.utils.text",
152152
]
153+
disallow_untyped_defs = true
154+
check_untyped_defs = false
155+
disallow_untyped_decorators = true
156+
157+
[[tool.mypy.overrides]]
158+
module = [
159+
]
153160
disallow_untyped_defs = false
161+
ignore_errors = true
162+
ignore_missing_imports = true
163+
disallow_untyped_calls = false
164+
disallow_incomplete_defs = false
154165
check_untyped_defs = false
155166
disallow_untyped_decorators = false
156167

157-
158-
# global ignore error
168+
# gloabl ignore error
159169
[[tool.mypy.overrides]]
160170
module = [
161171
"IPython",

0 commit comments

Comments
 (0)