Skip to content

Commit 50871f8

Browse files
authored
Merge pull request #206 from dlgallagher/requirements_txt_fixer_followup
Some style tweaks (requirements_txt_fixer)
2 parents 042c840 + 844d983 commit 50871f8

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

pre_commit_hooks/requirements_txt_fixer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
import argparse
44

55

6+
PASS = 0
7+
FAIL = 1
8+
9+
610
class Requirement(object):
711

812
def __init__(self):
@@ -30,14 +34,14 @@ def __lt__(self, requirement):
3034

3135
def fix_requirements(f):
3236
requirements = []
33-
before = list(f)
37+
before = tuple(f)
3438
after = []
3539

3640
before_string = b''.join(before)
3741

3842
# If the file is empty (i.e. only whitespace/newlines) exit early
3943
if before_string.strip() == b'':
40-
return 0
44+
return PASS
4145

4246
for line in before:
4347
# If the most recent requirement object has a value, then it's
@@ -60,27 +64,26 @@ def fix_requirements(f):
6064
requirement.value = line
6165

6266
for requirement in sorted(requirements):
63-
for comment in requirement.comments:
64-
after.append(comment)
67+
after.extend(requirement.comments)
6568
after.append(requirement.value)
6669

6770
after_string = b''.join(after)
6871

6972
if before_string == after_string:
70-
return 0
73+
return PASS
7174
else:
7275
f.seek(0)
7376
f.write(after_string)
7477
f.truncate()
75-
return 1
78+
return FAIL
7679

7780

7881
def fix_requirements_txt(argv=None):
7982
parser = argparse.ArgumentParser()
8083
parser.add_argument('filenames', nargs='*', help='Filenames to fix')
8184
args = parser.parse_args(argv)
8285

83-
retv = 0
86+
retv = PASS
8487

8588
for arg in args.filenames:
8689
with open(arg, 'rb+') as file_obj:

tests/requirements_txt_fixer_test.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,41 @@
11
import pytest
22

3+
from pre_commit_hooks.requirements_txt_fixer import FAIL
34
from pre_commit_hooks.requirements_txt_fixer import fix_requirements_txt
5+
from pre_commit_hooks.requirements_txt_fixer import PASS
46
from pre_commit_hooks.requirements_txt_fixer import Requirement
57

6-
# Input, expected return value, expected output
7-
TESTS = (
8-
(b'', 0, b''),
9-
(b'\n', 0, b'\n'),
10-
(b'foo\nbar\n', 1, b'bar\nfoo\n'),
11-
(b'bar\nfoo\n', 0, b'bar\nfoo\n'),
12-
(b'#comment1\nfoo\n#comment2\nbar\n', 1, b'#comment2\nbar\n#comment1\nfoo\n'),
13-
(b'#comment1\nbar\n#comment2\nfoo\n', 0, b'#comment1\nbar\n#comment2\nfoo\n'),
14-
(b'#comment\n\nfoo\nbar\n', 1, b'#comment\n\nbar\nfoo\n'),
15-
(b'#comment\n\nbar\nfoo\n', 0, b'#comment\n\nbar\nfoo\n'),
16-
(b'\nfoo\nbar\n', 1, b'bar\n\nfoo\n'),
17-
(b'\nbar\nfoo\n', 0, b'\nbar\nfoo\n'),
18-
(b'pyramid==1\npyramid-foo==2\n', 0, b'pyramid==1\npyramid-foo==2\n'),
19-
(b'ocflib\nDjango\nPyMySQL\n', 1, b'Django\nocflib\nPyMySQL\n'),
20-
(b'-e git+ssh://git_url@tag#egg=ocflib\nDjango\nPyMySQL\n', 1, b'Django\n-e git+ssh://git_url@tag#egg=ocflib\nPyMySQL\n'),
21-
)
22-
238

24-
@pytest.mark.parametrize(('input_s', 'expected_retval', 'output'), TESTS)
9+
@pytest.mark.parametrize(
10+
('input_s', 'expected_retval', 'output'),
11+
(
12+
(b'', PASS, b''),
13+
(b'\n', PASS, b'\n'),
14+
(b'foo\nbar\n', FAIL, b'bar\nfoo\n'),
15+
(b'bar\nfoo\n', PASS, b'bar\nfoo\n'),
16+
(b'#comment1\nfoo\n#comment2\nbar\n', FAIL, b'#comment2\nbar\n#comment1\nfoo\n'),
17+
(b'#comment1\nbar\n#comment2\nfoo\n', PASS, b'#comment1\nbar\n#comment2\nfoo\n'),
18+
(b'#comment\n\nfoo\nbar\n', FAIL, b'#comment\n\nbar\nfoo\n'),
19+
(b'#comment\n\nbar\nfoo\n', PASS, b'#comment\n\nbar\nfoo\n'),
20+
(b'\nfoo\nbar\n', FAIL, b'bar\n\nfoo\n'),
21+
(b'\nbar\nfoo\n', PASS, b'\nbar\nfoo\n'),
22+
(b'pyramid==1\npyramid-foo==2\n', PASS, b'pyramid==1\npyramid-foo==2\n'),
23+
(b'ocflib\nDjango\nPyMySQL\n', FAIL, b'Django\nocflib\nPyMySQL\n'),
24+
(
25+
b'-e git+ssh://git_url@tag#egg=ocflib\nDjango\nPyMySQL\n',
26+
FAIL,
27+
b'Django\n-e git+ssh://git_url@tag#egg=ocflib\nPyMySQL\n'
28+
),
29+
)
30+
)
2531
def test_integration(input_s, expected_retval, output, tmpdir):
2632
path = tmpdir.join('file.txt')
2733
path.write_binary(input_s)
2834

29-
assert fix_requirements_txt([path.strpath]) == expected_retval
35+
output_retval = fix_requirements_txt([path.strpath])
36+
3037
assert path.read_binary() == output
38+
assert output_retval == expected_retval
3139

3240

3341
def test_requirement_object():

0 commit comments

Comments
 (0)