11import macros
22import base/ metaUtils
33import sycl
4- setupSycl ()
4+ # setupSycl()
55
6+ let dev = defaultDevice ()
7+ let q = dev.queue
68
79proc alignatImpl (n:NimNode , byte:int ): NimNode =
810 result = n.copyNimNode
@@ -21,7 +23,7 @@ proc alignatImpl(n:NimNode, byte:int): NimNode =
2123 for c in n:
2224 result .add c.alignatImpl byte
2325macro alignat * (byte:static [int ], n:untyped ): untyped =
24- if byte notin { 1 ,2 ,4 ,8 ,16 ,32 ,64 ,128 ,256 } :
26+ if byte notin [ 1 ,2 ,4 ,8 ,16 ,32 ,64 ,128 ,256 ] :
2527 error (" macro alignat: unsupported alignment: " & $ byte , n)
2628 # echo "alignatImpl ", byte
2729 # echo n.treerepr
@@ -79,9 +81,10 @@ template dataAddr*(x: typed): pointer =
7981 # else: x
8082
8183template openmpDefs (body: untyped ): untyped =
84+ setupSycl ()
8285 var item {.item1 .}: Item1
83- template getThreadNum : untyped = item[]
84- template getNumThreads : untyped = item.getRange
86+ template getThreadNum : untyped {. used .} = item[]
87+ template getNumThreads : untyped {. used .} = item.getRange
8588 {.emit :[" #define nimZeroMem(b,len) memset((b),0,(len))" ].}
8689 # inlineProcs:
8790 body
@@ -107,7 +110,8 @@ proc prepareVars(n:NimNode):seq[NimNode] =
107110 nnkWhileStmt, nnkForStmt} + RoutineNodes :
108111 # New lexical scope
109112 newscope = true
110- ignoreStack.add newPar ()
113+ # ignoreStack.add newPar()
114+ ignoreStack.add newNimNode (nnkTupleConstr)
111115 if n.kind == nnkForStmt:
112116 ignoreStack[^ 1 ].add n[0 ]
113117 for i in 0 ..< n.len:
@@ -155,7 +159,8 @@ proc prepareVars(n:NimNode):seq[NimNode] =
155159 let np = gensym (nsklet, " gpu_ptr_" & $ n[i])
156160 ignoreStack[0 ].add nv
157161 ignoreStack[0 ].add np
158- openvars.add newpar (n[i], nv, np)
162+ # openvars.add newpar(n[i], nv, np)
163+ openvars.add newNimNode (nnkTupleConstr).add (n[i], nv, np)
159164 n[i] = newcall (" gpuVarPtr" ,nv,np)
160165 else :
161166 discard
@@ -179,7 +184,7 @@ proc genCpuPrepare(n:seq[NimNode]):NimNode =
179184 result = newstmtlist ()
180185 for c in n:
181186 result .add getast r (c[0 ],c[1 ],c[2 ])
182- echo result .repr
187+ # echo result.repr
183188proc genGpuPrepare (n:seq [NimNode ]):NimNode =
184189 template r (x,v,p:untyped ):untyped =
185190 mixin gpuPrepareOffload, rungpuPrepareOffload
@@ -206,9 +211,9 @@ proc declarePtrString(n:seq[NimNode]):NimNode =
206211 ps = infix (getast varname (c[0 ], $ c[2 ]), " &" , ps)
207212 result = getast res (ps)
208213
209- macro onGpu * (q: Queue , body: untyped ): untyped =
214+ macro onGpu * (q: Queue , body: untyped ): auto =
210215 # the architecture for cpugpuarray requires us replace body before it gets expanded, so we require untyped.
211- template target (cpuPrepare, gpuPrepare, cpuFinalize, devicePtrDeclare, body: untyped ): untyped =
216+ template target (q, cpuPrepare, gpuPrepare, cpuFinalize, devicePtrDeclare, body: untyped ): untyped =
212217 mixin hasGpuPtr, requireGpuMem
213218 {.push checks : off .}
214219 {.push stacktrace : off .}
@@ -222,6 +227,7 @@ macro onGpu*(q: Queue, body: untyped): untyped =
222227 openmpDefs:
223228 # gpuPrepare
224229 body
230+ q.wait
225231 cpuFinalize
226232 gpuProc ()
227233 let
@@ -230,20 +236,29 @@ macro onGpu*(q: Queue, body: untyped): untyped =
230236 gpuPrepare = genGpuPrepare v
231237 cpuFinalize = genCpuFinalize v
232238 isDevicePtrs = declarePtrString v
233- result = getast (target (cpuPrepare, gpuPrepare, cpuFinalize, isDevicePtrs, body))
239+ result = getast (target (q, cpuPrepare, gpuPrepare, cpuFinalize, isDevicePtrs, body))
234240 # echo result.repr
235241
236242# XXX fix the following
237- template onGpu * (totalNumThreads, body: untyped ): untyped = onGpu (body)
238- template onGpu * (totalNumThreads, numThreadsPerTeam, body: untyped ): untyped = onGpu (body)
243+ template onGpu * (body: untyped ) = onGpu (q, body)
244+ # template onGpu*(totalNumThreads, body: untyped): untyped = onGpu(body)
245+ # template onGpu*(totalNumThreads, numThreadsPerTeam, body: untyped): untyped = onGpu(body)
239246
240- template offloadUseVar * (x:SomeNumber ):bool = true
241- template offloadUsePtr * (x:SomeNumber ):bool = false
247+ # template offloadUseVar*(x:SomeNumber):bool = true
248+ # template offloadUsePtr*(x:SomeNumber):bool = false
242249template rungpuPrepareOffload * (x:SomeNumber ):bool = false
243- template runcpuFinalizeOffload * (x:SomeNumber ):bool = false
244- template gpuVarPtr * (v:SomeNumber ,p:untyped ):untyped = v
250+ # template runcpuFinalizeOffload*(x:SomeNumber):bool = false
251+ # template gpuVarPtr*(v:SomeNumber,p:untyped):untyped = v
245252template offloadVar * (x:SomeNumber ,p:untyped ):untyped = x
246253
254+ template runcpuFinalizeOffload * (x:SomeNumber ):bool = true
255+ template offloadUseVar * (x:SomeNumber ):bool = false
256+ template offloadUsePtr * (x:SomeNumber ):bool = true
257+ template gpuVarPtr * (v:untyped ,p:ptr SomeNumber ):untyped = p[]
258+ template offloadPtr * (x:SomeNumber ):untyped = unsafeAddr x
259+ template cpuFinalizeOffload * (x:SomeNumber ,v,p:untyped ) =
260+ x = p[]
261+
247262template toUArray (a:untyped ):untyped = cast [ptr UncheckedArray [typeof (a [0 ])]](a[0 ].unsafeaddr)
248263proc cleanAst (n:NimNode ):NimNode =
249264 if n.kind in {nnkHiddenDeref,nnkHiddenCallConv,nnkHiddenStdConv}:
@@ -256,10 +271,10 @@ proc identStr(n:NimNode):string =
256271 result = n.repr
257272 for i in 0 ..< result .len:
258273 if result [i] in {'.' ,'[' ,']' ,':' }: result [i] = '_'
259- proc isIndex (n,i:NimNode ):bool =
260- result = n.eqident i
261- if n.kind == nnkHiddenStdConv:
262- result = n[1 ].eqident i
274+ # proc isIndex(n,i:NimNode):bool =
275+ # result = n.eqident i
276+ # if n.kind == nnkHiddenStdConv:
277+ # result = n[1].eqident i
263278macro simdForImpl (n:typed ):untyped =
264279 proc getIndexedPtrs (n,i:NimNode ):(NimNode ,seq [NimNode ]) =
265280 # echo "### getIndexedPtrs: ", i.repr
@@ -273,7 +288,8 @@ macro simdForImpl(n:typed):untyped =
273288 break
274289 if m < 0 :
275290 let v = gensym (nskVar, n.cleanAst.identStr)
276- ptrs.add newPar (v, n)
291+ # ptrs.add newPar(v, n)
292+ ptrs.add newNimNode (nnkTupleConstr).add (v, n)
277293 return v
278294 else :
279295 return ptrs[m][0 ]
0 commit comments