File tree Expand file tree Collapse file tree 3 files changed +193
-151
lines changed Expand file tree Collapse file tree 3 files changed +193
-151
lines changed Original file line number Diff line number Diff line change @@ -47,7 +47,17 @@ void main() {
4747
4848  //  Compute the start and end of the input indices to load. Padding is assumed
4949  //  to be constant 0 padding, so reads from the padding region are skipped.
50-   const  ivec2  start =  max (ivec2 (0 ), ipos);
50+   ivec2  start =  ipos;
51+   if  (start.x <  0 ) {
52+     //  number of "steps" to get to >= zero is div_up(-start, dilation)
53+     int  num_steps =  ((- ipos.x) +  dilation.x -  1 ) /  dilation.x;
54+     start.x =  ipos.x +  num_steps *  dilation.x;
55+   }
56+   if  (start.y <  0 ) {
57+     //  number of "steps" to get to >= zero is div_up(-start, dilation)
58+     int  num_steps =  ((- ipos.y) +  dilation.y -  1 ) /  dilation.y;
59+     start.y =  ipos.y +  num_steps *  dilation.y;
60+   }
5161  const  ivec2  end =  min (ipos +  overlay_region.xy, ivec2 (in_sizes.xy));
5262  //  Compute the start of the kernel based on how far we are skipping ahead when
5363  //  reading the input. Note that these are "canonical" indices.
Original file line number Diff line number Diff line change @@ -262,11 +262,6 @@ void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
262262          " aten.convolution.default: transposed = true, dilation > 1 is not supported yet!"  );
263263    }
264264  }
265-   if  ((p.padding [0 ] > 0  && p.kernel_size [0 ] > 1  && p.dilation [0 ] > 1 ) ||
266-       (p.padding [1 ] > 0  && p.kernel_size [1 ] > 1  && p.dilation [1 ] > 1 )) {
267-     VK_THROW (
268-         " aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!"  );
269-   }
270265}
271266
272267Conv2dMethod get_conv2d_method (
Original file line number Diff line number Diff line change @@ -226,153 +226,190 @@ def get_max_pool2d_inputs():
226226
227227@register_test_suite ("aten.convolution.default" ) 
228228def  get_conv_inputs ():
229-     test_suite  =  VkTestSuite (
229+     Test  =  namedtuple (
230+         "ConvTest" ,
230231        [
231-             (
232-                 (1 , 6 , 40 , 50 ),
233-                 (8 , 6 , 3 , 3 ),
234-                 (8 ,),
235-                 [1 , 2 ],
236-                 [2 , 3 ],
237-                 [1 , 1 ],
238-                 False ,
239-                 [0 , 0 ],
240-                 1 ,
241-             ),
242-             (
243-                 (1 , 6 , 40 , 50 ),
244-                 (6 , 8 , 3 , 3 ),
245-                 (8 ,),
246-                 [1 , 2 ],
247-                 [2 , 3 ],
248-                 [1 , 1 ],
249-                 True ,
250-                 [0 , 1 ],
251-                 1 ,
252-             ),
253-             (
254-                 (1 , 8 , 72 , 96 ),
255-                 (8 , 1 , 3 , 3 ),
256-                 (8 ,),
257-                 [1 , 1 ],
258-                 [1 , 1 ],
259-                 [1 , 1 ],
260-                 False ,
261-                 [0 , 0 ],
262-                 8 ,
263-             ),
264-             (
265-                 (1 , 8 , 72 , 96 ),
266-                 (8 , 8 , 1 , 1 ),
267-                 (8 ,),
268-                 [1 , 1 ],
269-                 [1 , 1 ],
270-                 [1 , 1 ],
271-                 False ,
272-                 [0 , 0 ],
273-                 1 ,
274-             ),
275-             (
276-                 (1 , 6 , 40 , 50 ),
277-                 (8 , 6 , 3 , 3 ),
278-                 None ,
279-                 [1 , 2 ],
280-                 [2 , 3 ],
281-                 [1 , 1 ],
282-                 False ,
283-                 [0 , 0 ],
284-                 1 ,
285-             ),
286-             (
287-                 (1 , 6 , 7 ),
288-                 (6 , 1 , 3 ),
289-                 (6 ,),
290-                 [1 ],
291-                 [0 ],
292-                 [1 ],
293-                 False ,
294-                 [0 ],
295-                 6 ,
296-             ),
297-             (
298-                 (2 , 20 , 30 ),
299-                 (10 , 4 , 6 ),
300-                 (10 ,),
301-                 [5 ],
302-                 [5 ],
303-                 [3 ],
304-                 False ,
305-                 [0 ],
306-                 5 ,
307-             ),
308-             (
309-                 (1 , 9 , 11 ),
310-                 (9 , 1 , 3 ),
311-                 None ,
312-                 [1 ],
313-                 [0 ],
314-                 [1 ],
315-                 False ,
316-                 [0 ],
317-                 9 ,
318-             ),
319-             (
320-                 (5 , 15 , 30 ),
321-                 (20 , 3 , 3 ),
322-                 None ,
323-                 [3 ],
324-                 [5 ],
325-                 [7 ],
326-                 False ,
327-                 [0 ],
328-                 5 ,
329-             ),
330-             (
331-                 (1 , 16 , 672 , 512 ),
332-                 (64 , 16 , 1 , 1 ),
333-                 (64 ,),
334-                 [1 , 1 ],
335-                 [0 , 0 ],
336-                 [1 , 1 ],
337-                 False ,
338-                 [0 , 0 ],
339-                 1 ,
340-             ),
341-             (
342-                 (1 , 4 , 234 , 234 ),
343-                 (4 , 1 , 3 , 3 ),
344-                 (4 ,),
345-                 [2 , 1 ],
346-                 [1 , 1 ],
347-                 [1 , 1 ],
348-                 False ,
349-                 [0 , 0 ],
350-                 4 ,
351-             ),
352-             (
353-                 (1 , 4 , 234 , 234 ),
354-                 (4 , 1 , 3 , 3 ),
355-                 (4 ,),
356-                 [1 , 2 ],
357-                 [1 , 1 ],
358-                 [1 , 1 ],
359-                 False ,
360-                 [0 , 0 ],
361-                 4 ,
362-             ),
363-             (
364-                 (1 , 4 , 234 , 234 ),
365-                 (4 , 1 , 3 , 3 ),
366-                 (4 ,),
367-                 [2 , 2 ],
368-                 [1 , 1 ],
369-                 [1 , 1 ],
370-                 False ,
371-                 [0 , 0 ],
372-                 4 ,
373-             ),
374-         ]
232+             "self" ,
233+             "weight" ,
234+             "bias" ,
235+             "stride" ,
236+             "padding" ,
237+             "dilation" ,
238+             "transposed" ,
239+             "output_padding" ,
240+             "groups" ,
241+         ],
242+     )
243+     Test .__new__ .__defaults__  =  (
244+         None ,
245+         None ,
246+         None ,
247+         [1 , 1 ],
248+         [0 , 0 ],
249+         [1 , 1 ],
250+         False ,
251+         [9 , 0 ],
252+         1 ,
375253    )
254+     test_cases  =  []
255+     test_cases  =  [
256+         Test (
257+             self = (1 , 6 , 40 , 50 ),
258+             weight = (8 , 6 , 3 , 3 ),
259+             bias = (8 ,),
260+             stride = [1 , 2 ],
261+             padding = [2 , 3 ],
262+             dilation = [1 , 1 ],
263+             transposed = False ,
264+             output_padding = [0 , 0 ],
265+             groups = 1 ,
266+         ),
267+         Test (
268+             self = (1 , 6 , 40 , 50 ),
269+             weight = (6 , 8 , 3 , 3 ),
270+             bias = (8 ,),
271+             stride = [1 , 2 ],
272+             padding = [2 , 3 ],
273+             dilation = [1 , 1 ],
274+             transposed = True ,
275+             output_padding = [0 , 1 ],
276+             groups = 1 ,
277+         ),
278+         Test (
279+             self = (1 , 8 , 72 , 96 ),
280+             weight = (8 , 1 , 3 , 3 ),
281+             bias = (8 ,),
282+             stride = [1 , 1 ],
283+             padding = [1 , 1 ],
284+             dilation = [1 , 1 ],
285+             transposed = False ,
286+             output_padding = [0 , 0 ],
287+             groups = 8 ,
288+         ),
289+         Test (
290+             self = (1 , 8 , 72 , 96 ),
291+             weight = (8 , 8 , 1 , 1 ),
292+             bias = (8 ,),
293+             stride = [1 , 1 ],
294+             padding = [1 , 1 ],
295+             dilation = [1 , 1 ],
296+             transposed = False ,
297+             output_padding = [0 , 0 ],
298+             groups = 1 ,
299+         ),
300+         Test (
301+             self = (1 , 6 , 40 , 50 ),
302+             weight = (8 , 6 , 3 , 3 ),
303+             bias = None ,
304+             stride = [1 , 2 ],
305+             padding = [2 , 3 ],
306+             dilation = [1 , 1 ],
307+             transposed = False ,
308+             output_padding = [0 , 0 ],
309+             groups = 1 ,
310+         ),
311+         Test (
312+             self = (1 , 6 , 7 ),
313+             weight = (6 , 1 , 3 ),
314+             bias = (6 ,),
315+             stride = [1 ],
316+             padding = [0 ],
317+             dilation = [1 ],
318+             transposed = False ,
319+             output_padding = [0 ],
320+             groups = 6 ,
321+         ),
322+         Test (
323+             self = (2 , 20 , 30 ),
324+             weight = (10 , 4 , 6 ),
325+             bias = (10 ,),
326+             stride = [5 ],
327+             padding = [5 ],
328+             dilation = [3 ],
329+             transposed = False ,
330+             output_padding = [0 ],
331+             groups = 5 ,
332+         ),
333+         Test (
334+             self = (1 , 9 , 11 ),
335+             weight = (9 , 1 , 3 ),
336+             bias = None ,
337+             stride = [1 ],
338+             padding = [0 ],
339+             dilation = [1 ],
340+             transposed = False ,
341+             output_padding = [0 ],
342+             groups = 9 ,
343+         ),
344+         Test (
345+             self = (5 , 15 , 30 ),
346+             weight = (20 , 3 , 3 ),
347+             bias = None ,
348+             stride = [3 ],
349+             padding = [5 ],
350+             dilation = [7 ],
351+             transposed = False ,
352+             output_padding = [0 ],
353+             groups = 5 ,
354+         ),
355+         Test (
356+             self = (1 , 16 , 672 , 512 ),
357+             weight = (64 , 16 , 1 , 1 ),
358+             bias = (64 ,),
359+             stride = [1 , 1 ],
360+             padding = [0 , 0 ],
361+             dilation = [1 , 1 ],
362+             transposed = False ,
363+             output_padding = [0 , 0 ],
364+             groups = 1 ,
365+         ),
366+         Test (
367+             self = (1 , 4 , 234 , 234 ),
368+             weight = (4 , 1 , 3 , 3 ),
369+             bias = (4 ,),
370+             stride = [2 , 1 ],
371+             padding = [1 , 1 ],
372+             dilation = [1 , 1 ],
373+             transposed = False ,
374+             output_padding = [0 , 0 ],
375+             groups = 4 ,
376+         ),
377+         Test (
378+             self = (1 , 4 , 234 , 234 ),
379+             weight = (4 , 1 , 3 , 3 ),
380+             bias = (4 ,),
381+             stride = [1 , 2 ],
382+             padding = [1 , 1 ],
383+             dilation = [1 , 1 ],
384+             transposed = False ,
385+             output_padding = [0 , 0 ],
386+             groups = 4 ,
387+         ),
388+         Test (
389+             self = (1 , 4 , 234 , 234 ),
390+             weight = (4 , 1 , 3 , 3 ),
391+             bias = (4 ,),
392+             stride = [2 , 2 ],
393+             padding = [1 , 1 ],
394+             dilation = [1 , 1 ],
395+             transposed = False ,
396+             output_padding = [0 , 0 ],
397+             groups = 4 ,
398+         ),
399+         Test (
400+             self = (1 , 8 , 90 , 77 ),
401+             weight = (1 , 8 , 3 , 3 ),
402+             bias = (1 ,),
403+             stride = [1 , 1 ],
404+             padding = [2 , 2 ],
405+             dilation = [2 , 2 ],
406+             transposed = False ,
407+             output_padding = [0 , 0 ],
408+             groups = 1 ,
409+         ),
410+     ]
411+ 
412+     test_suite  =  VkTestSuite (test_cases )
376413    return  test_suite 
377414
378415
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments