Skip to content

Commit b83f9fe

Browse files
committed
start unified gpu interface
1 parent b1aac2f commit b83f9fe

File tree

8 files changed

+152
-51
lines changed

8 files changed

+152
-51
lines changed

src/backend/accel.nim

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#const Backend {.strdefine.} = "OpenMP"
2+
#const Backend {.strdefine.} = "CUDA"
3+
#const Backend {.strdefine.} = "SYCL"
4+
const Backend {.strdefine.} = "CPU"
5+
6+
when Backend == "OpenMP":
7+
#const useGPU = true
8+
import openmp
9+
export openmp
10+
elif Backend == "CUDA":
11+
#const useGPU = true
12+
import cuda
13+
export cuda
14+
elif Backend == "SYCL":
15+
#const useGPU = true
16+
import syclbe
17+
export syclbe
18+
else:
19+
{.warning: "Backend unknown, use CPU only.".}
20+
#const useGPU = false
21+
import cpu
22+
export cpu
23+
24+
#when useGPU:
25+
# import expr
26+
# import gpuarray
27+
# export gpuarray
28+
#else:
29+
# template onGpu*(x:untyped) = threads: x
30+
# template onGpu*(n,x:untyped) = threads: x
31+
# template onGpu*(n,t,x:untyped) = threads: x
32+
# template packVarsStmt(x,y:untyped) = discard
33+
34+
when isMainModule:
35+
#import qex
36+
#qexInit()
37+
proc test1 =
38+
var x = 1.0'f32
39+
#var yp = cast[ptr float32](gpuMalloc(sizeof(float32)))
40+
echo "x: ", x
41+
#threads:
42+
onGpu:
43+
x = 2.0
44+
echo "x: ", x
45+
46+
test1()

src/backend/cpu.nim

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import base/threading
2+
3+
#type
4+
# Buffer*[T] = object
5+
# data: ptr UncheckedArray[T]
6+
# n: int
7+
8+
#proc init[T](b: Buffer[T], n: int) =
9+
# let p = allocShared(n*sizeof(T))
10+
# b.data = cast[type b.data](p)
11+
# b.n = n
12+
13+
#proc free[T](b: Buffer[T]) =
14+
# free b.data
15+
16+
template onGpu*(x:untyped) = threads: x
17+
template onGpu*(n,x:untyped) = threads: x
18+
template onGpu*(n,t,x:untyped) = threads: x

src/backend/cuda.nim

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ template getInst*(p: untyped): untyped =
148148
procInst(p)
149149
#var t =
150150
#t
151-
macro `>>`*(px: tuple, y: any): auto =
151+
macro `>>`*(px: tuple, y: auto): auto =
152152
#echo "begin >>:"
153153
#echo px.treerepr
154154
#echo "kernel type:"
@@ -199,7 +199,7 @@ template onGpu*(nn,tpb: untyped, body: untyped): untyped =
199199
type ByCopy[T] {.bycopy.} = object
200200
d: T
201201
proc kern(xx: ByCopy[type(v)]) {.cudaGlobal.} =
202-
template deref(k: int): untyped = xx.d[k]
202+
template deref(k: int): untyped = xx.d[k][]
203203
substVars(body, deref)
204204
let ni = nn.int32
205205
let threadsPerBlock = tpb.int32
@@ -210,6 +210,9 @@ template onGpu*(nn,tpb: untyped, body: untyped): untyped =
210210
template onGpu*(nn: untyped, body: untyped): untyped = onGpu(nn, 64, body)
211211
template onGpu*(body: untyped): untyped = onGpu(512*64, 64, body)
212212

213+
template getGpuPtr*(x: SomeNumber): untyped = unsafeAddr x
214+
215+
213216
when isMainModule:
214217
type FltArr = UncheckedArray[float32]
215218

src/backend/expr.nim

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ macro substVars*(x: untyped, a: untyped): auto =
9797
var v = newSeq[NimNode](0)
9898
let e = getVars(v, x, a)
9999
result = e
100-
#echo result.treerepr
100+
echo result.treerepr
101101

102102
when isMainModule:
103103
template test(x) =

src/backend/openmp.nim

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ macro onGpu*(body: untyped): untyped =
292292
cpuFinalize = genCpuFinalize v
293293
isDevicePtrs = declarePtrTuple v
294294
result = getast(target(cpuPrepare, gpuPrepare, cpuFinalize, isDevicePtrs, body))
295-
#echo result.repr
295+
echo result.repr
296296

297297
# XXX fix the following
298298
template onGpu*(totalNumThreads, body: untyped): untyped = onGpu(body)
@@ -317,13 +317,21 @@ template onGpu*(nn: untyped, body: untyped): untyped = onGpu(nn, 64, body)
317317
template onGpu*(body: untyped): untyped = onGpu(512*64, 64, body)
318318
]#
319319

320-
template offloadUseVar*(x:SomeNumber):bool = true
321-
template offloadUsePtr*(x:SomeNumber):bool = false
320+
#template offloadUseVar*(x:SomeNumber):bool = true
321+
#template offloadUsePtr*(x:SomeNumber):bool = false
322322
template rungpuPrepareOffload*(x:SomeNumber):bool = false
323-
template runcpuFinalizeOffload*(x:SomeNumber):bool = false
324-
template gpuVarPtr*(v:SomeNumber,p:untyped):untyped = v
323+
#template runcpuFinalizeOffload*(x:SomeNumber):bool = false
324+
#template gpuVarPtr*(v:SomeNumber,p:untyped):untyped = v
325325
template offloadVar*(x:SomeNumber,p:untyped):untyped = x
326326

327+
template runcpuFinalizeOffload*(x:SomeNumber):bool = true
328+
template offloadUseVar*(x:SomeNumber):bool = false
329+
template offloadUsePtr*(x:SomeNumber):bool = true
330+
template gpuVarPtr*(v:untyped,p:ptr SomeNumber):untyped = p[]
331+
template offloadPtr*(x:SomeNumber):untyped = unsafeAddr x
332+
template cpuFinalizeOffload*(x:SomeNumber,v,p:untyped) =
333+
x = p[]
334+
327335
template toUArray(a:untyped):untyped = cast[ptr UncheckedArray[typeof(a[0])]](a[0].unsafeaddr)
328336
proc cleanAst(n:NimNode):NimNode =
329337
if n.kind in {nnkHiddenDeref,nnkHiddenCallConv,nnkHiddenStdConv}:

src/backend/sycl.nim

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import macros
22

3-
{.pragma: syclh,header:"<CL/sycl.hpp>".}
3+
{.pragma: syclh,header:"<sycl/sycl.hpp>".}
4+
{.passC: "-fsycl".}
5+
{.passL: "-fsycl".}
46

57
type
6-
DefaultSelector* {.importcpp:"sycl::default_selector", syclh.} = object
7-
HostSelector* {.importcpp:"sycl::host_selector", syclh.} = object
8+
#DefaultSelector* {.importcpp:"sycl::default_selector_v", syclh.} = object
9+
#HostSelector* {.importcpp:"sycl::host_selector_v", syclh.} = object
810
Context* {.importcpp:"sycl::context", syclh.} = object
911
Device* {.importcpp:"sycl::device", syclh.} = object
1012
Queue* {.importcpp:"sycl::queue", syclh.} = object
1113
SyclBuffer*[T;N:static[int]] {.importcpp:"sycl::buffer", syclh.} = object
1214
AmRead {.importcpp:"sycl::access::mode::read", syclh.} = object
1315
AmWrite {.importcpp:"sycl::access::mode::write", syclh.} = object
1416
AmReadWrite {.importcpp:"sycl::access::mode::read_write", syclh.} = object
15-
TgHost {.importcpp:"sycl::access::target::host_buffer", syclh.} = object
17+
#TgHost {.importcpp:"sycl::access::target::host_buffer", syclh.} = object
1618
TgGlobal {.importcpp:"sycl::access::target::global_buffer", syclh.} = object
1719
SyclAccessor*[T;N:static[int];M;A] {.
1820
importcpp:"sycl::accessor", syclh.} = object
@@ -37,11 +39,18 @@ type
3739
Accessor*[T;N:static[int];M;G] = ref object
3840
acc: SyclAccessor[T,N,M,G]
3941

40-
proc newDefaultSelector*(): DefaultSelector {.importcpp:"default_selector()", syclh.}
41-
proc newHostSelector*(): HostSelector {.importcpp:"host_selector()", syclh.}
42+
#proc newDefaultSelector*(): DefaultSelector {.importcpp:"default_selector_v", syclh.}
43+
#proc newHostSelector*(): HostSelector {.importcpp:"host_selector()", syclh.}
44+
#var
45+
# defaultSelector*{.importcpp:"default_selector_v", syclh.}
46+
# hostSelector*{.importcpp:"host_selector_v", syclh.}
4247

43-
proc selectDevice*(x: DefaultSelector): Device {.importcpp:"#.select_device()".}
44-
proc selectDevice*(x: HostSelector): Device {.importcpp:"#.select_device()".}
48+
#proc selectDevice*(x: DefaultSelector): Device {.importcpp:"#.select_device()".}
49+
#proc selectDevice*(x: HostSelector): Device {.importcpp:"#.select_device()".}
50+
#proc selectDevice*(x: object): Device {.importcpp:"#.select_device()".}
51+
#proc device*(x: object): Device {.importcpp:"'0(#)", constructor, syclh.}
52+
proc defaultDevice*():Device {.importcpp:"'0{sycl::default_selector_v}", constructor, syclh.}
53+
proc hostDevice*():Device {.importcpp:"'0{sycl::host_selector_v}", constructor, syclh.}
4554

4655
type cppstring {.importcpp:"std::string",header:"string".} = object
4756
proc `len`*(x: cppstring): cint {.importcpp:"length".}
@@ -152,21 +161,21 @@ proc getPointer*[T;N:static[int];M](
152161

153162
template submit0*(q: Queue, body: typed) =
154163
proc qs(qq: Queue) {.gensym.} =
155-
{.emit:[qq,".submit([&](cl::sycl::handler &cgh){"].}
164+
{.emit:[qq,".submit([&](sycl::handler &cgh){"].}
156165
body
157166
{.emit:"});".}
158167
qs(q)
159168

160169
template submit*(q: Queue, body: typed) =
161170
block:
162-
{.emit:[q,".submit([&](cl::sycl::handler &cgh){"].}
171+
{.emit:[q,".submit([&](sycl::handler &cgh){"].}
163172
body
164173
{.emit:["});"].}
165174

166175
template setupSycl* =
167176
{.pragma: id1, importcpp:"it",nodecl,header:"",noinit,codegendecl:"".}
168177
{.pragma: item1, importcpp:"it",nodecl,header:"",noinit,codegendecl:"".}
169-
setupSycl()
178+
#setupSycl()
170179
#macro id1*(x: untyped): untyped =
171180
#echo "test"
172181
#echo x.kind
@@ -208,11 +217,11 @@ macro sumX*(x: typed, et: typedesc): untyped =
208217
template sum*[T](x: SyclAccessor[T,1,AmReadWrite,TgGlobal]): untyped =
209218
sumX(x, type(T))
210219

211-
macro val(x: typed): untyped =
212-
#echo x.repr
213-
let v = x.getImpl[2][1][1]
214-
echo v.treerepr
215-
result = v
220+
#macro val(x: typed): untyped =
221+
# #echo x.repr
222+
# let v = x.getImpl[2][1][1]
223+
# echo v.treerepr
224+
# result = v
216225
macro nm(x: typed): untyped =
217226
let v = x.getImpl[2][1][1]
218227
let i = v.intVal
@@ -287,21 +296,21 @@ proc `[]`*[T;G](x: SyclAccessor[T,1,AmWrite,G], i: int): var T {.
287296
proc `[]`*[T;G](x: SyclAccessor[T,1,AmWrite,G], i: Id1): var T {.
288297
importcpp:"#[#]", syclh.}
289298

290-
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmRead,G], i: int, y: any) {.
299+
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmRead,G], i: int, y: auto) {.
291300
error:"illegal use of []= on read-only accessor".}
292-
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmRead,G], i: Id1, y: any) {.
301+
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmRead,G], i: Id1, y: auto) {.
293302
error:"illegal use of []= on read-only accessor".}
294303

295-
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmWrite,G], i: int, y: any) {.
304+
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmWrite,G], i: int, y: auto) {.
296305
importcpp:"#[#]=#", syclh.}
297-
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmWrite,G], i: Id1, y: any) {.
306+
proc `[]=`*[T;G](x: SyclAccessor[T,1,AmWrite,G], i: Id1, y: auto) {.
298307
importcpp:"#[#]=#", syclh.}
299308

300309
proc `[]`*[T](x: SyclHostAccessor[T,1,AmRead], i: int): T {.
301310
importcpp:"#[#]", syclh.}
302311
proc `[]`*[T](x: SyclHostAccessor[T,1,AmWrite], i: int): T {.
303312
importcpp:"#[#]", syclh.}
304-
proc `[]=`*[T](x: SyclHostAccessor[T,1,AmWrite], i: int, y: any) {.
313+
proc `[]=`*[T](x: SyclHostAccessor[T,1,AmWrite], i: int, y: auto) {.
305314
importcpp:"#[#]=#", syclh.}
306315

307316
when isMainModule:

src/backend/syclbe.nim

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import macros
22
import base/metaUtils
33
import sycl
4-
setupSycl()
4+
#setupSycl()
55

6+
let dev = defaultDevice()
7+
let q = dev.queue
68

79
proc alignatImpl(n:NimNode, byte:int): NimNode =
810
result = n.copyNimNode
@@ -21,7 +23,7 @@ proc alignatImpl(n:NimNode, byte:int): NimNode =
2123
for c in n:
2224
result.add c.alignatImpl byte
2325
macro alignat*(byte:static[int], n:untyped): untyped =
24-
if byte notin {1,2,4,8,16,32,64,128,256}:
26+
if byte notin [1,2,4,8,16,32,64,128,256]:
2527
error("macro alignat: unsupported alignment: " & $byte, n)
2628
#echo "alignatImpl ", byte
2729
#echo n.treerepr
@@ -79,9 +81,10 @@ template dataAddr*(x: typed): pointer =
7981
#else: x
8082

8183
template openmpDefs(body: untyped): untyped =
84+
setupSycl()
8285
var item {.item1.}: Item1
83-
template getThreadNum: untyped = item[]
84-
template getNumThreads: untyped = item.getRange
86+
template getThreadNum: untyped {.used.} = item[]
87+
template getNumThreads: untyped {.used.} = item.getRange
8588
{.emit:["#define nimZeroMem(b,len) memset((b),0,(len))"].}
8689
#inlineProcs:
8790
body
@@ -107,7 +110,8 @@ proc prepareVars(n:NimNode):seq[NimNode] =
107110
nnkWhileStmt, nnkForStmt} + RoutineNodes:
108111
# New lexical scope
109112
newscope = true
110-
ignoreStack.add newPar()
113+
#ignoreStack.add newPar()
114+
ignoreStack.add newNimNode(nnkTupleConstr)
111115
if n.kind == nnkForStmt:
112116
ignoreStack[^1].add n[0]
113117
for i in 0..<n.len:
@@ -155,7 +159,8 @@ proc prepareVars(n:NimNode):seq[NimNode] =
155159
let np = gensym(nsklet, "gpu_ptr_" & $n[i])
156160
ignoreStack[0].add nv
157161
ignoreStack[0].add np
158-
openvars.add newpar(n[i], nv, np)
162+
#openvars.add newpar(n[i], nv, np)
163+
openvars.add newNimNode(nnkTupleConstr).add(n[i], nv, np)
159164
n[i] = newcall("gpuVarPtr",nv,np)
160165
else:
161166
discard
@@ -179,7 +184,7 @@ proc genCpuPrepare(n:seq[NimNode]):NimNode =
179184
result = newstmtlist()
180185
for c in n:
181186
result.add getast r(c[0],c[1],c[2])
182-
echo result.repr
187+
#echo result.repr
183188
proc genGpuPrepare(n:seq[NimNode]):NimNode =
184189
template r(x,v,p:untyped):untyped =
185190
mixin gpuPrepareOffload, rungpuPrepareOffload
@@ -206,9 +211,9 @@ proc declarePtrString(n:seq[NimNode]):NimNode =
206211
ps = infix(getast varname(c[0], $c[2]), "&", ps)
207212
result = getast res(ps)
208213

209-
macro onGpu*(q: Queue, body: untyped): untyped =
214+
macro onGpu*(q: Queue, body: untyped): auto =
210215
# the architecture for cpugpuarray requires us replace body before it gets expanded, so we require untyped.
211-
template target(cpuPrepare, gpuPrepare, cpuFinalize, devicePtrDeclare, body: untyped): untyped =
216+
template target(q, cpuPrepare, gpuPrepare, cpuFinalize, devicePtrDeclare, body: untyped): untyped =
212217
mixin hasGpuPtr, requireGpuMem
213218
{.push checks: off.}
214219
{.push stacktrace: off.}
@@ -222,6 +227,7 @@ macro onGpu*(q: Queue, body: untyped): untyped =
222227
openmpDefs:
223228
#gpuPrepare
224229
body
230+
q.wait
225231
cpuFinalize
226232
gpuProc()
227233
let
@@ -230,20 +236,29 @@ macro onGpu*(q: Queue, body: untyped): untyped =
230236
gpuPrepare = genGpuPrepare v
231237
cpuFinalize = genCpuFinalize v
232238
isDevicePtrs = declarePtrString v
233-
result = getast(target(cpuPrepare, gpuPrepare, cpuFinalize, isDevicePtrs, body))
239+
result = getast(target(q, cpuPrepare, gpuPrepare, cpuFinalize, isDevicePtrs, body))
234240
#echo result.repr
235241

236242
# XXX fix the following
237-
template onGpu*(totalNumThreads, body: untyped): untyped = onGpu(body)
238-
template onGpu*(totalNumThreads, numThreadsPerTeam, body: untyped): untyped = onGpu(body)
243+
template onGpu*(body: untyped) = onGpu(q, body)
244+
#template onGpu*(totalNumThreads, body: untyped): untyped = onGpu(body)
245+
#template onGpu*(totalNumThreads, numThreadsPerTeam, body: untyped): untyped = onGpu(body)
239246

240-
template offloadUseVar*(x:SomeNumber):bool = true
241-
template offloadUsePtr*(x:SomeNumber):bool = false
247+
#template offloadUseVar*(x:SomeNumber):bool = true
248+
#template offloadUsePtr*(x:SomeNumber):bool = false
242249
template rungpuPrepareOffload*(x:SomeNumber):bool = false
243-
template runcpuFinalizeOffload*(x:SomeNumber):bool = false
244-
template gpuVarPtr*(v:SomeNumber,p:untyped):untyped = v
250+
#template runcpuFinalizeOffload*(x:SomeNumber):bool = false
251+
#template gpuVarPtr*(v:SomeNumber,p:untyped):untyped = v
245252
template offloadVar*(x:SomeNumber,p:untyped):untyped = x
246253

254+
template runcpuFinalizeOffload*(x:SomeNumber):bool = true
255+
template offloadUseVar*(x:SomeNumber):bool = false
256+
template offloadUsePtr*(x:SomeNumber):bool = true
257+
template gpuVarPtr*(v:untyped,p:ptr SomeNumber):untyped = p[]
258+
template offloadPtr*(x:SomeNumber):untyped = unsafeAddr x
259+
template cpuFinalizeOffload*(x:SomeNumber,v,p:untyped) =
260+
x = p[]
261+
247262
template toUArray(a:untyped):untyped = cast[ptr UncheckedArray[typeof(a[0])]](a[0].unsafeaddr)
248263
proc cleanAst(n:NimNode):NimNode =
249264
if n.kind in {nnkHiddenDeref,nnkHiddenCallConv,nnkHiddenStdConv}:
@@ -256,10 +271,10 @@ proc identStr(n:NimNode):string =
256271
result = n.repr
257272
for i in 0..<result.len:
258273
if result[i] in {'.','[',']',':'}: result[i] = '_'
259-
proc isIndex(n,i:NimNode):bool =
260-
result = n.eqident i
261-
if n.kind == nnkHiddenStdConv:
262-
result = n[1].eqident i
274+
#proc isIndex(n,i:NimNode):bool =
275+
# result = n.eqident i
276+
# if n.kind == nnkHiddenStdConv:
277+
# result = n[1].eqident i
263278
macro simdForImpl(n:typed):untyped =
264279
proc getIndexedPtrs(n,i:NimNode):(NimNode,seq[NimNode]) =
265280
#echo "### getIndexedPtrs: ", i.repr
@@ -273,7 +288,8 @@ macro simdForImpl(n:typed):untyped =
273288
break
274289
if m < 0:
275290
let v = gensym(nskVar, n.cleanAst.identStr)
276-
ptrs.add newPar(v, n)
291+
#ptrs.add newPar(v, n)
292+
ptrs.add newNimNode(nnkTupleConstr).add(v, n)
277293
return v
278294
else:
279295
return ptrs[m][0]

0 commit comments

Comments
 (0)