@@ -163,7 +163,9 @@ def _parse_scalar_argument(entry):
163163 return np .frombuffer (bytes (data ), dtype = dtype )[0 ]
164164
165165
166- def _parse_array_file (file_name : str , data_dir : str , expect_hash : str , dtype , validate : bool ):
166+ def _parse_array_file (
167+ file_name : str , data_dir : str , expect_hash : str , dtype , validate : bool
168+ ):
167169 file_path = os .path .join (data_dir , file_name )
168170
169171 if file_name .endswith (".gz" ):
@@ -190,15 +192,21 @@ def _parse_array_argument(entry: dict, data_dir: str, validate_checksum: bool):
190192 dtype = _type_name_to_dtype (type_name [:- 1 ])
191193
192194 if dtype is None :
193- logger .warning (f" unknown type \ "{ type_name } \ " , falling back to byte array" )
195+ logger .warning (f' unknown type "{ type_name } ", falling back to byte array' )
194196 dtype = np .byte
195197
196- arg = _parse_array_file (entry ["file" ], data_dir , entry .get ("file_hash" ),
197- dtype , validate_checksum )
198+ arg = _parse_array_file (
199+ entry ["file" ], data_dir , entry .get ("file_hash" ), dtype , validate_checksum
200+ )
198201
199202 if "reference_file" in entry :
200- answer = _parse_array_file (entry ["reference_file" ], data_dir , entry .get ("reference_hash" ),
201- dtype , validate_checksum )
203+ answer = _parse_array_file (
204+ entry ["reference_file" ],
205+ data_dir ,
206+ entry .get ("reference_hash" ),
207+ dtype ,
208+ validate_checksum ,
209+ )
202210 else :
203211 answer = None
204212
@@ -230,7 +238,15 @@ def __init__(self, obj, data_dir: str, validate_checksum: bool = True):
230238 self .args .append (arg )
231239 self .answers .append (answer )
232240
233- def _tune_options (self , working_dir = None , lang = "cupy" , compiler_options = None , defines = None , device = 0 , ** kwargs ):
241+ def _tune_options (
242+ self ,
243+ working_dir = None ,
244+ lang = "cupy" ,
245+ compiler_options = None ,
246+ defines = None ,
247+ device = 0 ,
248+ ** kwargs ,
249+ ):
234250 if working_dir is None :
235251 working_dir = os .getcwd ()
236252
@@ -277,20 +293,21 @@ def grid_size(config):
277293 restrictions = [e .resolve (** context ) for e in self .space .restrictions ]
278294
279295 options = dict (
280- kernel_name = self .kernel .generate_name (),
281- kernel_source = self .kernel .generate_source (working_dir ),
282- arguments = self .args ,
283- problem_size = grid_size ,
284- restrictions = lambda config : all (f (config ) for f in restrictions ),
285- defines = all_defines ,
286- compiler_options = compiler_options ,
287- block_size_names = block_size_names ,
288- grid_div_x = [],
289- grid_div_y = [],
290- grid_div_z = [],
291- lang = lang ,
292- device = device ,
293- ** kwargs )
296+ kernel_name = self .kernel .generate_name (),
297+ kernel_source = self .kernel .generate_source (working_dir ),
298+ arguments = self .args ,
299+ problem_size = grid_size ,
300+ restrictions = lambda config : all (f (config ) for f in restrictions ),
301+ defines = all_defines ,
302+ compiler_options = compiler_options ,
303+ block_size_names = block_size_names ,
304+ grid_div_x = [],
305+ grid_div_y = [],
306+ grid_div_z = [],
307+ lang = lang ,
308+ device = device ,
309+ ** kwargs ,
310+ )
294311
295312 os .chdir (working_dir )
296313 return extra_params , options
@@ -351,16 +368,25 @@ def tune(self, params=None, **kwargs):
351368 strategy = "brute_force" if total_configs < 100 else "bayes_opt"
352369
353370 return kernel_tuner .tune_kernel (
354- tune_params = params ,
355- strategy = strategy ,
356- answer = answer ,
357- verify = verify ,
358- ** options )
371+ tune_params = params ,
372+ strategy = strategy ,
373+ answer = answer ,
374+ verify = verify ,
375+ ** options ,
376+ )
359377
360378
361379def _fancy_verify (answers , outputs , * , atol = None ):
362- INTEGRAL_DTYPES = [np .int8 , np .int16 , np .int32 , np .int64 ,
363- np .uint8 , np .uint16 , np .uint32 , np .uint64 ]
380+ INTEGRAL_DTYPES = [
381+ np .int8 ,
382+ np .int16 ,
383+ np .int32 ,
384+ np .int64 ,
385+ np .uint8 ,
386+ np .uint16 ,
387+ np .uint32 ,
388+ np .uint64 ,
389+ ]
364390 FLOATING_DTYPES = [np .float16 , np .float32 , np .float64 ]
365391 PRINT_TOP_VALUES = 25
366392 DEFAULT_ATOL = 1e-8
@@ -378,7 +404,9 @@ def _fancy_verify(answers, outputs, *, atol=None):
378404 continue
379405
380406 if output .dtype != expected .dtype or output .shape != expected .shape :
381- raise RuntimeError (f"arrays data type or shape do not match: { output } and { expected } " )
407+ raise RuntimeError (
408+ f"arrays data type or shape do not match: { output } and { expected } "
409+ )
382410
383411 if output .dtype in INTEGRAL_DTYPES :
384412 matches = output == expected
@@ -401,14 +429,18 @@ def _fancy_verify(answers, outputs, *, atol=None):
401429 # indices = indices[np.argsort(errors[indices], kind="stable")][::-1]
402430
403431 percentage = nerrors / len (output ) * 100
404- print (f"argument { index + 1 } fails validation: { nerrors } incorrect values" +
405- f"({ percentage :.5} %)" )
432+ print (
433+ f"argument { index + 1 } fails validation: { nerrors } incorrect values"
434+ + f"({ percentage :.5} %)"
435+ )
406436
407437 errors = np .abs (output - expected )
408438
409439 for index in indices [:PRINT_TOP_VALUES ]:
410- print (f" * at index { index } : { output [index ]} != { expected [index ]} " +
411- f"(error: { errors [index ]} )" )
440+ print (
441+ f" * at index { index } : { output [index ]} != { expected [index ]} "
442+ + f"(error: { errors [index ]} )"
443+ )
412444
413445 if nerrors > PRINT_TOP_VALUES :
414446 print (f" * ({ nerrors - PRINT_TOP_VALUES } more entries have been omitted)" )
@@ -583,22 +615,27 @@ def resolve(self, problem_size, **kwargs):
583615
584616class DeviceAttributeExpr (Expr ):
585617 # Map cuda.h names to cupy names
586- NAME_MAPPING = dict ([
587- ('MAX_THREADS_PER_BLOCK' , 'MaxThreadsPerBlock' ),
588- ('MAX_BLOCK_DIM_X' , 'MaxBlockDimX' ),
589- ('MAX_BLOCK_DIM_Y' , 'MaxBlockDimY' ),
590- ('MAX_BLOCK_DIM_Z' , 'MaxBlockDimZ' ),
591- ('MAX_GRID_DIM_X' , 'MaxGridDimX' ),
592- ('MAX_GRID_DIM_Y' , 'MaxGridDimY' ),
593- ('MAX_GRID_DIM_Z' , 'MaxGridDimZ' ),
594- ('MAX_SHARED_MEMORY_PER_BLOCK' , 'MaxSharedMemoryPerBlock' ),
595- ('WARP_SIZE' , 'WarpSize' ),
596- ('MAX_REGISTERS_PER_BLOCK' , 'MaxRegistersPerBlock' ),
597- ('MULTIPROCESSOR_COUNT' , 'MultiProcessorCount' ),
598- ('MAX_THREADS_PER_MULTIPROCESSOR' , 'MaxThreadsPerMultiProcessor' ),
599- ('MAX_SHARED_MEMORY_PER_MULTIPROCESSOR' , 'MaxSharedMemoryPerMultiprocessor' ),
600- ('MAX_REGISTERS_PER_MULTIPROCESSOR' , 'MaxRegistersPerMultiprocessor' ),
601- ])
618+ NAME_MAPPING = dict (
619+ [
620+ ("MAX_THREADS_PER_BLOCK" , "MaxThreadsPerBlock" ),
621+ ("MAX_BLOCK_DIM_X" , "MaxBlockDimX" ),
622+ ("MAX_BLOCK_DIM_Y" , "MaxBlockDimY" ),
623+ ("MAX_BLOCK_DIM_Z" , "MaxBlockDimZ" ),
624+ ("MAX_GRID_DIM_X" , "MaxGridDimX" ),
625+ ("MAX_GRID_DIM_Y" , "MaxGridDimY" ),
626+ ("MAX_GRID_DIM_Z" , "MaxGridDimZ" ),
627+ ("MAX_SHARED_MEMORY_PER_BLOCK" , "MaxSharedMemoryPerBlock" ),
628+ ("WARP_SIZE" , "WarpSize" ),
629+ ("MAX_REGISTERS_PER_BLOCK" , "MaxRegistersPerBlock" ),
630+ ("MULTIPROCESSOR_COUNT" , "MultiProcessorCount" ),
631+ ("MAX_THREADS_PER_MULTIPROCESSOR" , "MaxThreadsPerMultiProcessor" ),
632+ (
633+ "MAX_SHARED_MEMORY_PER_MULTIPROCESSOR" ,
634+ "MaxSharedMemoryPerMultiprocessor" ,
635+ ),
636+ ("MAX_REGISTERS_PER_MULTIPROCESSOR" , "MaxRegistersPerMultiprocessor" ),
637+ ]
638+ )
602639
603640 def __init__ (self , name ):
604641 self .name = name
@@ -630,8 +667,10 @@ def evaluate(self, config):
630667 index = self .condition .evaluate (config )
631668
632669 if not is_int_like (index ) or index < 0 or index >= len (self .options ):
633- raise RuntimeError ("expression must yield an integer in " +
634- f"range 0..{ len (self .options )} : { self } " )
670+ raise RuntimeError (
671+ "expression must yield an integer in "
672+ + f"range 0..{ len (self .options )} : { self } "
673+ )
635674
636675 return self .options [int (index )].evaluate (config )
637676
@@ -641,8 +680,7 @@ def visit_children(self, fun):
641680
642681def _parse_expr (entry ) -> Expr :
643682 # literal int, str or float becomes ValueExpr.
644- if isinstance (entry , (int , str , float )) or \
645- entry is None :
683+ if isinstance (entry , (int , str , float )) or entry is None :
646684 return ValueExpr (entry )
647685
648686 # Otherwise it must be an operator expression
0 commit comments