Skip to content

Commit b9a53da

Browse files
committed
Add 3d convolution support with testcases
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 2809f2b commit b9a53da

File tree

2 files changed

+165
-1
lines changed

2 files changed

+165
-1
lines changed

core/lowering/passes/conv3d_to_convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
1515
std::string convolution_pattern = R"IR(
1616
graph(%x, %w, %b, %s, %p, %d, %g):
1717
%1 : bool = prim::Constant[value=0]()
18-
%2 : int[] = prim::Constant[value=[0, 0]]()
18+
%2 : int[] = prim::Constant[value=[0, 0, 0]]()
1919
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1)
2020
return (%4))IR";;
2121

tests/core/converters/test_conv_deconv.cpp

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,170 @@ TEST(Converters, ATenConvolutionWithPaddingConvertsCorrectly) {
203203
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
204204
}
205205

206+
TEST(Converters, ATenConvolution3dConvertsCorrectly) {
207+
const auto graph = R"IR(
208+
graph(%0 : Tensor,
209+
%1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
210+
%2 : Float(32:1)):
211+
%sv : int = prim::Constant[value=1]()
212+
%s : int[] = prim::ListConstruct(%sv, %sv, %sv)
213+
%pv : int = prim::Constant[value=0]()
214+
%p : int[] = prim::ListConstruct(%pv, %pv, %pv)
215+
%transposed : bool = prim::Constant[value=0]()
216+
%opv : int = prim::Constant[value=0]()
217+
%op : int[] = prim::ListConstruct(%opv, %opv, %opv)
218+
%g : int = prim::Constant[value=1]()
219+
%fb : bool = prim::Constant[value=0]()
220+
%out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
221+
return (%out))IR";
222+
223+
auto g = std::make_shared<torch::jit::Graph>();
224+
torch::jit::parseIR(graph, &*g);
225+
226+
auto in = at::randint(1, 10, {1, 3, 5, 5, 5}, {at::kCUDA});
227+
auto w = at::randint(1, 10, {32, 3, 3, 3, 3}, {at::kCUDA});
228+
auto b = at::randint(1, 10, {32}, {at::kCUDA});
229+
230+
auto jit_in = at::clone(in);
231+
auto jit_w = at::clone(w);
232+
auto jit_b = at::clone(b);
233+
234+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
235+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
236+
237+
auto trt_in = at::clone(in);
238+
auto trt_w = at::clone(w);
239+
auto trt_b = at::clone(b);
240+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
241+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
242+
243+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
244+
245+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
246+
}
247+
248+
TEST(Converters, ATenConvolution3dNoBiasConvertsCorrectly) {
249+
const auto graph = R"IR(
250+
graph(%0 : Tensor,
251+
%1 : Float(32:81, 3:27, 3:9, 3:3, 3:1)):
252+
%bias : None = prim::Constant()
253+
%sv : int = prim::Constant[value=1]()
254+
%s : int[] = prim::ListConstruct(%sv, %sv, %sv)
255+
%pv : int = prim::Constant[value=0]()
256+
%p : int[] = prim::ListConstruct(%pv, %pv, %pv)
257+
%transposed : bool = prim::Constant[value=0]()
258+
%opv : int = prim::Constant[value=0]()
259+
%op : int[] = prim::ListConstruct(%opv, %opv, %opv)
260+
%g : int = prim::Constant[value=1]()
261+
%fb : bool = prim::Constant[value=0]()
262+
%out : Tensor = aten::_convolution(%0, %1, %bias, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
263+
return (%out))IR";
264+
265+
auto g = std::make_shared<torch::jit::Graph>();
266+
torch::jit::parseIR(graph, &*g);
267+
268+
auto in = at::randint(1, 2, {1, 3, 5, 5, 5}, {at::kCUDA});
269+
auto w = at::randint(1, 2, {32, 3, 3, 3, 3}, {at::kCUDA});
270+
271+
auto jit_in = at::clone(in);
272+
auto jit_w = at::clone(w);
273+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w});
274+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
275+
276+
auto trt_in = at::clone(in);
277+
auto trt_w = at::clone(w);
278+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w});
279+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
280+
281+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
282+
283+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
284+
}
285+
286+
TEST(Converters, ATenConvolution3dWithPaddingConvertsCorrectly) {
287+
const auto graph = R"IR(
288+
graph(%0 : Tensor,
289+
%1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
290+
%2 : Float(32:1)):
291+
%sv : int = prim::Constant[value=1]()
292+
%s : int[] = prim::ListConstruct(%sv, %sv, %sv)
293+
%pv : int = prim::Constant[value=1]()
294+
%p : int[] = prim::ListConstruct(%pv, %pv, %pv)
295+
%transposed : bool = prim::Constant[value=0]()
296+
%opv : int = prim::Constant[value=0]()
297+
%op : int[] = prim::ListConstruct(%opv, %opv, %opv)
298+
%g : int = prim::Constant[value=1]()
299+
%fb : bool = prim::Constant[value=0]()
300+
%out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
301+
return (%out))IR";
302+
303+
auto g = std::make_shared<torch::jit::Graph>();
304+
torch::jit::parseIR(graph, &*g);
305+
306+
auto in = at::randint(1, 10, {1, 3, 5, 5, 5}, {at::kCUDA});
307+
auto w = at::randint(1, 10, {32, 3, 3, 3, 3}, {at::kCUDA});
308+
auto b = at::randint(1, 10, {32}, {at::kCUDA});
309+
310+
auto jit_in = at::clone(in);
311+
auto jit_w = at::clone(w);
312+
auto jit_b = at::clone(b);
313+
314+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
315+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
316+
317+
auto trt_in = at::clone(in);
318+
auto trt_w = at::clone(w);
319+
auto trt_b = at::clone(b);
320+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
321+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
322+
323+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
324+
325+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
326+
}
327+
328+
TEST(Converters, ATenConvolution3dWithStrideDilationConvertsCorrectly) {
329+
const auto graph = R"IR(
330+
graph(%0 : Tensor,
331+
%1 : Float(32:81, 3:27, 3:9, 3:3, 3:1),
332+
%2 : Float(32:1)):
333+
%sv : int = prim::Constant[value=2]()
334+
%s : int[] = prim::ListConstruct(%sv, %sv, %sv)
335+
%pv : int = prim::Constant[value=1]()
336+
%p : int[] = prim::ListConstruct(%pv, %pv, %pv)
337+
%transposed : bool = prim::Constant[value=0]()
338+
%opv : int = prim::Constant[value=0]()
339+
%op : int[] = prim::ListConstruct(%opv, %opv, %opv)
340+
%g : int = prim::Constant[value=1]()
341+
%fb : bool = prim::Constant[value=0]()
342+
%out : Tensor = aten::_convolution(%0, %1, %2, %s, %p, %s, %transposed, %op, %g, %fb, %fb, %fb)
343+
return (%out))IR";
344+
345+
auto g = std::make_shared<torch::jit::Graph>();
346+
torch::jit::parseIR(graph, &*g);
347+
348+
auto in = at::randint(1, 10, {1, 3, 5, 5, 5}, {at::kCUDA});
349+
auto w = at::randint(1, 10, {32, 3, 3, 3, 3}, {at::kCUDA});
350+
auto b = at::randint(1, 10, {32}, {at::kCUDA});
351+
352+
auto jit_in = at::clone(in);
353+
auto jit_w = at::clone(w);
354+
auto jit_b = at::clone(b);
355+
356+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {jit_w, jit_b});
357+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
358+
359+
auto trt_in = at::clone(in);
360+
auto trt_w = at::clone(w);
361+
auto trt_b = at::clone(b);
362+
params = trtorch::core::conversion::get_named_params(g->inputs(), {trt_w, trt_b});
363+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
364+
365+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
366+
367+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
368+
}
369+
206370
TEST(Converters, ATenConvTransposeConvertsCorrectly) {
207371
const auto graph = R"IR(
208372
graph(%0 : Tensor,

0 commit comments

Comments
 (0)