@@ -43,33 +43,33 @@ TEST(Partitioning, FallbackOnConditionalsCorrectly) {
43
43
ASSERT_TRUE (conditional_engines_count == 2 );
44
44
}
45
45
46
- // TEST(Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
47
- // torch::jit::script::Module mod;
48
- // try {
49
- // mod = torch::jit::load("tests/modules/inplace_op_if_scripted.jit.pt");
50
- // } catch (const c10::Error& e) {
51
- // std::cerr << "error loading the model\n";
52
- // return;
53
- // }
54
- //
55
- // const std::vector<std::vector<int64_t>> input_shapes = {{4, 4}, {4, 4}};
56
- // std::vector<torch::jit::IValue> jit_inputs_ivalues;
57
- // std::vector<torch::jit::IValue> trt_inputs_ivalues;
58
- // for (auto in_shape : input_shapes) {
59
- // auto in = at::randint(5, in_shape, {at::kCUDA});
60
- // jit_inputs_ivalues.push_back(in.clone());
61
- // trt_inputs_ivalues.push_back(in.clone());
62
- // }
63
- //
64
- // std::vector<torch_tensorrt::core::ir::Input> inputs{
65
- // torch_tensorrt::core::ir::Input({4, 4}), torch_tensorrt::core::ir::Input({4, 4})};
66
- // auto g = mod.get_method("forward").graph();
67
- // torch_tensorrt::core::CompileSpec cfg(inputs);
68
- // cfg.partitioning_info.enabled = true;
69
- // cfg.partitioning_info.forced_fallback_operators.push_back("prim::ListConstruct");
70
- //
71
- // auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
72
- // auto trt_mod = torch_tensorrt::core::CompileGraph(mod, cfg);
73
- // auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
74
- // ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(jit_results, trt_results));
75
- // }
46
+ TEST (Partitioning, FallbackInplaceOPInConditionalsCorrectly) {
47
+ torch::jit::script::Module mod;
48
+ try {
49
+ mod = torch::jit::load (" tests/modules/inplace_op_if_scripted.jit.pt" );
50
+ } catch (const c10::Error& e) {
51
+ std::cerr << " error loading the model\n " ;
52
+ return ;
53
+ }
54
+
55
+ const std::vector<std::vector<int64_t >> input_shapes = {{4 , 4 }, {4 , 4 }};
56
+ std::vector<torch::jit::IValue> jit_inputs_ivalues;
57
+ std::vector<torch::jit::IValue> trt_inputs_ivalues;
58
+ for (auto in_shape : input_shapes) {
59
+ auto in = at::randint (5 , in_shape, {at::kCUDA });
60
+ jit_inputs_ivalues.push_back (in.clone ());
61
+ trt_inputs_ivalues.push_back (in.clone ());
62
+ }
63
+
64
+ std::vector<torch_tensorrt::core::ir::Input> inputs{
65
+ torch_tensorrt::core::ir::Input ({4 , 4 }), torch_tensorrt::core::ir::Input ({4 , 4 })};
66
+ auto g = mod.get_method (" forward" ).graph ();
67
+ torch_tensorrt::core::CompileSpec cfg (inputs);
68
+ cfg.partitioning_info .enabled = true ;
69
+ cfg.partitioning_info .forced_fallback_operators .push_back (" prim::ListConstruct" );
70
+
71
+ auto jit_results = mod.forward (jit_inputs_ivalues).toTensor ();
72
+ auto trt_mod = torch_tensorrt::core::CompileGraph (mod, cfg);
73
+ auto trt_results = trt_mod.forward (trt_inputs_ivalues).toTensor ();
74
+ ASSERT_TRUE (torch_tensorrt::tests::util::cosineSimEqual (jit_results, trt_results));
75
+ }
0 commit comments