@@ -3,6 +3,9 @@ import base/metaUtils
33import base/ omp
44
55{.pragma : omp, header :" omp.h" .}
6+ {.passC :" -fcf-protection=none -no-pie -fno-stack-protector" .}
7+ {.passL :" -fcf-protection=none -no-pie -fno-stack-protector" .}
8+
69template mkMemoryPragma * :untyped =
710 {.pragma : restrict, codegenDecl : " $# __restrict__ $#" .}
811 {.pragma : aligned, codegenDecl : " $# $# __attribute__((aligned))" .}
@@ -34,7 +37,7 @@ proc alignatImpl(n:NimNode, byte:int): NimNode =
3437 for c in n:
3538 result .add c.alignatImpl byte
3639macro alignat * (byte:static [int ], n:untyped ): untyped =
37- if byte notin { 1 ,2 ,4 ,8 ,16 ,32 ,64 ,128 ,256 } :
40+ if byte notin [ 1 ,2 ,4 ,8 ,16 ,32 ,64 ,128 ,256 ] :
3841 error (" macro alignat: unsupported alignment: " & $ byte , n)
3942 # echo "alignatImpl ", byte
4043 # echo n.treerepr
@@ -61,11 +64,11 @@ proc omp_get_initial_device*: cint {.omp.}
6164proc omp_get_num_teams * : cint {.omp .}
6265proc omp_get_team_num * : cint {.omp .}
6366
64- template omp_target_alloc * (size: int ): pointer =
67+ template omp_target_alloc * (size: SomeNumber ): pointer =
6568 omp_target_alloc (csize_t size, omp_get_default_device ())
6669template omp_target_memcpy_tocpu * (dst: pointer , src: pointer ; length: csize_t ): cint =
6770 omp_target_memcpy (dst, src, length, 0 , 0 , omp_get_initial_device (), omp_get_default_device ())
68- template omp_target_memcpy_togpu * (dst: pointer , src: pointer ; length: int ): cint =
71+ template omp_target_memcpy_togpu * (dst: pointer , src: pointer ; length: csize_t ): cint =
6972 omp_target_memcpy (dst, src, csize_t length, 0 , 0 , omp_get_default_device (), omp_get_initial_device ())
7073template omp_target_free * (device_ptr: pointer ) =
7174 omp_target_free (device_ptr, omp_get_default_device ())
@@ -130,7 +133,8 @@ proc prepareVars(n:NimNode):seq[NimNode] =
130133 nnkWhileStmt, nnkForStmt} + RoutineNodes :
131134 # New lexical scope
132135 newscope = true
133- ignoreStack.add newPar ()
136+ # ignoreStack.add newPar()
137+ ignoreStack.add newNimNode (nnkTupleConstr)
134138 for i in 0 ..< n.len:
135139 # echo "### ",n[i].lisprepr
136140 case n[i].kind
@@ -176,7 +180,8 @@ proc prepareVars(n:NimNode):seq[NimNode] =
176180 let np = gensym (nsklet, " gpu_ptr_" & $ n[i])
177181 ignoreStack[0 ].add nv
178182 ignoreStack[0 ].add np
179- openvars.add newpar (n[i], nv, np)
183+ # openvars.add newpar(n[i], nv, np)
184+ openvars.add newNimNode (nnkTupleConstr).add (n[i], nv, np)
180185 n[i] = newcall (" gpuVarPtr" ,nv,np)
181186 else :
182187 discard
@@ -199,6 +204,7 @@ proc genCpuPrepare(n:seq[NimNode]):NimNode =
199204 var v{.noinit .}:OffloadDummy [typeof (x)]
200205 result = newstmtlist ()
201206 for c in n:
207+ # echo c.treerepr
202208 result .add getast r (c[0 ],c[1 ],c[2 ])
203209proc genGpuPrepare (n:seq [NimNode ]):NimNode =
204210 template r (x,v,p:untyped ):untyped =
@@ -214,31 +220,55 @@ proc genCpuFinalize(n:seq[NimNode]):NimNode =
214220 result = newstmtlist ()
215221 for c in n:
216222 result .add getast r (c[0 ],c[1 ],c[2 ])
217- proc declarePtrString (n:seq [NimNode ]):NimNode =
218- template res (ptrlist:untyped ):untyped =
219- const s = ptrlist
220- when s.len == 0 : " " else : " is_device_ptr(" & s[0 ..^ 2 ] & " )"
221- template varname (x, xp:untyped ):untyped =
222- mixin offloadPtr
223- when compiles (offloadPtr (x)): xp& " , " else : " "
224- var ps = newlit " "
223+ # proc declarePtrString(n:seq[NimNode]):NimNode =
224+ # template res(ptrlist:untyped):untyped =
225+ # const s = ptrlist
226+ # when s.len == 0: "" else: "is_device_ptr(" & s[0..^2] & ")"
227+ # template varname(x, xp:untyped):untyped =
228+ # mixin offloadPtr
229+ # when compiles(offloadPtr(x)): xp&"," else: ""
230+ # var ps = newlit""
231+ # for c in n:
232+ # ps = infix(getast varname(c[0], $c[2]), "&", ps)
233+ # result = getast res(ps)
234+ proc declarePtrTuple (n:seq [NimNode ]):NimNode =
235+ mixin offloadPtr
236+ var ps = newNimNode (nnkTupleConstr)
237+ ps.add newLit " is_device_ptr("
225238 for c in n:
226- ps = infix (getast varname (c[0 ], $ c[2 ]), " &" , ps)
227- result = getast res (ps)
239+ when compiles (offloadPtr (c[0 ])):
240+ ps.add c[2 ]
241+ if ps.len == 1 :
242+ result = newNimNode (nnkTupleConstr)
243+ else :
244+ ps.add newLit " )"
245+ result = ps
246+ # echo result.treerepr
228247
229- macro isDevicePtr (x: typed ): untyped =
230- let n = $ x
231- result = newLit (" is_device_ptr(" & n& " ) " )
248+ # macro isDevicePtr(x: typed): untyped =
249+ # let n = $x
250+ # result = newLit(" is_device_ptr("&n&")")
232251
233- macro useDevicePtr (x: typed ): untyped =
234- let n = $ x
235- let p = newLit (" #pragma omp target data use_device_ptr(" & n& " ) " )
236- result = quote do :
237- {.emit : `p`.}
252+ # macro useDevicePtr(x: auto): auto =
253+ # echo x.treerepr
254+ # let n = x.strVal
255+ # echo "useDevicePtr: ", n
256+ # let p = newLit("#pragma omp target data use_device_ptr("&n&")")
257+ # result = quote do:
258+ # {.emit: ["#pragma omp target data use_device_ptr(",`x`,")"].}
238259
239- macro mapto (x: typed ): untyped =
240- let n = $ x
241- result = newLit (" map(to:" & n& " ) " )
260+ # macro getrepr(x: untyped): auto =
261+ # echo x.treerepr
262+ # result = x
263+
264+ template useDevicePtr (x: auto ) =
265+ # getrepr:
266+ {.emit : [" #pragma omp target data use_device_ptr(" ,x," )" ].}
267+
268+ # macro mapto(x: typed): untyped =
269+ # let n = $x
270+ # result = newLit(" map(to:"&n&")")
271+ # macro mapto(x: typed): untyped =
242272
243273macro onGpu * (body: untyped ): untyped =
244274 # the architecture for cpugpuarray requires us replace body before it gets expanded, so we require untyped.
@@ -248,8 +278,8 @@ macro onGpu*(body: untyped): untyped =
248278 {.push stacktrace : off .}
249279 proc gpuProc {.gensym .} =
250280 cpuPrepare # a let section declare and save device pointers
251- const isDevicePtrList = devicePtrDeclare # is_device_ptr(ptrList) in string
252- ompBlock (" target teams " & isDevicePtrList ):
281+ # const isDevicePtrList = devicePtrDeclare # is_device_ptr(ptrList) in string
282+ ompBlock2 (" target teams " , devicePtrDeclare ):
253283 openmpDefs:
254284 gpuPrepare
255285 body
@@ -260,7 +290,7 @@ macro onGpu*(body: untyped): untyped =
260290 cpuPrepare = genCpuPrepare v
261291 gpuPrepare = genGpuPrepare v
262292 cpuFinalize = genCpuFinalize v
263- isDevicePtrs = declarePtrString v
293+ isDevicePtrs = declarePtrTuple v
264294 result = getast (target (cpuPrepare, gpuPrepare, cpuFinalize, isDevicePtrs, body))
265295 # echo result.repr
266296
@@ -306,10 +336,10 @@ proc identStr(n:NimNode):string =
306336 result = n.repr
307337 for i in 0 ..< result .len:
308338 if result [i] in {'.' ,'[' ,']' ,':' }: result [i] = '_'
309- proc isIndex (n,i:NimNode ):bool =
310- result = n.eqident i
311- if n.kind == nnkHiddenStdConv:
312- result = n[1 ].eqident i
339+ # proc isIndex(n,i:NimNode):bool =
340+ # result = n.eqident i
341+ # if n.kind == nnkHiddenStdConv:
342+ # result = n[1].eqident i
313343macro simdForImpl (n:typed ):untyped =
314344 proc getIndexedPtrs (n,i:NimNode ):(NimNode ,seq [NimNode ]) =
315345 # echo "### getIndexedPtrs: ", i.repr
@@ -323,7 +353,8 @@ macro simdForImpl(n:typed):untyped =
323353 break
324354 if m < 0 :
325355 let v = gensym (nskVar, n.cleanAst.identStr)
326- ptrs.add newPar (v, n)
356+ # ptrs.add newPar(v, n)
357+ ptrs.add newNimNode (nnkTupleConstr).add (v, n)
327358 return v
328359 else :
329360 return ptrs[m][0 ]
@@ -437,20 +468,19 @@ when isMainModule:
437468 useDevicePtr (y)
438469 discard omp_target_memcpy_togpu (y, addr x, sizeof (float32 ))
439470 # ompBlock("target teams"&isDevicePtr(x)):
440- ompBlock (" target teams" & mapto (x)):
471+ # ompBlock("target teams"&mapto(x)):
472+ ompBlock2 (" target teams" , " map(to:" , x, " )" ):
441473 {.emit :" #pragma omp parallel" .}
442474 {.emit :" for(int ii=0; ii<1; ii++)" .}
443475 block :
444476 x = 1.0
445477
446- macro dump (n:typed ):typed =
447- echo n.repr
448- n
449- #[
450- dump:
451- onGpu:
452- let i = getThreadNum()
453- if i < n:
454- c[i] = a[i] + b[i]
455- ]#
478+ # macro dump(n:auto):auto =
479+ # echo n.repr
480+ # n
481+ onGpu:
482+ let i = getThreadNum ()
483+ if i < n:
484+ c[i] = a[i] + b[i]
485+
456486 test ()
0 commit comments