@@ -855,3 +855,76 @@ TEST(Converters, ATenUnbindNegativeAxisConvertsCorrectly) {
855
855
ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
856
856
}
857
857
}
858
+
859
+ TEST (Converters, ScatterValueConvertsCorrectly) {
860
+ const auto graph = R"IR(
861
+ graph(%data : Tensor,
862
+ %index.1 : Tensor):
863
+ %value : int = prim::Constant[value=100]()
864
+ %dim : int = prim::Constant[value=1]()
865
+ %5 : NoneType = prim::Constant()
866
+ %6 : bool = prim::Constant[value=0]()
867
+ %7 : int = prim::Constant[value=4]()
868
+ %index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
869
+ %10 : Tensor = aten::scatter(%data, %dim, %index, %value)
870
+ return (%10))IR" ;
871
+
872
+ auto g = std::make_shared<torch::jit::Graph>();
873
+
874
+ torch::jit::parseIR (graph, g.get ());
875
+
876
+ auto index = at::randint (0 , 5 , {2 , 2 }, {at::kCUDA });
877
+ auto data = at::randn ({5 , 5 }, {at::kCUDA });
878
+
879
+ auto jit_index = at::clone (index);
880
+ auto jit_data = at::clone (data);
881
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
882
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_data, jit_index});
883
+
884
+ auto trt_index = at::clone (index);
885
+ auto trt_data = at::clone (data);
886
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_data, trt_index});
887
+
888
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
889
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
890
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
891
+ }
892
+ }
893
+
894
+ TEST (Converters, ScatterSrcConvertsCorrectly) {
895
+ const auto graph = R"IR(
896
+ graph(%data : Tensor,
897
+ %src : Tensor,
898
+ %index.1 : Tensor):
899
+ %dim : int = prim::Constant[value=1]()
900
+ %5 : NoneType = prim::Constant()
901
+ %6 : bool = prim::Constant[value=0]()
902
+ %7 : int = prim::Constant[value=4]()
903
+ %index : Tensor = aten::to(%index.1, %7, %6, %6, %5)
904
+ %10 : Tensor = aten::scatter(%data, %dim, %index, %src)
905
+ return (%10))IR" ;
906
+
907
+ auto g = std::make_shared<torch::jit::Graph>();
908
+
909
+ torch::jit::parseIR (graph, g.get ());
910
+
911
+ auto index = at::randint (0 , 4 , {2 , 2 }, {at::kCUDA });
912
+ auto data = at::randn ({5 , 5 }, {at::kCUDA });
913
+ auto src = at::randn ({2 , 2 }, {at::kCUDA });
914
+
915
+ auto jit_index = at::clone (index);
916
+ auto jit_data = at::clone (data);
917
+ auto jit_src = at::clone (src);
918
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
919
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {jit_data, jit_src, jit_index});
920
+
921
+ auto trt_index = at::clone (index);
922
+ auto trt_data = at::clone (data);
923
+ auto trt_src = at::clone (src);
924
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {trt_data, trt_src, trt_index});
925
+
926
+ for (size_t i = 0 ; i < jit_results.size (); i++) {
927
+ auto trt = trt_results[i].reshape (jit_results[i].sizes ());
928
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[i], trt, 2e-6 ));
929
+ }
930
+ }
0 commit comments