Skip to content

Commit c3b8583

Browse files
authored
Merge pull request #439 from NVIDIA/clean_up_ptq_recipe
refactor: Fix python linting issues
2 parents 2af2c11 + a39dea7 commit c3b8583

File tree

5 files changed

+5
-3
lines changed

5 files changed

+5
-3
lines changed

.github/scripts/run_cpp_linter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
if output.returncode != 0:
2424
comment = '''There are some changes that do not conform to C++ style guidelines:\n ```diff\n{}```'''.format(output.stdout.decode("utf-8"))
2525
approval = 'REQUEST_CHANGES'
26+
exit(1)
2627

2728
pr.create_review(commit, comment, approval)
2829

.github/scripts/run_py_linter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@
2323
if output.returncode != 0:
2424
comment = '''There are some changes that do not conform to Python style guidelines:\n ```diff\n{}```'''.format(output.stdout.decode("utf-8"))
2525
approval = 'REQUEST_CHANGES'
26+
exit(1)
2627

2728
pr.create_review(commit, comment, approval)

cpp/ptq/training/vgg16/export_ckpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def test(model, dataloader, crit):
2222

2323
with torch.no_grad():
2424
for data, labels in dataloader:
25-
data, labels = data.cuda(), labels.cuda(async=True)
25+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
2626
out = model(data)
2727
loss += crit(out, labels)
2828
preds = torch.max(out, 1)[1]

cpp/ptq/training/vgg16/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def train(model, dataloader, crit, opt, epoch):
141141
model.train()
142142
running_loss = 0.0
143143
for batch, (data, labels) in enumerate(dataloader):
144-
data, labels = data.cuda(), labels.cuda(async=True)
144+
data, labels = data.cuda(), labels.cuda(non_blocking=True)
145145
opt.zero_grad()
146146
out = model(data)
147147
loss = crit(out, labels)

py/trtorch/_compile_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
176176
if "max_batch_size" in compile_spec:
177177
assert type(compile_spec["max_batch_size"]) is int
178178
info.max_batch_size = compile_spec["max_batch_size"]
179-
179+
180180
if "truncate_long_and_double" in compile_spec:
181181
assert type(compile_spec["truncate_long_and_double"]) is bool
182182
info.truncate_long_and_double = compile_spec["truncate_long_and_double"]

0 commit comments

Comments
 (0)