File tree Expand file tree Collapse file tree 3 files changed +193
-151
lines changed Expand file tree Collapse file tree 3 files changed +193
-151
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,17 @@ void main() {
4747
4848 // Compute the start and end of the input indices to load. Padding is assumed
4949 // to be constant 0 padding, so reads from the padding region are skipped.
50- const ivec2 start = max (ivec2 (0 ), ipos);
50+ ivec2 start = ipos;
51+ if (start.x < 0 ) {
52+ // number of "steps" to get to >= zero is div_up(-start, dilation)
53+ int num_steps = ((- ipos.x) + dilation.x - 1 ) / dilation.x;
54+ start.x = ipos.x + num_steps * dilation.x;
55+ }
56+ if (start.y < 0 ) {
57+ // number of "steps" to get to >= zero is div_up(-start, dilation)
58+ int num_steps = ((- ipos.y) + dilation.y - 1 ) / dilation.y;
59+ start.y = ipos.y + num_steps * dilation.y;
60+ }
5161 const ivec2 end = min (ipos + overlay_region.xy, ivec2 (in_sizes.xy));
5262 // Compute the start of the kernel based on how far we are skipping ahead when
5363 // reading the input. Note that these are "canonical" indices.
Original file line number Diff line number Diff line change @@ -262,11 +262,6 @@ void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
262262 " aten.convolution.default: transposed = true, dilation > 1 is not supported yet!" );
263263 }
264264 }
265- if ((p.padding [0 ] > 0 && p.kernel_size [0 ] > 1 && p.dilation [0 ] > 1 ) ||
266- (p.padding [1 ] > 0 && p.kernel_size [1 ] > 1 && p.dilation [1 ] > 1 )) {
267- VK_THROW (
268- " aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!" );
269- }
270265}
271266
272267Conv2dMethod get_conv2d_method (
Original file line number Diff line number Diff line change @@ -226,153 +226,190 @@ def get_max_pool2d_inputs():
226226
227227@register_test_suite ("aten.convolution.default" )
228228def get_conv_inputs ():
229- test_suite = VkTestSuite (
229+ Test = namedtuple (
230+ "ConvTest" ,
230231 [
231- (
232- (1 , 6 , 40 , 50 ),
233- (8 , 6 , 3 , 3 ),
234- (8 ,),
235- [1 , 2 ],
236- [2 , 3 ],
237- [1 , 1 ],
238- False ,
239- [0 , 0 ],
240- 1 ,
241- ),
242- (
243- (1 , 6 , 40 , 50 ),
244- (6 , 8 , 3 , 3 ),
245- (8 ,),
246- [1 , 2 ],
247- [2 , 3 ],
248- [1 , 1 ],
249- True ,
250- [0 , 1 ],
251- 1 ,
252- ),
253- (
254- (1 , 8 , 72 , 96 ),
255- (8 , 1 , 3 , 3 ),
256- (8 ,),
257- [1 , 1 ],
258- [1 , 1 ],
259- [1 , 1 ],
260- False ,
261- [0 , 0 ],
262- 8 ,
263- ),
264- (
265- (1 , 8 , 72 , 96 ),
266- (8 , 8 , 1 , 1 ),
267- (8 ,),
268- [1 , 1 ],
269- [1 , 1 ],
270- [1 , 1 ],
271- False ,
272- [0 , 0 ],
273- 1 ,
274- ),
275- (
276- (1 , 6 , 40 , 50 ),
277- (8 , 6 , 3 , 3 ),
278- None ,
279- [1 , 2 ],
280- [2 , 3 ],
281- [1 , 1 ],
282- False ,
283- [0 , 0 ],
284- 1 ,
285- ),
286- (
287- (1 , 6 , 7 ),
288- (6 , 1 , 3 ),
289- (6 ,),
290- [1 ],
291- [0 ],
292- [1 ],
293- False ,
294- [0 ],
295- 6 ,
296- ),
297- (
298- (2 , 20 , 30 ),
299- (10 , 4 , 6 ),
300- (10 ,),
301- [5 ],
302- [5 ],
303- [3 ],
304- False ,
305- [0 ],
306- 5 ,
307- ),
308- (
309- (1 , 9 , 11 ),
310- (9 , 1 , 3 ),
311- None ,
312- [1 ],
313- [0 ],
314- [1 ],
315- False ,
316- [0 ],
317- 9 ,
318- ),
319- (
320- (5 , 15 , 30 ),
321- (20 , 3 , 3 ),
322- None ,
323- [3 ],
324- [5 ],
325- [7 ],
326- False ,
327- [0 ],
328- 5 ,
329- ),
330- (
331- (1 , 16 , 672 , 512 ),
332- (64 , 16 , 1 , 1 ),
333- (64 ,),
334- [1 , 1 ],
335- [0 , 0 ],
336- [1 , 1 ],
337- False ,
338- [0 , 0 ],
339- 1 ,
340- ),
341- (
342- (1 , 4 , 234 , 234 ),
343- (4 , 1 , 3 , 3 ),
344- (4 ,),
345- [2 , 1 ],
346- [1 , 1 ],
347- [1 , 1 ],
348- False ,
349- [0 , 0 ],
350- 4 ,
351- ),
352- (
353- (1 , 4 , 234 , 234 ),
354- (4 , 1 , 3 , 3 ),
355- (4 ,),
356- [1 , 2 ],
357- [1 , 1 ],
358- [1 , 1 ],
359- False ,
360- [0 , 0 ],
361- 4 ,
362- ),
363- (
364- (1 , 4 , 234 , 234 ),
365- (4 , 1 , 3 , 3 ),
366- (4 ,),
367- [2 , 2 ],
368- [1 , 1 ],
369- [1 , 1 ],
370- False ,
371- [0 , 0 ],
372- 4 ,
373- ),
374- ]
232+ "self" ,
233+ "weight" ,
234+ "bias" ,
235+ "stride" ,
236+ "padding" ,
237+ "dilation" ,
238+ "transposed" ,
239+ "output_padding" ,
240+ "groups" ,
241+ ],
242+ )
243+ Test .__new__ .__defaults__ = (
244+ None ,
245+ None ,
246+ None ,
247+ [1 , 1 ],
248+ [0 , 0 ],
249+ [1 , 1 ],
250+ False ,
251+ [9 , 0 ],
252+ 1 ,
375253 )
254+ test_cases = []
255+ test_cases = [
256+ Test (
257+ self = (1 , 6 , 40 , 50 ),
258+ weight = (8 , 6 , 3 , 3 ),
259+ bias = (8 ,),
260+ stride = [1 , 2 ],
261+ padding = [2 , 3 ],
262+ dilation = [1 , 1 ],
263+ transposed = False ,
264+ output_padding = [0 , 0 ],
265+ groups = 1 ,
266+ ),
267+ Test (
268+ self = (1 , 6 , 40 , 50 ),
269+ weight = (6 , 8 , 3 , 3 ),
270+ bias = (8 ,),
271+ stride = [1 , 2 ],
272+ padding = [2 , 3 ],
273+ dilation = [1 , 1 ],
274+ transposed = True ,
275+ output_padding = [0 , 1 ],
276+ groups = 1 ,
277+ ),
278+ Test (
279+ self = (1 , 8 , 72 , 96 ),
280+ weight = (8 , 1 , 3 , 3 ),
281+ bias = (8 ,),
282+ stride = [1 , 1 ],
283+ padding = [1 , 1 ],
284+ dilation = [1 , 1 ],
285+ transposed = False ,
286+ output_padding = [0 , 0 ],
287+ groups = 8 ,
288+ ),
289+ Test (
290+ self = (1 , 8 , 72 , 96 ),
291+ weight = (8 , 8 , 1 , 1 ),
292+ bias = (8 ,),
293+ stride = [1 , 1 ],
294+ padding = [1 , 1 ],
295+ dilation = [1 , 1 ],
296+ transposed = False ,
297+ output_padding = [0 , 0 ],
298+ groups = 1 ,
299+ ),
300+ Test (
301+ self = (1 , 6 , 40 , 50 ),
302+ weight = (8 , 6 , 3 , 3 ),
303+ bias = None ,
304+ stride = [1 , 2 ],
305+ padding = [2 , 3 ],
306+ dilation = [1 , 1 ],
307+ transposed = False ,
308+ output_padding = [0 , 0 ],
309+ groups = 1 ,
310+ ),
311+ Test (
312+ self = (1 , 6 , 7 ),
313+ weight = (6 , 1 , 3 ),
314+ bias = (6 ,),
315+ stride = [1 ],
316+ padding = [0 ],
317+ dilation = [1 ],
318+ transposed = False ,
319+ output_padding = [0 ],
320+ groups = 6 ,
321+ ),
322+ Test (
323+ self = (2 , 20 , 30 ),
324+ weight = (10 , 4 , 6 ),
325+ bias = (10 ,),
326+ stride = [5 ],
327+ padding = [5 ],
328+ dilation = [3 ],
329+ transposed = False ,
330+ output_padding = [0 ],
331+ groups = 5 ,
332+ ),
333+ Test (
334+ self = (1 , 9 , 11 ),
335+ weight = (9 , 1 , 3 ),
336+ bias = None ,
337+ stride = [1 ],
338+ padding = [0 ],
339+ dilation = [1 ],
340+ transposed = False ,
341+ output_padding = [0 ],
342+ groups = 9 ,
343+ ),
344+ Test (
345+ self = (5 , 15 , 30 ),
346+ weight = (20 , 3 , 3 ),
347+ bias = None ,
348+ stride = [3 ],
349+ padding = [5 ],
350+ dilation = [7 ],
351+ transposed = False ,
352+ output_padding = [0 ],
353+ groups = 5 ,
354+ ),
355+ Test (
356+ self = (1 , 16 , 672 , 512 ),
357+ weight = (64 , 16 , 1 , 1 ),
358+ bias = (64 ,),
359+ stride = [1 , 1 ],
360+ padding = [0 , 0 ],
361+ dilation = [1 , 1 ],
362+ transposed = False ,
363+ output_padding = [0 , 0 ],
364+ groups = 1 ,
365+ ),
366+ Test (
367+ self = (1 , 4 , 234 , 234 ),
368+ weight = (4 , 1 , 3 , 3 ),
369+ bias = (4 ,),
370+ stride = [2 , 1 ],
371+ padding = [1 , 1 ],
372+ dilation = [1 , 1 ],
373+ transposed = False ,
374+ output_padding = [0 , 0 ],
375+ groups = 4 ,
376+ ),
377+ Test (
378+ self = (1 , 4 , 234 , 234 ),
379+ weight = (4 , 1 , 3 , 3 ),
380+ bias = (4 ,),
381+ stride = [1 , 2 ],
382+ padding = [1 , 1 ],
383+ dilation = [1 , 1 ],
384+ transposed = False ,
385+ output_padding = [0 , 0 ],
386+ groups = 4 ,
387+ ),
388+ Test (
389+ self = (1 , 4 , 234 , 234 ),
390+ weight = (4 , 1 , 3 , 3 ),
391+ bias = (4 ,),
392+ stride = [2 , 2 ],
393+ padding = [1 , 1 ],
394+ dilation = [1 , 1 ],
395+ transposed = False ,
396+ output_padding = [0 , 0 ],
397+ groups = 4 ,
398+ ),
399+ Test (
400+ self = (1 , 8 , 90 , 77 ),
401+ weight = (1 , 8 , 3 , 3 ),
402+ bias = (1 ,),
403+ stride = [1 , 1 ],
404+ padding = [2 , 2 ],
405+ dilation = [2 , 2 ],
406+ transposed = False ,
407+ output_padding = [0 , 0 ],
408+ groups = 1 ,
409+ ),
410+ ]
411+
412+ test_suite = VkTestSuite (test_cases )
376413 return test_suite
377414
378415
You can’t perform that action at this time.
0 commit comments