29
29
'20' : 'u' , '21' : 'v' , '22' : 'w' , '23' : 'x' , '24' : 'y' ,
30
30
'25' : 'z' }
31
31
32
- rc_kernel2c = '''
32
+ rc_kernel11 = '''
33
+ TEMPLATE_PARAMS __global__
34
+ STATIC_KERNEL void KERNEL_NAME(SINGLE_LOOP_LIMIT_PARAM EXTRA_KERNEL_PARAMS)
35
+ {
36
+ KERNEL_SINGLE_LOOP_BEGIN
37
+ KERNEL_SINGLE_LOOP_CODE
38
+ KERNEL_SINGLE_LOOP_END
39
+ }
40
+ '''
41
+
42
+ rc_kernel23c = '''
33
43
TEMPLATE_PARAMS \
34
44
__global__ \
35
45
void KERNEL_NAMEc( \
69
79
'''
70
80
71
81
72
- rc_kernel2b = '''
82
+ rc_kernel23b = '''
73
83
TEMPLATE_PARAMS \
74
84
__global__ \
75
85
void KERNEL_NAMEb( \
130
140
'''
131
141
132
142
133
- rc_kernel2a = '''
143
+ rc_kernel23a = '''
134
144
TEMPLATE_PARAMS \
135
145
__global__ \
136
146
void KERNEL_NAMEa( \
181
191
'''
182
192
183
193
184
- rc_kernel1 = '''
194
+ rc_kernel21 = '''
185
195
TEMPLATE_PARAMS \
186
196
__global__ \
187
197
void KERNEL_NAME(int n, \
@@ -393,6 +403,13 @@ def ikreplace(self, code:str) -> str:
393
403
return code
394
404
395
405
406
+ def iterreplace (self , code :str ) -> str :
407
+ old_name = '@{}@' .format ('i' )
408
+ new_name = self .name
409
+ code = code .replace (old_name , new_name )
410
+ return code
411
+
412
+
396
413
class VariableDefinitions :
397
414
def __init__ (self , iork :str , lst :list ) -> None :
398
415
self .shared , self .register = {}, {}
@@ -548,9 +565,10 @@ def _load_scale_param(ptype:str, stem:str, input:str, separate_scaled_pairwise:b
548
565
def __init__ (self , config ) -> None :
549
566
self .config = config
550
567
551
- self .yk_split_kernel = 'SPLIT_KERNEL '
568
+ self .yk_kernel_version_number = 'KERNEL_VERSION_NUMBER '
552
569
553
570
self .yk_output_dir = 'OUTPUT_DIR'
571
+ self .yk_kernel_is_static = 'KERNEL_IS_STATIC'
554
572
self .yk_kernel_name = 'KERNEL_NAME'
555
573
self .yk_template_params = 'TEMPLATE_PARAMS'
556
574
self .yk_constexpr_flags = 'CONSTEXPR_FLAGS'
@@ -570,6 +588,9 @@ def __init__(self, config) -> None:
570
588
self .yk_scaled_pairwise = 'SCALED_PAIRWISE_INTERACTION'
571
589
self .yk_full_pairwise = 'FULL_PAIRWISE_INTERACTION'
572
590
591
+ self .yk_single_loop_limit = 'SINGLE_LOOP_LIMIT'
592
+ self .yk_single_loop_iter = 'SINGLE_LOOP_ITER'
593
+ self .yk_single_loop_code = 'SINGLE_LOOP_CODE'
573
594
574
595
def _kv (self , k :str ):
575
596
if k in self .config .keys ():
@@ -588,6 +609,13 @@ def cudaReplaceDict(self) -> dict:
588
609
d [k ] = v
589
610
590
611
# kernel name
612
+ k , v = 'STATIC_KERNEL' , ''
613
+ kcfg , vcfg = self .yk_kernel_is_static , False
614
+ if kcfg in keys :
615
+ vcfg = config [kcfg ]
616
+ if vcfg :
617
+ v = 'static'
618
+ d [k ] = v
591
619
k , v = 'KERNEL_NAME' , self ._kv (self .yk_kernel_name )
592
620
d [k ] = v
593
621
@@ -619,25 +647,29 @@ def cudaReplaceDict(self) -> dict:
619
647
k , v = 'KERNEL_CONSTEXPR_FLAGS' , self ._kv (self .yk_constexpr_flags )
620
648
d [k ] = v
621
649
622
- # i and k declaration
623
- ivars , kvars = VariableDefinitions ('i' , config [self .yk_i_variables ]), VariableDefinitions ('k' , config [self .yk_k_variables ])
624
- ifrcs , kfrcs = VariableDefinitions ('i' , config [self .yk_i_force ]), VariableDefinitions ('k' , config [self .yk_k_force ])
625
- if len (ifrcs .shared .keys ()):
626
- raise ValueError ('I_FORCE cannot be put on shared memory.' )
627
- if len (kfrcs .shared .keys ()):
628
- raise ValueError ('F_FORCE cannot be put on shared memory.' )
629
- k1 , v1 = 'DECLARE_PARAMS_I_AND_K' , ivars .declare () + kvars .declare ()
630
- k2 , v2 = 'DECLARE_FORCE_I_AND_K' , ifrcs .declare () + kfrcs .declare ()
631
- d [k1 ], d [k2 ] = v1 , v2
632
-
633
- # i and k in exclude block
634
- k1 , v1 = 'KERNEL_INIT_EXCLUDE_PARAMS_I_AND_K' , ''
635
- k2 , v2 = 'KERNEL_INIT_PARAMS_I_AND_K' , ''
636
- k3 , v3 = 'KERNEL_SHUFFLE_PARAMS_I' , ''
637
- v1 = v1 + ivars .init_exclude () + kvars .init_exclude ()
638
- v2 = v2 + ivars .init_block () + kvars .init_block ()
639
- v3 = v3 + ivars .shuffle ()
640
- d [k1 ], d [k2 ], d [k3 ] = v1 , v2 , v3
650
+ use_ikvars = False
651
+ if self .yk_i_variables in keys and self .yk_k_variables in keys :
652
+ use_ikvars = True
653
+ if use_ikvars :
654
+ # i and k declaration
655
+ ivars , kvars = VariableDefinitions ('i' , config [self .yk_i_variables ]), VariableDefinitions ('k' , config [self .yk_k_variables ])
656
+ ifrcs , kfrcs = VariableDefinitions ('i' , config [self .yk_i_force ]), VariableDefinitions ('k' , config [self .yk_k_force ])
657
+ if len (ifrcs .shared .keys ()):
658
+ raise ValueError ('I_FORCE cannot be put on shared memory.' )
659
+ if len (kfrcs .shared .keys ()):
660
+ raise ValueError ('F_FORCE cannot be put on shared memory.' )
661
+ k1 , v1 = 'DECLARE_PARAMS_I_AND_K' , ivars .declare () + kvars .declare ()
662
+ k2 , v2 = 'DECLARE_FORCE_I_AND_K' , ifrcs .declare () + kfrcs .declare ()
663
+ d [k1 ], d [k2 ] = v1 , v2
664
+
665
+ # i and k in exclude block
666
+ k1 , v1 = 'KERNEL_INIT_EXCLUDE_PARAMS_I_AND_K' , ''
667
+ k2 , v2 = 'KERNEL_INIT_PARAMS_I_AND_K' , ''
668
+ k3 , v3 = 'KERNEL_SHUFFLE_PARAMS_I' , ''
669
+ v1 = v1 + ivars .init_exclude () + kvars .init_exclude ()
670
+ v2 = v2 + ivars .init_block () + kvars .init_block ()
671
+ v3 = v3 + ivars .shuffle ()
672
+ d [k1 ], d [k2 ], d [k3 ] = v1 , v2 , v3
641
673
642
674
# count
643
675
k1 , v1 = 'COUNT_KERNEL_PARAMS' , ''
@@ -696,39 +728,46 @@ def cudaReplaceDict(self) -> dict:
696
728
v3 = 'if CONSTEXPR (do_v) {%s}' % total
697
729
d [k1 ], d [k2 ], d [k3 ] = v1 , v2 , v3
698
730
699
- # gradient
700
- k , v = 'GRADIENT_KERNEL_PARAMS' , ''
701
- kcfg = self .yk_gradient
702
- if kcfg in keys :
703
- vcfg = config [kcfg ]
704
- for t in vcfg :
705
- v = v + ', grad_prec* restrict {}' .format (t )
706
- k1 , v1 = 'KERNEL_ZERO_LOCAL_FORCE' , ifrcs .zero () + kfrcs .zero ()
707
- k2 , v2 = 'KERNEL_SAVE_LOCAL_FORCE' , ifrcs .save () + kfrcs .save ()
708
- k3 , v3 = 'KERNEL_SHUFFLE_LOCAL_FORCE_I' , ifrcs .shuffle ()
709
- kcfg = self .yk_constexpr_flags
710
- if kcfg in keys :
711
- vcfg = config [kcfg ]
712
- if 'constexpr bool do_g =' in vcfg :
713
- v1 = 'if CONSTEXPR (do_g) {%s}' % v1
714
- v2 = 'if CONSTEXPR (do_g) {%s}' % v2
715
- if v3 != '' :
716
- v3 = 'if CONSTEXPR (do_g) {%s}' % v3
717
- d [k ], d [k1 ], d [k2 ], d [k3 ] = v , v1 , v2 , v3
718
-
719
- # klane -- True only if ivar uses shared memory
720
- k1 , v1 = 'KERNEL_KLANE1' , ''
721
- k2 , v2 = 'KERNEL_KLANE2' , ''
722
- k3 , v3 = 'KERNEL_SCALED_KLANE' , ''
723
- use_klane = False
724
- if len (ivars .shared .keys ()):
725
- use_klane = True
726
- if use_klane :
727
- v1 = 'int klane = srclane + threadIdx.x - ilane;'
728
- v2 = v2 + 'int srclane = (ilane + j) & (WARP_SIZE - 1);'
729
- v2 = v2 + 'int klane = srclane + threadIdx.x - ilane;'
730
- v3 = 'const int klane = threadIdx.x;'
731
- d [k1 ], d [k2 ], d [k3 ] = v1 , v2 , v3
731
+ if use_ikvars :
732
+ # sync warp
733
+ k1 , v1 = 'KERNEL_SYNCWARP' , '__syncwarp();'
734
+ if len (ivars .shared ) == 0 and len (kvars .shared ) == 0 and len (ifrcs .shared ) == 0 and len (kfrcs .shared ) == 0 :
735
+ v1 = ''
736
+ d [k1 ] = v1
737
+
738
+ # gradient
739
+ k , v = 'GRADIENT_KERNEL_PARAMS' , ''
740
+ kcfg = self .yk_gradient
741
+ if kcfg in keys :
742
+ vcfg = config [kcfg ]
743
+ for t in vcfg :
744
+ v = v + ', grad_prec* restrict {}' .format (t )
745
+ k1 , v1 = 'KERNEL_ZERO_LOCAL_FORCE' , ifrcs .zero () + kfrcs .zero ()
746
+ k2 , v2 = 'KERNEL_SAVE_LOCAL_FORCE' , ifrcs .save () + kfrcs .save ()
747
+ k3 , v3 = 'KERNEL_SHUFFLE_LOCAL_FORCE_I' , ifrcs .shuffle ()
748
+ kcfg = self .yk_constexpr_flags
749
+ if kcfg in keys :
750
+ vcfg = config [kcfg ]
751
+ if 'constexpr bool do_g =' in vcfg :
752
+ v1 = 'if CONSTEXPR (do_g) {%s}' % v1
753
+ v2 = 'if CONSTEXPR (do_g) {%s}' % v2
754
+ if v3 != '' :
755
+ v3 = 'if CONSTEXPR (do_g) {%s}' % v3
756
+ d [k ], d [k1 ], d [k2 ], d [k3 ] = v , v1 , v2 , v3
757
+
758
+ # klane -- True only if ivar uses shared memory
759
+ k1 , v1 = 'KERNEL_KLANE1' , ''
760
+ k2 , v2 = 'KERNEL_KLANE2' , ''
761
+ k3 , v3 = 'KERNEL_SCALED_KLANE' , ''
762
+ use_klane = False
763
+ if len (ivars .shared .keys ()):
764
+ use_klane = True
765
+ if use_klane :
766
+ v1 = 'int klane = srclane + threadIdx.x - ilane;'
767
+ v2 = v2 + 'int srclane = (ilane + j) & (WARP_SIZE - 1);'
768
+ v2 = v2 + 'int klane = srclane + threadIdx.x - ilane;'
769
+ v3 = 'const int klane = threadIdx.x;'
770
+ d [k1 ], d [k2 ], d [k3 ] = v1 , v2 , v3
732
771
733
772
# exclude
734
773
k1 , v1 = 'EXCLUDE_INFO_KERNEL_PARAMS' , ''
@@ -780,18 +819,31 @@ def cudaReplaceDict(self) -> dict:
780
819
v2 = kfrcs .ikreplace (v2 )
781
820
d [k1 ], d [k2 ] = v1 , v2
782
821
783
- # sync warp
784
- k1 , v1 = 'KERNEL_SYNCWARP' , '__syncwarp();'
785
- if len (ivars .shared ) == 0 and len (kvars .shared ) == 0 and len (ifrcs .shared ) == 0 and len (kfrcs .shared ) == 0 :
786
- v1 = ''
787
- d [k1 ] = v1
822
+ # single loop
823
+ k0 , v0 = 'SINGLE_LOOP_LIMIT_PARAM' , ''
824
+ k1 , v1 = 'KERNEL_SINGLE_LOOP_CODE' , ''
825
+ k2 , v2 = 'KERNEL_SINGLE_LOOP_BEGIN' , ''
826
+ k3 , v3 = 'KERNEL_SINGLE_LOOP_END' , ''
827
+ kcfg = self .yk_single_loop_code
828
+ if kcfg in keys :
829
+ v0 = config [self .yk_single_loop_limit ]
830
+ v1 = config [kcfg ]
831
+ sl_limit , sl_iter = config [self .yk_single_loop_limit ], config [self .yk_single_loop_iter ]
832
+ sl_limit = 'register ' + sl_limit + ' from:dummy'
833
+ sl_iter = 'register ' + sl_iter + ' from:dummy'
834
+ var_limit = Variable ('k' , sl_limit )
835
+ var_iter = Variable ('k' , sl_iter )
836
+ v2 = 'for(%s %s = ITHREAD; %s < %s; %s += STRIDE) {' % (var_iter .type , var_iter .name , var_iter .name , var_limit .name , var_iter .name )
837
+ v3 = '}'
838
+ v1 = var_iter .iterreplace (v1 )
839
+ d [k0 ], d [k1 ], d [k2 ], d [k3 ] = v0 , v1 , v2 , v3
788
840
789
841
return d
790
842
791
843
792
844
@staticmethod
793
845
def version () -> str :
794
- return '3.0.2 '
846
+ return '3.1.0 '
795
847
796
848
797
849
@staticmethod
@@ -808,13 +860,18 @@ def _replace(s:str, d:dict) -> str:
808
860
def write (self , output ) -> None :
809
861
d = self .cudaReplaceDict ()
810
862
outstr = '// ck.py Version {}' .format (self .version ())
811
- if self .yk_split_kernel in self .config .keys ():
863
+ kernel_num = 21 # default
864
+ if self .yk_kernel_version_number in self .config .keys ():
865
+ kernel_num = self .config [self .yk_kernel_version_number ]
866
+ if kernel_num == 11 :
867
+ outstr = outstr + self ._replace (rc_kernel11 , d )
868
+ elif kernel_num == 23 :
812
869
if self .yk_scale_1x_type in self .config .keys ():
813
- outstr = outstr + self ._replace (rc_kernel2c , d )
814
- outstr = outstr + self ._replace (rc_kernel2b , d )
815
- outstr = outstr + self ._replace (rc_kernel2a , d )
870
+ outstr = outstr + self ._replace (rc_kernel23c , d )
871
+ outstr = outstr + self ._replace (rc_kernel23b , d )
872
+ outstr = outstr + self ._replace (rc_kernel23a , d )
816
873
else :
817
- outstr = outstr + self ._replace (rc_kernel1 , d )
874
+ outstr = outstr + self ._replace (rc_kernel21 , d )
818
875
print (outstr , file = output )
819
876
820
877
0 commit comments