Skip to content

Commit b8348c9

Browse files
committed
Further Python 3 type annotations on top-level files
1 parent 5a3ba37 commit b8348c9

File tree

4 files changed

+51
-41
lines changed

4 files changed

+51
-41
lines changed

nbgrader/coursedir.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from traitlets import Integer, Bool, Unicode, List, default, validate, TraitError
88

99
from .utils import full_split, parse_utc
10+
from traitlets.utils.bunch import Bunch
11+
import datetime
12+
from typing import Optional
1013

1114

1215
class CourseDirectory(LoggingConfigurable):
@@ -50,7 +53,7 @@ def _validate_course_id(self, proposal):
5053
).tag(config=True)
5154

5255
@validate('student_id')
53-
def _validate_student_id(self, proposal):
56+
def _validate_student_id(self, proposal: Bunch) -> str:
5457
if proposal['value'].strip() != proposal['value']:
5558
self.log.warning("student_id '%s' has trailing whitespace, stripping it away", proposal['value'])
5659
return proposal['value'].strip()
@@ -80,7 +83,7 @@ def _validate_student_id(self, proposal):
8083
).tag(config=True)
8184

8285
@validate('assignment_id')
83-
def _validate_assignment_id(self, proposal):
86+
def _validate_assignment_id(self, proposal: Bunch) -> str:
8487
if '+' in proposal['value']:
8588
raise TraitError('Assignment names should not contain the following characters: +')
8689
if proposal['value'].strip() != proposal['value']:
@@ -98,7 +101,7 @@ def _validate_assignment_id(self, proposal):
98101
).tag(config=True)
99102

100103
@validate('notebook_id')
101-
def _validate_notebook_id(self, proposal):
104+
def _validate_notebook_id(self, proposal: Bunch) -> str:
102105
if proposal['value'].strip() != proposal['value']:
103106
self.log.warning("notebook_id '%s' has trailing whitespace, stripping it away", proposal['value'])
104107
return proposal['value'].strip()
@@ -248,11 +251,11 @@ def _db_url_default(self):
248251
).tag(config=True)
249252

250253
@default("root")
251-
def _root_default(self):
254+
def _root_default(self) -> str:
252255
return os.getcwd()
253256

254257
@validate('root')
255-
def _validate_root(self, proposal):
258+
def _validate_root(self, proposal: Bunch) -> str:
256259
path = os.path.abspath(proposal['value'])
257260
if path != proposal['value']:
258261
self.log.warning("root '%s' is not absolute, standardizing it to '%s", proposal['value'], path)
@@ -298,7 +301,7 @@ def _validate_root(self, proposal):
298301
)
299302
).tag(config=True)
300303

301-
def format_path(self, nbgrader_step, student_id, assignment_id, escape=False):
304+
def format_path(self, nbgrader_step: str, student_id: str, assignment_id: str, escape: bool = False) -> str:
302305
kwargs = dict(
303306
nbgrader_step=nbgrader_step,
304307
student_id=student_id,
@@ -314,7 +317,7 @@ def format_path(self, nbgrader_step, student_id, assignment_id, escape=False):
314317

315318
return path
316319

317-
def get_existing_timestamp(self, dest_path):
320+
def get_existing_timestamp(self, dest_path: str) -> Optional[datetime.datetime]:
318321
"""Get the timestamp, as a datetime object, of an existing submission."""
319322
timestamp_path = os.path.join(dest_path, 'timestamp.txt')
320323
if os.path.exists(timestamp_path):

nbgrader/dbutil.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88

99
from contextlib import contextmanager
1010
from subprocess import check_call
11+
from typing import Iterator
1112

1213
_here = os.path.abspath(os.path.dirname(__file__))
1314

1415
ALEMBIC_INI_TEMPLATE_PATH = os.path.join(_here, 'alembic.ini')
1516
ALEMBIC_DIR = os.path.join(_here, 'alembic')
1617

1718

18-
def write_alembic_ini(alembic_ini='alembic.ini', db_url='sqlite:///gradebook.db'):
19+
def write_alembic_ini(alembic_ini: str = 'alembic.ini', db_url: str = 'sqlite:///gradebook.db') -> None:
1920
"""Write a complete alembic.ini from our template.
2021
Parameters
2122
----------
@@ -37,17 +38,17 @@ def write_alembic_ini(alembic_ini='alembic.ini', db_url='sqlite:///gradebook.db'
3738

3839

3940
@contextmanager
40-
def _temp_alembic_ini(db_url):
41+
def _temp_alembic_ini(db_url: str) -> Iterator[str]:
4142
"""Context manager for temporary JupyterHub alembic directory
4243
Temporarily write an alembic.ini file for use with alembic migration scripts.
4344
Context manager yields alembic.ini path.
4445
Parameters
4546
----------
46-
db_url: str
47+
db_url:
4748
The SQLAlchemy database url, e.g. `sqlite:///gradebook.db`.
4849
Returns
4950
-------
50-
alembic_ini: str
51+
alembic_ini:
5152
The path to the temporary alembic.ini that we have created.
5253
This file will be cleaned up on exit from the context manager.
5354
"""

nbgrader/utils.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from tornado.log import LogFormatter
1818
from dateutil.tz import gettz
1919
from datetime import datetime
20+
from nbformat.notebooknode import NotebookNode
21+
from logging import Logger
22+
from typing import Optional, Tuple, Union, List, Iterator, Any
2023

2124
# pwd is for unix passwords only, so we shouldn't import it on
2225
# windows machines
@@ -26,28 +29,28 @@
2629
pwd = None
2730

2831

29-
def is_task(cell):
32+
def is_task(cell: NotebookNode) -> bool:
3033
"""Returns True if the cell is a task cell."""
3134
if 'nbgrader' not in cell.metadata:
3235
return False
3336
return cell.metadata['nbgrader'].get('task', False)
3437

3538

36-
def is_grade(cell):
39+
def is_grade(cell: NotebookNode) -> bool:
3740
"""Returns True if the cell is a grade cell."""
3841
if 'nbgrader' not in cell.metadata:
3942
return False
4043
return cell.metadata['nbgrader'].get('grade', False)
4144

4245

43-
def is_solution(cell):
46+
def is_solution(cell: NotebookNode) -> bool:
4447
"""Returns True if the cell is a solution cell."""
4548
if 'nbgrader' not in cell.metadata:
4649
return False
4750
return cell.metadata['nbgrader'].get('solution', False)
4851

4952

50-
def is_locked(cell):
53+
def is_locked(cell: NotebookNode) -> bool:
5154
"""Returns True if the cell source is locked (will be overwritten)."""
5255
if 'nbgrader' not in cell.metadata:
5356
return False
@@ -90,7 +93,8 @@ def get_partial_grade(output, max_points, log=None):
9093
log.warning(warning_msg)
9194
return max_points
9295

93-
def determine_grade(cell, log=None):
96+
97+
def determine_grade(cell: NotebookNode, log: Logger = None) -> Tuple[Optional[float], float]:
9498
if not is_grade(cell):
9599
raise ValueError("cell is not a grade cell")
96100

@@ -126,12 +130,12 @@ def determine_grade(cell, log=None):
126130
return None, max_points
127131

128132

129-
def to_bytes(string):
133+
def to_bytes(string: str) -> bytes:
130134
"""A helper function for converting a string to bytes with utf-8 encoding."""
131135
return bytes(string.encode('utf-8'))
132136

133137

134-
def compute_checksum(cell):
138+
def compute_checksum(cell: NotebookNode) -> str:
135139
m = hashlib.md5()
136140
# add the cell source and type
137141
m.update(to_bytes(cell.source))
@@ -152,7 +156,7 @@ def compute_checksum(cell):
152156
return m.hexdigest()
153157

154158

155-
def parse_utc(ts):
159+
def parse_utc(ts: Union[datetime, str]) -> datetime:
156160
"""Parses a timestamp into datetime format, converting it to UTC if necessary."""
157161
if ts is None:
158162
return None
@@ -237,7 +241,7 @@ def self_owned(path):
237241
return get_osusername() == find_owner(os.path.abspath(path))
238242

239243

240-
def is_ignored(filename, ignore_globs=None):
244+
def is_ignored(filename: str, ignore_globs: List[str] = None) -> bool:
241245
"""Determines whether a filename should be ignored, based on whether it
242246
matches any file glob in the given list. Note that this only matches on the
243247
base filename itself, not the full path."""
@@ -304,7 +308,7 @@ def ignore_patterns(directory, filelist):
304308
return ignore_patterns
305309

306310

307-
def find_all_files(path, exclude=None):
311+
def find_all_files(path: str, exclude: List[str] = None) -> List[str]:
308312
"""Recursively finds all filenames rooted at `path`, optionally excluding
309313
some based on filename globs."""
310314
files = []
@@ -333,7 +337,7 @@ def find_all_notebooks(path):
333337
return notebooks
334338

335339

336-
def full_split(path):
340+
def full_split(path: str) -> Tuple[str, ...]:
337341
rest, last = os.path.split(path)
338342
if last == path:
339343
return (path,)
@@ -344,7 +348,7 @@ def full_split(path):
344348

345349

346350
@contextlib.contextmanager
347-
def chdir(dirname):
351+
def chdir(dirname: str) -> Iterator:
348352
currdir = os.getcwd()
349353
if dirname:
350354
os.chdir(dirname)
@@ -355,8 +359,8 @@ def chdir(dirname):
355359

356360

357361
@contextlib.contextmanager
358-
def setenv(**kwargs):
359-
previous_env = { }
362+
def setenv(**kwargs: Any) -> Iterator:
363+
previous_env = {}
360364
for key, value in kwargs.items():
361365
previous_env[key] = os.environ.get(value)
362366
os.environ[key] = value
@@ -368,7 +372,7 @@ def setenv(**kwargs):
368372
os.environ[key] = previous_env[key]
369373

370374

371-
def rmtree(path):
375+
def rmtree(path: str) -> None:
372376
# for windows, we need to go through and make sure everything
373377
# is writeable, otherwise rmtree will fail
374378
if sys.platform == 'win32':
@@ -381,7 +385,7 @@ def rmtree(path):
381385
shutil.rmtree(path)
382386

383387

384-
def remove(path):
388+
def remove(path: str) -> None:
385389
# for windows, we need to make sure that the file is writeable,
386390
# otherwise remove will fail
387391
if sys.platform == 'win32':

nbgrader/validator.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from .preprocessors import Execute, ClearOutput, CheckCellMetadata
1111
from . import utils
12+
from nbformat.notebooknode import NotebookNode
13+
import typing
1214

1315

1416
class Validator(LoggingConfigurable):
@@ -94,7 +96,7 @@ class Validator(LoggingConfigurable):
9496

9597
stream = sys.stdout
9698

97-
def _indent(self, val):
99+
def _indent(self, val: str) -> str:
98100
lines = val.split("\n")
99101
new_lines = []
100102
for line in lines:
@@ -104,7 +106,7 @@ def _indent(self, val):
104106
new_lines.append(new_line)
105107
return "\n".join(new_lines)
106108

107-
def _extract_error(self, cell):
109+
def _extract_error(self, cell: NotebookNode) -> str:
108110
errors = []
109111

110112
# possibilities:
@@ -129,7 +131,7 @@ def _extract_error(self, cell):
129131

130132
return "\n".join(errors)
131133

132-
def _print_type_changed(self, old_type, new_type, source):
134+
def _print_type_changed(self, old_type: str, new_type: str, source: str) -> None:
133135
self.stream.write("\n" + "=" * self.width + "\n")
134136
self.stream.write(
135137
"The following {} cell has changed to a {} cell:\n\n".format(
@@ -141,7 +143,7 @@ def _print_changed(self, source):
141143
self.stream.write("The following cell has changed:\n\n")
142144
self.stream.write(self._indent(source) + "\n\n")
143145

144-
def _print_error(self, source, error):
146+
def _print_error(self, source: str, error: str) -> None:
145147
self.stream.write("\n" + "=" * self.width + "\n")
146148
self.stream.write("The following cell failed:\n\n")
147149
self.stream.write(self._indent(source) + "\n\n")
@@ -153,7 +155,7 @@ def _print_pass(self, source):
153155
self.stream.write("The following cell passed:\n\n")
154156
self.stream.write(self._indent(source) + "\n\n")
155157

156-
def _print_num_type_changed(self, num_changed):
158+
def _print_num_type_changed(self, num_changed: int) -> None:
157159
if num_changed == 0:
158160
return
159161

@@ -165,7 +167,7 @@ def _print_num_type_changed(self, num_changed):
165167
)
166168
)
167169

168-
def _print_num_changed(self, num_changed):
170+
def _print_num_changed(self, num_changed: int) -> None:
169171
if num_changed == 0:
170172
return
171173

@@ -177,7 +179,7 @@ def _print_num_changed(self, num_changed):
177179
)
178180
)
179181

180-
def _print_num_failed(self, num_failed):
182+
def _print_num_failed(self, num_failed: int) -> None:
181183
if num_failed == 0:
182184
self.stream.write("Success! Your notebook passes all the tests.\n")
183185

@@ -201,7 +203,7 @@ def _print_num_passed(self, num_passed):
201203
)
202204
)
203205

204-
def _get_type_changed_cells(self, nb):
206+
def _get_type_changed_cells(self, nb: NotebookNode) -> typing.List[NotebookNode]:
205207
changed = []
206208

207209
for cell in nb.cells:
@@ -217,7 +219,7 @@ def _get_type_changed_cells(self, nb):
217219

218220
return changed
219221

220-
def _get_changed_cells(self, nb):
222+
def _get_changed_cells(self, nb: NotebookNode) -> typing.List:
221223
changed = []
222224
for cell in nb.cells:
223225
if not (utils.is_grade(cell) or utils.is_locked(cell)):
@@ -237,7 +239,7 @@ def _get_changed_cells(self, nb):
237239

238240
return changed
239241

240-
def _get_failed_cells(self, nb):
242+
def _get_failed_cells(self, nb: NotebookNode) -> typing.List[NotebookNode]:
241243
failed = []
242244
for cell in nb.cells:
243245
if not (self.validate_all or utils.is_grade(cell) or utils.is_locked(cell)):
@@ -260,7 +262,7 @@ def _get_failed_cells(self, nb):
260262

261263
return failed
262264

263-
def _get_passed_cells(self, nb):
265+
def _get_passed_cells(self, nb: NotebookNode) -> typing.List[NotebookNode]:
264266
passed = []
265267
for cell in nb.cells:
266268
if not (utils.is_grade(cell) or utils.is_locked(cell)):
@@ -278,7 +280,7 @@ def _get_passed_cells(self, nb):
278280

279281
return passed
280282

281-
def _preprocess(self, nb):
283+
def _preprocess(self, nb: NotebookNode) -> NotebookNode:
282284
resources = {}
283285
with utils.setenv(NBGRADER_VALIDATING='1'):
284286
for preprocessor in self.preprocessors:
@@ -290,7 +292,7 @@ def _preprocess(self, nb):
290292
nb, resources = pp.preprocess(nb, resources)
291293
return nb
292294

293-
def validate(self, filename):
295+
def validate(self, filename: str) -> typing.Dict[str, typing.List[typing.Dict[str, str]]]:
294296
self.log.info("Validating '{}'".format(os.path.abspath(filename)))
295297
basename = os.path.basename(filename)
296298
dirname = os.path.dirname(filename)
@@ -336,7 +338,7 @@ def validate(self, filename):
336338

337339
return results
338340

339-
def validate_and_print(self, filename):
341+
def validate_and_print(self, filename: str) -> None:
340342
results = self.validate(filename)
341343
type_changed = results.get('type_changed', [])
342344
changed = results.get('changed', [])

0 commit comments

Comments
 (0)