@@ -307,7 +307,7 @@ class Variable:
307
307
def __init__ (self , iork :str , defstr :str ) -> None :
308
308
vs = defstr .split ()
309
309
self .iork = iork
310
- self .location = vs [0 ] # shared / register
310
+ self .location = vs [0 ] # shared / register
311
311
self .type = vs [1 ]
312
312
self .name = vs [2 ]
313
313
self .readfrom , self .addto , self .onlyif = None , None , None
@@ -323,7 +323,6 @@ def __init__(self, iork:str, defstr:str) -> None:
323
323
s3 = s2 .replace ('onlyif:' , '' )
324
324
self .onlyif = s3
325
325
326
-
327
326
@staticmethod
328
327
def _get_src (src , index ) -> str :
329
328
if ',' in src :
@@ -332,18 +331,16 @@ def _get_src(src, index) -> str:
332
331
else :
333
332
return '{}[{}]' .format (src , index )
334
333
335
-
336
334
def zero (self ) -> str :
337
335
v1 = ''
338
336
if self .location == 'shared' :
339
337
v1 = '{}[threadIdx.x] = 0;' .format (self .name )
340
338
else :
341
339
v1 = '{} = 0;' .format (self .name )
342
- if self .onlyif != None :
343
- v1 = 'if CONSTEXPR ({}) {}' .format (self .onlyif , v1 )
340
+ if self .onlyif is not None :
341
+ v1 = 'if CONSTEXPR ({}) {}' .format (self .onlyif , v1 )
344
342
return v1
345
343
346
-
347
344
def save (self ) -> str :
348
345
dst = self .addto
349
346
v1 , suffix = '' , ''
@@ -354,11 +351,10 @@ def save(self) -> str:
354
351
else :
355
352
vs = dst .split (',' )
356
353
v1 = 'atomic_add({}{}, &{}[{}][{}]);' .format (self .name , suffix , vs [0 ], self .iork , vs [1 ])
357
- if self .onlyif != None :
354
+ if self .onlyif is not None :
358
355
v1 = 'if CONSTEXPR ({}) {}' .format (self .onlyif , v1 )
359
356
return v1
360
357
361
-
362
358
def init_exclude (self ) -> str :
363
359
rhs = self ._get_src (self .readfrom , self .iork )
364
360
if self .location == 'shared' :
@@ -369,7 +365,6 @@ def init_exclude(self) -> str:
369
365
else :
370
366
return '{} = {};' .format (self .name , rhs )
371
367
372
-
373
368
def init_block (self ) -> str :
374
369
if self .readfrom in ['x' , 'y' , 'z' ]:
375
370
if self .location == 'shared' :
@@ -382,14 +377,12 @@ def init_block(self) -> str:
382
377
else :
383
378
return '{} = {};' .format (self .name , self ._get_src (self .readfrom , self .iork ))
384
379
385
-
386
380
def shuffle (self ) -> str :
387
381
if self .location == 'register' :
388
382
return '{0:} = __shfl_sync(ALL_LANES, {0:}, ilane + 1);' .format (self .name )
389
383
else :
390
384
raise ValueError ('Cannot shuffle variables in the shared memory.' )
391
385
392
-
393
386
def ikreplace (self , code :str ) -> str :
394
387
old_name = '@{}@' .format (self .name )
395
388
new_name = self .name
@@ -402,7 +395,6 @@ def ikreplace(self, code:str) -> str:
402
395
code = code .replace (old_name , new_name )
403
396
return code
404
397
405
-
406
398
def iterreplace (self , code :str ) -> str :
407
399
old_name = '@{}@' .format ('i' )
408
400
new_name = self .name
@@ -425,7 +417,6 @@ def __init__(self, iork:str, lst:list) -> None:
425
417
else :
426
418
d [v .type ] = [v ]
427
419
428
-
429
420
def declare (self ) -> str :
430
421
s = ''
431
422
for t in self .shared .keys ():
@@ -441,7 +432,6 @@ def declare(self) -> str:
441
432
s = s .replace (',;' , ';' )
442
433
return s
443
434
444
-
445
435
def zero (self ) -> str :
446
436
s = ''
447
437
for t in self .shared .keys ():
@@ -452,7 +442,6 @@ def zero(self) -> str:
452
442
s = s + v .zero ()
453
443
return s
454
444
455
-
456
445
def save (self ) -> str :
457
446
s = ''
458
447
for t in self .shared .keys ():
@@ -463,7 +452,6 @@ def save(self) -> str:
463
452
s = s + v .save ()
464
453
return s
465
454
466
-
467
455
def init_exclude (self ) -> str :
468
456
s = ''
469
457
for t in self .shared .keys ():
@@ -474,7 +462,6 @@ def init_exclude(self) -> str:
474
462
s = s + v .init_exclude ()
475
463
return s
476
464
477
-
478
465
def init_block (self ) -> str :
479
466
s = ''
480
467
for t in self .shared .keys ():
@@ -485,15 +472,13 @@ def init_block(self) -> str:
485
472
s = s + v .init_block ()
486
473
return s
487
474
488
-
489
475
def shuffle (self ) -> str :
490
476
s = ''
491
477
for t in self .register .keys ():
492
478
for v in self .register [t ]:
493
479
s = s + v .shuffle ()
494
480
return s
495
481
496
-
497
482
def ikreplace (self , code :str ) -> str :
498
483
for t in self .shared .keys ():
499
484
for v in self .shared [t ]:
@@ -531,7 +516,6 @@ def _func_param(ptype:str, pname:str) -> str:
531
516
else :
532
517
raise ValueError ('Do not know how to parse type: {}' .format (ptype ))
533
518
534
-
535
519
@staticmethod
536
520
def _load_scale_param (ptype :str , stem :str , input :str , separate_scaled_pairwise :bool ) -> str :
537
521
if ptype == 'real_const_array' :
@@ -551,7 +535,7 @@ def _load_scale_param(ptype:str, stem:str, input:str, separate_scaled_pairwise:b
551
535
# dim = match.group(2)
552
536
ss = ptype .split (',' )
553
537
v = ''
554
- for i in range (1 ,len (ss )):
538
+ for i in range (1 , len (ss )):
555
539
idx = ss [i ]
556
540
al = rc_alphabets [idx ]
557
541
if input is None :
@@ -561,7 +545,6 @@ def _load_scale_param(ptype:str, stem:str, input:str, separate_scaled_pairwise:b
561
545
v = v + '{} {}{} = {}[ii][{}];' .format (t , stem , al , input , idx )
562
546
return v
563
547
564
-
565
548
def __init__ (self , config ) -> None :
566
549
self .config = config
567
550
@@ -598,7 +581,6 @@ def _kv(self, k:str):
598
581
else :
599
582
return ''
600
583
601
-
602
584
def cudaReplaceDict (self ) -> dict :
603
585
d = {}
604
586
config = self .config
@@ -679,10 +661,10 @@ def cudaReplaceDict(self) -> dict:
679
661
if kcfg in keys :
680
662
vcfg , decl , zero , total = config [kcfg ], '' , '' , ''
681
663
for t in vcfg :
682
- v1 = v1 + ', CountBuffer restrict {}' .format (t )
683
- decl = decl + 'int {}tl;' .format (t )
684
- zero = zero + '{}tl = 0;' .format (t )
685
- total = total + 'atomic_add({}tl, {}, ithread);' .format (t , t )
664
+ v1 = v1 + ', CountBuffer restrict {}' .format (t )
665
+ decl = decl + 'int {}tl;' .format (t )
666
+ zero = zero + '{}tl = 0;' .format (t )
667
+ total = total + 'atomic_add({}tl, {}, ithread);' .format (t , t )
686
668
v2 = '%s if CONSTEXPR (do_a) {%s}' % (decl , zero )
687
669
v3 = 'if CONSTEXPR (do_a) {%s}' % (total )
688
670
d [k1 ], d [k2 ], d [k3 ] = v1 , v2 , v3
@@ -809,7 +791,7 @@ def cudaReplaceDict(self) -> dict:
809
791
v1 = kvars .ikreplace (v1 )
810
792
v1 = ifrcs .ikreplace (v1 )
811
793
v1 = kfrcs .ikreplace (v1 )
812
- v2 = v1 # in case no scaled pairwise interaction is given
794
+ v2 = v1 # in case no scaled pairwise interaction is given
813
795
kcfg = self .yk_scaled_pairwise
814
796
if kcfg in keys :
815
797
v2 = config [kcfg ]
@@ -840,27 +822,24 @@ def cudaReplaceDict(self) -> dict:
840
822
841
823
return d
842
824
843
-
844
825
@staticmethod
845
826
def version () -> str :
846
827
return '3.1.0'
847
828
848
-
849
829
@staticmethod
850
830
def _replace (s :str , d :dict ) -> str :
851
831
output = s
852
832
for k in d .keys ():
853
833
v = d [k ]
854
- if v == None :
834
+ if v is None :
855
835
v = ''
856
836
output = output .replace (k , v )
857
837
return output
858
838
859
-
860
839
def write (self , output ) -> None :
861
840
d = self .cudaReplaceDict ()
862
841
outstr = '// ck.py Version {}' .format (self .version ())
863
- kernel_num = 21 # default
842
+ kernel_num = 21 # default
864
843
if self .yk_kernel_version_number in self .config .keys ():
865
844
kernel_num = self .config [self .yk_kernel_version_number ]
866
845
if kernel_num == 11 :
@@ -896,7 +875,6 @@ def show_command(argv):
896
875
d2 = os .path .join (d , '../..' )
897
876
d = os .path .abspath (d2 )
898
877
899
-
900
878
yaml_file = argv [1 ]
901
879
with open (yaml_file ) as input_file :
902
880
config = yaml .full_load (input_file )
0 commit comments