Skip to content

Commit b72a5fe

Browse files
committed
refactor: apply linting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 4e15605 commit b72a5fe

File tree

11 files changed

+49
-22
lines changed

11 files changed

+49
-22
lines changed

core/lowering/LowerInfo.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ namespace core {
99
namespace lowering {
1010

1111
std::ostream& operator<<(std::ostream& os, const LowerInfo& l) {
12-
os << "Settings requested for Lowering:" << std::endl;
13-
os << " Forced Fallback Modules: [" << std::endl;
14-
for (auto i : l.forced_fallback_modules) {
15-
os << " " << i << std::endl;
16-
}
17-
os << " ]";
18-
return os;
12+
os << "Settings requested for Lowering:" << std::endl;
13+
os << " Forced Fallback Modules: [" << std::endl;
14+
for (auto i : l.forced_fallback_modules) {
15+
os << " " << i << std::endl;
16+
}
17+
os << " ]";
18+
return os;
1919
}
2020

2121
} // namespace lowering

core/lowering/lowering.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6060
LOG_GRAPH(*g);
6161
}
6262

63-
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, std::unordered_set<std::string> forced_fallback_modules) {
63+
torch::jit::Module LowerModule(
64+
const torch::jit::Module& mod,
65+
std::string method_name,
66+
std::unordered_set<std::string> forced_fallback_modules) {
6467
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
6568
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
6669
auto mod_ = torch::jit::freeze_module(mod);
@@ -70,7 +73,8 @@ torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method
7073

7174
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
7275
const torch::jit::Module& mod,
73-
std::string method_name, const LowerInfo& lower_info) {
76+
std::string method_name,
77+
const LowerInfo& lower_info) {
7478
LOG_DEBUG(lower_info);
7579
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
7680
std::unordered_set<std::string> forced_fallback_modules(

core/lowering/lowering.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@ struct LowerInfo {
2121

2222
void LowerBlock(torch::jit::Block* b);
2323
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
24-
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, std::unordered_set<std::string> forced_fallback_modules);
24+
torch::jit::Module LowerModule(
25+
const torch::jit::Module& mod,
26+
std::string method_name,
27+
std::unordered_set<std::string> forced_fallback_modules);
2528
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
2629
const torch::jit::Module& mod,
27-
std::string method_name, const LowerInfo& lower_info);
30+
std::string method_name,
31+
const LowerInfo& lower_info);
2832

2933
} // namespace lowering
3034
} // namespace core

core/lowering/passes/module_fallback.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ std::string unmangle_cls_name(const std::string& name) {
2525
return unmangled;
2626
}
2727

28-
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, std::string method_name, std::unordered_set<std::string> forced_fallback_modules) {
28+
void NotateModuleForFallback(
29+
const torch::jit::Module& mod,
30+
std::string mod_name,
31+
std::string method_name,
32+
std::unordered_set<std::string> forced_fallback_modules) {
2933
auto cls_name = unmangle_cls_name(mod.type()->name()->qualifiedName());
3034

3135
auto g = mod.get_method(method_name).graph();
@@ -35,7 +39,9 @@ void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name
3539
if (n->kind() == torch::jit::prim::GetAttr) {
3640
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
3741
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
38-
LOG_DEBUG("Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]");
42+
LOG_DEBUG(
43+
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
44+
<< " (" << cls_name << ")]");
3945
auto uses = n->output(0)->uses();
4046
for (const auto u : uses) {
4147
auto user = u.user;

core/lowering/passes/passes.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10-
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, std::string method_name, std::unordered_set<std::string> forced_fallback_modules);
10+
void NotateModuleForFallback(
11+
const torch::jit::Module& mod,
12+
std::string mod_name,
13+
std::string method_name,
14+
std::unordered_set<std::string> forced_fallback_modules);
1115
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1216
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1317
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);

core/partitioning/partitioning.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
275275

276276
std::string node_string(n->kind().toQualString());
277277
auto has_compile_attribute = n->hasAttribute(c10::Symbol::attr("to_compile"));
278-
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) && (!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
278+
if (conversion::OpSupported(n) && !forced_fallback_operators.count(node_string) &&
279+
(!has_compile_attribute || n->i(c10::Symbol::attr("to_compile")) == (int64_t) true)) {
279280
tensorrt_nodes.push_back(n);
280281
if (tensorrt_nodes.size() >= min_block_size && !pytorch_nodes.empty()) {
281282
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);

py/trtorch/csrc/tensorrt_backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
3535
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
3636
LOG_DEBUG(raw_spec->stringify());
3737
auto cfg = raw_spec->toInternalCompileSpec();
38-
auto graph_and_ivals = Lower(mod_, method_name, cfg.lower_info);
38+
auto graph_and_ivals = core::lowering::Lower(mod_, method_name, cfg.lower_info);
3939

4040
auto g = graph_and_ivals.first;
4141
auto params = graph_and_ivals.second;

tests/core/lowering/test_module_fallback_passes.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
#include <unordered_set>
33
#include "core/compiler.h"
44
#include "core/lowering/lowering.h"
5+
#include "core/lowering/passes/passes.h"
56
#include "gtest/gtest.h"
67
#include "tests/util/util.h"
7-
#include "torch/script.h"
8-
#include "core/lowering/passes/passes.h"
98
#include "torch/csrc/jit/passes/freeze_module.h"
9+
#include "torch/script.h"
1010

1111
TEST(Lowering, NotateModuleForFallbackWorksCorrectly) {
1212
torch::jit::script::Module mod;
@@ -124,4 +124,3 @@ TEST(Lowering, LowerAndPartitionSimpleModuleFallbackCorrectly) {
124124
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
125125
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
126126
}
127-

tests/cpp/test_module_fallback.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ TEST(CppAPITests, LowerResNetModuleFallbackCorrectly) {
2222
trt_inputs_ivalues.push_back(in.clone());
2323
}
2424

25-
std::vector<trtorch::CompileSpec::Input> input_ranges{trtorch::CompileSpec::Input(std::vector<int64_t>({1, 3, 224, 224}))};
25+
std::vector<trtorch::CompileSpec::Input> input_ranges{
26+
trtorch::CompileSpec::Input(std::vector<int64_t>({1, 3, 224, 224}))};
2627
trtorch::CompileSpec cfg(input_ranges);
2728
cfg.torch_fallback.enabled = true;
2829
cfg.torch_fallback.forced_fallback_modules.push_back("torchvision.models.resnet.BasicBlock");
@@ -51,7 +52,8 @@ TEST(CppAPITests, LowerAndPartitionMobileNetModuleFallbackCorrectly) {
5152
trt_inputs_ivalues.push_back(in.clone());
5253
}
5354

54-
std::vector<trtorch::CompileSpec::Input> input_ranges{trtorch::CompileSpec::Input(std::vector<int64_t>({1, 3, 224, 224}))};
55+
std::vector<trtorch::CompileSpec::Input> input_ranges{
56+
trtorch::CompileSpec::Input(std::vector<int64_t>({1, 3, 224, 224}))};
5557
trtorch::CompileSpec cfg(input_ranges);
5658
cfg.torch_fallback.enabled = true;
5759
cfg.torch_fallback.min_block_size = 5;

tests/modules/hub.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def forward(self, x):
9797
trace_model = torch.jit.trace(model, x)
9898
torch.jit.save(trace_model, "pooling_traced.jit.pt")
9999

100+
100101
# Sample Nested Module (for module-level fallback testing)
101102
class ModuleFallbackSub(nn.Module):
102103

@@ -108,6 +109,7 @@ def __init__(self):
108109
def forward(self, x):
109110
return self.relu(self.conv(x))
110111

112+
111113
class ModuleFallbackMain(nn.Module):
112114

113115
def __init__(self):
@@ -119,10 +121,12 @@ def __init__(self):
119121
def forward(self, x):
120122
return self.relu(self.conv(self.layer1(x)))
121123

124+
122125
module_fallback_model = ModuleFallbackMain().eval().cuda()
123126
module_fallback_script_model = torch.jit.script(module_fallback_model)
124127
torch.jit.save(module_fallback_script_model, "module_fallback_scripted.jit.pt")
125128

129+
126130
# Sample Conditional Model (for testing partitioning and fallback in conditionals)
127131
class FallbackIf(torch.nn.Module):
128132

0 commit comments

Comments
 (0)