Skip to content

Commit 6b46ce2

Browse files
committed
start moving single loop kernels to separate files
1 parent 500d269 commit 6b46ce2

File tree

6 files changed

+66
-91
lines changed

6 files changed

+66
-91
lines changed

ext/ext/ck3.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ class Variable:
307307
def __init__(self, iork:str, defstr:str) -> None:
308308
vs = defstr.split()
309309
self.iork = iork
310-
self.location = vs[0] # shared / register
310+
self.location = vs[0] # shared / register
311311
self.type = vs[1]
312312
self.name = vs[2]
313313
self.readfrom, self.addto, self.onlyif = None, None, None
@@ -323,7 +323,6 @@ def __init__(self, iork:str, defstr:str) -> None:
323323
s3 = s2.replace('onlyif:', '')
324324
self.onlyif = s3
325325

326-
327326
@staticmethod
328327
def _get_src(src, index) -> str:
329328
if ',' in src:
@@ -332,18 +331,16 @@ def _get_src(src, index) -> str:
332331
else:
333332
return '{}[{}]'.format(src, index)
334333

335-
336334
def zero(self) -> str:
337335
v1 = ''
338336
if self.location == 'shared':
339337
v1 = '{}[threadIdx.x] = 0;'.format(self.name)
340338
else:
341339
v1 = '{} = 0;'.format(self.name)
342-
if self.onlyif != None:
343-
v1 = 'if CONSTEXPR ({}) {}'.format(self.onlyif, v1)
340+
if self.onlyif is not None:
341+
v1 = 'if CONSTEXPR ({}) {}'.format(self.onlyif, v1)
344342
return v1
345343

346-
347344
def save(self) -> str:
348345
dst = self.addto
349346
v1, suffix = '', ''
@@ -354,11 +351,10 @@ def save(self) -> str:
354351
else:
355352
vs = dst.split(',')
356353
v1 = 'atomic_add({}{}, &{}[{}][{}]);'.format(self.name, suffix, vs[0], self.iork, vs[1])
357-
if self.onlyif != None:
354+
if self.onlyif is not None:
358355
v1 = 'if CONSTEXPR ({}) {}'.format(self.onlyif, v1)
359356
return v1
360357

361-
362358
def init_exclude(self) -> str:
363359
rhs = self._get_src(self.readfrom, self.iork)
364360
if self.location == 'shared':
@@ -369,7 +365,6 @@ def init_exclude(self) -> str:
369365
else:
370366
return '{} = {};'.format(self.name, rhs)
371367

372-
373368
def init_block(self) -> str:
374369
if self.readfrom in ['x', 'y', 'z']:
375370
if self.location == 'shared':
@@ -382,14 +377,12 @@ def init_block(self) -> str:
382377
else:
383378
return '{} = {};'.format(self.name, self._get_src(self.readfrom, self.iork))
384379

385-
386380
def shuffle(self) -> str:
387381
if self.location == 'register':
388382
return '{0:} = __shfl_sync(ALL_LANES, {0:}, ilane + 1);'.format(self.name)
389383
else:
390384
raise ValueError('Cannot shuffle variables in the shared memory.')
391385

392-
393386
def ikreplace(self, code:str) -> str:
394387
old_name = '@{}@'.format(self.name)
395388
new_name = self.name
@@ -402,7 +395,6 @@ def ikreplace(self, code:str) -> str:
402395
code = code.replace(old_name, new_name)
403396
return code
404397

405-
406398
def iterreplace(self, code:str) -> str:
407399
old_name = '@{}@'.format('i')
408400
new_name = self.name
@@ -425,7 +417,6 @@ def __init__(self, iork:str, lst:list) -> None:
425417
else:
426418
d[v.type] = [v]
427419

428-
429420
def declare(self) -> str:
430421
s = ''
431422
for t in self.shared.keys():
@@ -441,7 +432,6 @@ def declare(self) -> str:
441432
s = s.replace(',;', ';')
442433
return s
443434

444-
445435
def zero(self) -> str:
446436
s = ''
447437
for t in self.shared.keys():
@@ -452,7 +442,6 @@ def zero(self) -> str:
452442
s = s + v.zero()
453443
return s
454444

455-
456445
def save(self) -> str:
457446
s = ''
458447
for t in self.shared.keys():
@@ -463,7 +452,6 @@ def save(self) -> str:
463452
s = s + v.save()
464453
return s
465454

466-
467455
def init_exclude(self) -> str:
468456
s = ''
469457
for t in self.shared.keys():
@@ -474,7 +462,6 @@ def init_exclude(self) -> str:
474462
s = s + v.init_exclude()
475463
return s
476464

477-
478465
def init_block(self) -> str:
479466
s = ''
480467
for t in self.shared.keys():
@@ -485,15 +472,13 @@ def init_block(self) -> str:
485472
s = s + v.init_block()
486473
return s
487474

488-
489475
def shuffle(self) -> str:
490476
s = ''
491477
for t in self.register.keys():
492478
for v in self.register[t]:
493479
s = s + v.shuffle()
494480
return s
495481

496-
497482
def ikreplace(self, code:str) -> str:
498483
for t in self.shared.keys():
499484
for v in self.shared[t]:
@@ -531,7 +516,6 @@ def _func_param(ptype:str, pname:str) -> str:
531516
else:
532517
raise ValueError('Do not know how to parse type: {}'.format(ptype))
533518

534-
535519
@staticmethod
536520
def _load_scale_param(ptype:str, stem:str, input:str, separate_scaled_pairwise:bool) -> str:
537521
if ptype == 'real_const_array':
@@ -551,7 +535,7 @@ def _load_scale_param(ptype:str, stem:str, input:str, separate_scaled_pairwise:b
551535
# dim = match.group(2)
552536
ss = ptype.split(',')
553537
v = ''
554-
for i in range(1,len(ss)):
538+
for i in range(1, len(ss)):
555539
idx = ss[i]
556540
al = rc_alphabets[idx]
557541
if input is None:
@@ -561,7 +545,6 @@ def _load_scale_param(ptype:str, stem:str, input:str, separate_scaled_pairwise:b
561545
v = v + '{} {}{} = {}[ii][{}];'.format(t, stem, al, input, idx)
562546
return v
563547

564-
565548
def __init__(self, config) -> None:
566549
self.config = config
567550

@@ -598,7 +581,6 @@ def _kv(self, k:str):
598581
else:
599582
return ''
600583

601-
602584
def cudaReplaceDict(self) -> dict:
603585
d = {}
604586
config = self.config
@@ -679,10 +661,10 @@ def cudaReplaceDict(self) -> dict:
679661
if kcfg in keys:
680662
vcfg, decl, zero, total = config[kcfg], '', '', ''
681663
for t in vcfg:
682-
v1 = v1 + ', CountBuffer restrict {}'.format(t)
683-
decl = decl + 'int {}tl;'.format(t)
684-
zero = zero + '{}tl = 0;'.format(t)
685-
total = total + 'atomic_add({}tl, {}, ithread);'.format(t, t)
664+
v1 = v1 + ', CountBuffer restrict {}'.format(t)
665+
decl = decl + 'int {}tl;'.format(t)
666+
zero = zero + '{}tl = 0;'.format(t)
667+
total = total + 'atomic_add({}tl, {}, ithread);'.format(t, t)
686668
v2 = '%s if CONSTEXPR (do_a) {%s}' % (decl, zero)
687669
v3 = 'if CONSTEXPR (do_a) {%s}' % (total)
688670
d[k1], d[k2], d[k3] = v1, v2, v3
@@ -809,7 +791,7 @@ def cudaReplaceDict(self) -> dict:
809791
v1 = kvars.ikreplace(v1)
810792
v1 = ifrcs.ikreplace(v1)
811793
v1 = kfrcs.ikreplace(v1)
812-
v2 = v1 # in case no scaled pairwise interaction is given
794+
v2 = v1 # in case no scaled pairwise interaction is given
813795
kcfg = self.yk_scaled_pairwise
814796
if kcfg in keys:
815797
v2 = config[kcfg]
@@ -840,27 +822,24 @@ def cudaReplaceDict(self) -> dict:
840822

841823
return d
842824

843-
844825
@staticmethod
845826
def version() -> str:
846827
return '3.1.0'
847828

848-
849829
@staticmethod
850830
def _replace(s:str, d:dict) -> str:
851831
output = s
852832
for k in d.keys():
853833
v = d[k]
854-
if v == None:
834+
if v is None:
855835
v = ''
856836
output = output.replace(k, v)
857837
return output
858838

859-
860839
def write(self, output) -> None:
861840
d = self.cudaReplaceDict()
862841
outstr = '// ck.py Version {}'.format(self.version())
863-
kernel_num = 21 # default
842+
kernel_num = 21 # default
864843
if self.yk_kernel_version_number in self.config.keys():
865844
kernel_num = self.config[self.yk_kernel_version_number]
866845
if kernel_num == 11:
@@ -896,7 +875,6 @@ def show_command(argv):
896875
d2 = os.path.join(d, '../..')
897876
d = os.path.abspath(d2)
898877

899-
900878
yaml_file = argv[1]
901879
with open(yaml_file) as input_file:
902880
config = yaml.full_load(input_file)

ext/ext/y3/mdPos_cu1.yaml renamed to ext/ext/y31/mdPos_cu1.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
KERNEL_VERSION_NUMBER: 11
2-
32
KERNEL_IS_STATIC: True
3+
4+
OUTPUT_DIR: src/cu
45
KERNEL_NAME: mdPos_cu1
56
SINGLE_LOOP_LIMIT: int n
67
SINGLE_LOOP_ITER: int i

src/cu/mdPos_cu1.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// ck.py Version 3.1.0
2+
__global__
3+
static void mdPos_cu1(int n, time_prec dt, pos_prec* restrict qx, pos_prec* restrict qy, pos_prec* restrict qz,
4+
const vel_prec* restrict vlx, const vel_prec* restrict vly, const vel_prec* restrict vlz)
5+
{
6+
for (int i = ITHREAD; i < n; i += STRIDE) {
7+
qx[i] += dt * vlx[i];
8+
qy[i] += dt * vly[i];
9+
qz[i] += dt * vlz[i];
10+
}
11+
}

src/cu/mdpq.cu

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,18 @@
44
#include <tinker/detail/units.hh>
55

66
namespace tinker {
7-
__global__
8-
static void mdPos_cu1(int n, time_prec dt, //
9-
pos_prec* restrict qx, pos_prec* restrict qy, pos_prec* restrict qz, //
10-
const vel_prec* restrict vlx, const vel_prec* restrict vly, const vel_prec* restrict vlz)
11-
{
12-
for (int i = ITHREAD; i < n; i += STRIDE) {
13-
qx[i] += dt * vlx[i];
14-
qy[i] += dt * vly[i];
15-
qz[i] += dt * vlz[i];
16-
}
17-
}
18-
19-
void mdPos_cu(time_prec dt, pos_prec* qx, pos_prec* qy, pos_prec* qz, const vel_prec* vlx,
20-
const vel_prec* vly, const vel_prec* vlz)
7+
#include "mdPos_cu1.cc"
8+
void mdPos_cu(time_prec dt, pos_prec* qx, pos_prec* qy, pos_prec* qz, const vel_prec* vlx, const vel_prec* vly,
9+
const vel_prec* vlz)
2110
{
2211
launch_k1s(g::s0, n, mdPos_cu1, //
2312
n, dt, qx, qy, qz, vlx, vly, vlz);
2413
}
2514

2615
__global__
2716
static void mdPosAxbv_cu1(int n, pos_prec sa, pos_prec sb, //
28-
pos_prec* restrict xpos, pos_prec* restrict ypos, pos_prec* restrict zpos,
29-
const vel_prec* restrict vx, const vel_prec* restrict vy, const vel_prec* restrict vz)
17+
pos_prec* restrict xpos, pos_prec* restrict ypos, pos_prec* restrict zpos, const vel_prec* restrict vx,
18+
const vel_prec* restrict vy, const vel_prec* restrict vz)
3019
{
3120
for (int i = ITHREAD; i < n; i += STRIDE) {
3221
xpos[i] = sa * xpos[i] + sb * vx[i];
@@ -45,8 +34,8 @@ __global__
4534
static void mdPosAxbvAn_cu1(int n, //
4635
double3 a0, double3 a1, double3 a2, //
4736
double3 b0, double3 b1, double3 b2, //
48-
pos_prec* restrict xpos, pos_prec* restrict ypos, pos_prec* restrict zpos,
49-
const vel_prec* restrict vx, const vel_prec* restrict vy, const vel_prec* restrict vz)
37+
pos_prec* restrict xpos, pos_prec* restrict ypos, pos_prec* restrict zpos, const vel_prec* restrict vx,
38+
const vel_prec* restrict vy, const vel_prec* restrict vz)
5039
{
5140
pos_prec a00 = a0.x, a01 = a0.y, a02 = a0.z;
5241
pos_prec a10 = a1.x, a11 = a1.y, a12 = a1.z;
@@ -94,20 +83,19 @@ static void mdVel_cu1(int n, time_prec dt, const double* restrict massinv, //
9483
}
9584
}
9685

97-
void mdVel_cu(time_prec dt, vel_prec* vlx, vel_prec* vly, vel_prec* vlz, const grad_prec* grx,
98-
const grad_prec* gry, const grad_prec* grz)
86+
void mdVel_cu(time_prec dt, vel_prec* vlx, vel_prec* vly, vel_prec* vlz, const grad_prec* grx, const grad_prec* gry,
87+
const grad_prec* grz)
9988
{
10089
launch_k1s(g::s0, n, mdVel_cu1, //
10190
n, dt, massinv, vlx, vly, vlz, grx, gry, grz);
10291
}
10392

10493
__global__
105-
static void mdVel2_cu1(int n, const double* restrict massinv, vel_prec* restrict vlx,
106-
vel_prec* restrict vly, vel_prec* restrict vlz, //
94+
static void mdVel2_cu1(int n, const double* restrict massinv, vel_prec* restrict vlx, vel_prec* restrict vly,
95+
vel_prec* restrict vlz, //
10796
time_prec dt1, const grad_prec* restrict grx1, const grad_prec* restrict gry1,
10897
const grad_prec* restrict grz1, //
109-
time_prec dt2, const grad_prec* restrict grx2, const grad_prec* restrict gry2,
110-
const grad_prec* restrict grz2)
98+
time_prec dt2, const grad_prec* restrict grx2, const grad_prec* restrict gry2, const grad_prec* restrict grz2)
11199
{
112100
const vel_prec ekcal = units::ekcal;
113101
for (int i = ITHREAD; i < n; i += STRIDE) {
@@ -124,8 +112,8 @@ static void mdVel2_cu1(int n, const double* restrict massinv, vel_prec* restrict
124112
}
125113
}
126114

127-
void mdVel2_cu(time_prec dt, const grad_prec* grx, const grad_prec* gry, const grad_prec* grz,
128-
time_prec dt2, const grad_prec* grx2, const grad_prec* gry2, const grad_prec* grz2)
115+
void mdVel2_cu(time_prec dt, const grad_prec* grx, const grad_prec* gry, const grad_prec* grz, time_prec dt2,
116+
const grad_prec* grx2, const grad_prec* gry2, const grad_prec* grz2)
129117
{
130118
launch_k1s(g::s0, n, mdVel2_cu1, //
131119
n, massinv, vx, vy, vz, //

0 commit comments

Comments
 (0)