3434crossMapLRNBackward = None
3535
3636
37+ instanceNorm2d = None
38+ instanceNorm2dBackward = None
39+
40+
41+ spatialTf = None
42+ spatialTfBackward = None
43+
44+
45+ RNNMode = None
46+ DirectionMode = None
47+
48+ createRnn = None
49+
50+ acquireRnnParams = None
51+ updateRnnParams = None
52+
53+ forwardRnn = None
54+ backwardDataRnn = None
55+ backwardParamsRnn = None
56+
57+ deviceSupportsBatchHint = None
58+
59+
3760def autoinit ():
3861 if not Config .shouldInit ():
3962 return
@@ -87,10 +110,19 @@ def wrapCrossMapLRNBackward(data, outdata, grad, _, N, alpha, beta, K):
87110 crossMapLRN = wrapCrossMapLRN
88111 crossMapLRNBackward = wrapCrossMapLRNBackward
89112
113+ def wrapSpatialTf (data , transform , outshape , getGrid ):
114+ return dnn .spatialTf (data , transform , outshape , getGrid , allocator = memoryPool )
115+
116+ def wrapSpatialTfBackward (grad , data , grid ):
117+ return dnn .spatialTfBackward (grad , data , grid , allocator = memoryPool )
118+
119+ global spatialTf , spatialTfBackward
120+ spatialTf = wrapSpatialTf
121+ spatialTfBackward = wrapSpatialTfBackward
122+
90123
91124def initHip ():
92125 from PuzzleLib .Hip import Backend
93- initGPU (Backend )
94126
95127 backend = initGPU (Backend )
96128 memoryPool , dnn = backend .memoryPool , backend .dnn
@@ -125,9 +157,17 @@ def wrapMapLRNBackward(data, outdata, grad, means, workspace, N, alpha, beta, K)
125157
126158
127159def initGPU (Backend ):
128- import numpy as np
160+ backend = Backend .getBackend (Config .deviceIdx , initmode = 1 , logger = Config .getLogger ())
161+
162+ initBaseGPU (backend )
163+ initInstanceNormGPU (backend )
164+ initRnnGPU (backend )
165+
166+ return backend
167+
129168
130- backend = Backend .getBackend (Config .deviceIdx , initmode = 1 )
169+ def initBaseGPU (backend ):
170+ import numpy as np
131171 memoryPool , dnn = backend .memoryPool , backend .dnn
132172
133173 global ConvFwdAlgo , ConvBwdDataAlgo , ConvBwdFilterAlgo
@@ -143,7 +183,7 @@ def wrapConvNd(data, W, bias, stride, pad, dilation, groups, algo):
143183
144184 def wrapConvNdBackwardData (grad , W , data , stride , pad , dilation , groups , algo ):
145185 return dnn .convNdBackwardData (
146- grad , W , None , data , stride , pad , dilation , groups , algo .value , None , memoryPool
186+ grad , W , None , data , stride , pad , dilation , None , groups , algo .value , None , memoryPool
147187 )
148188
149189 def wrapConvNdBackwardParams (data , grad , W , bias , stride , pad , dilation , groups ,
@@ -168,9 +208,9 @@ def wrapConvNdbenchmark(datashape, Wshape, stride, pad, dilation, groups, transp
168208 global convNdbenchmark
169209 convNdbenchmark = wrapConvNdbenchmark
170210
171- def wrapDeconvNd (data , W , bias , stride , pad , dilation , groups , algo ):
211+ def wrapDeconvNd (data , W , bias , stride , pad , dilation , postpad , groups , algo ):
172212 return dnn .convNdBackwardData (
173- data , W , bias .ravel () if bias is not None else None , None , stride , pad , dilation , groups ,
213+ data , W , bias .ravel () if bias is not None else None , None , stride , pad , dilation , postpad , groups ,
174214 algo .value , None , memoryPool
175215 )
176216
@@ -231,15 +271,79 @@ def wrapSoftmaxNdBackward(outdata, grad):
231271 softmaxNd = wrapSoftmaxNd
232272 softmaxNdBackward = wrapSoftmaxNdBackward
233273
274+
275+ def initInstanceNormGPU (backend ):
276+ memoryPool = backend .memoryPool
277+
278+ def wrapInstanceNorm2d (data , scale , bias , epsilon = 1e-5 ):
279+ return backend .instanceNorm2d (data , scale .ravel (), bias .ravel (), epsilon , allocator = memoryPool )
280+
281+ def wrapInstanceNorm2dBackward (grad , data , extscale , savemean , saveinvvar , epsilon , affine ):
282+ return backend .instanceNorm2dBackward (
283+ grad , data , extscale , savemean , saveinvvar , epsilon , affine , allocator = memoryPool
284+ )
285+
286+ global instanceNorm2d , instanceNorm2dBackward
287+ instanceNorm2d = wrapInstanceNorm2d
288+ instanceNorm2dBackward = wrapInstanceNorm2dBackward
289+
290+
291+ def initRnnGPU (backend ):
292+ import numpy as np
293+ memoryPool = backend .memoryPool
294+
295+ global RNNMode , DirectionMode
296+ RNNMode = backend .RNNMode
297+ DirectionMode = backend .DirectionMode
298+
299+ def wrapCreateRnn (insize , hsize , layers , mode , direction , dropout , seed , batchsize ):
300+ rnn , W , params = backend .createRnn (
301+ insize , hsize , np .float32 , layers , mode = mode , direction = direction , dropout = dropout ,
302+ seed = seed , batchsize = 0 if batchsize is None else batchsize
303+ )
304+
305+ return rnn , W , {i : layer for i , layer in enumerate (params )}
306+
307+ def wrapAcquireRnnParams (descRnn , w ):
308+ params = backend .acquireRnnParams (descRnn , w )
309+ return w , params
310+
311+ def wrapUpdateRnnParams (descRnn , w , params ):
312+ params = [params [layer ] for layer in sorted (params .keys ())]
313+ backend .updateRnnParams (descRnn , w , params )
314+
315+ global createRnn , acquireRnnParams , updateRnnParams
316+ createRnn = wrapCreateRnn
317+ acquireRnnParams = wrapAcquireRnnParams
318+ updateRnnParams = wrapUpdateRnnParams
319+
320+ def wrapForwardRnn (data , W , descRnn , test = False ):
321+ return descRnn .forward (data , W , test = test , allocator = memoryPool )
322+
323+ def wrapBackwardDataRnn (grad , outdata , W , reserve , descRnn ):
324+ ingrad , _ , _ = descRnn .backwardData (grad , outdata , W , reserve , allocator = memoryPool )
325+ return ingrad , reserve
326+
327+ def wrapBackwardParamsRnn (data , outdata , _ , reserve , descRnn ):
328+ return descRnn .backwardParams (data , outdata , reserve , allocator = memoryPool )
329+
330+ global forwardRnn , backwardDataRnn , backwardParamsRnn
331+ forwardRnn = wrapForwardRnn
332+ backwardDataRnn = wrapBackwardDataRnn
333+ backwardParamsRnn = wrapBackwardParamsRnn
334+
335+ global deviceSupportsBatchHint
336+ deviceSupportsBatchHint = backend .deviceSupportsBatchHint
337+
234338 return backend
235339
236340
237341def initCPU ():
238342 from PuzzleLib .CPU .Wrappers import NumpyDnn
239343
240344 def wrapConvNd (data , W , bias , stride , pad , dilation , groups , algo ):
241- assert dilation == ( 1 , 1 ) and groups == 1
242- return NumpyDnn .conv2d (data , W , bias , stride , pad )
345+ assert groups == 1
346+ return NumpyDnn .conv2d (data , W , bias , stride , pad , dilation )
243347
244348 global convNd , convNdBackwardData , convNdBackwardParams
245349 convNd = wrapConvNd
@@ -263,9 +367,12 @@ def wrapBatchNormNd(data, scale, bias, mean, var, epsilon, factor, test, mode=No
263367 BatchNormMode = ProxyBatchNormMode
264368 batchNormNd = wrapBatchNormNd
265369
370+ global deviceSupportsBatchHint
371+ deviceSupportsBatchHint = lambda : False
372+
266373
267374def initIntel ():
268- from PuzzleLib .Intel .Wrappers import DNNL
375+ from PuzzleLib .Intel .Wrappers import DNNL , DNNLInstanceNorm
269376
270377 global ConvFwdAlgo , ConvBwdDataAlgo , ConvBwdFilterAlgo
271378 ConvFwdAlgo = DNNL .ConvAlgo
@@ -299,9 +406,9 @@ def wrapConvNdbenchmark(datashape, Wshape, stride, pad, dilation, groups, transp
299406 global convNdbenchmark
300407 convNdbenchmark = wrapConvNdbenchmark
301408
302- def wrapDeconvNd (data , W , bias , stride , pad , dilation , groups , algo ):
409+ def wrapDeconvNd (data , W , bias , stride , pad , dilation , groups , postpad , algo ):
303410 assert groups == 1
304- return DNNL .convNd (data , W , bias , stride , pad , dilation , algo = algo , transpose = True )
411+ return DNNL .convNd (data , W , bias , stride , pad , dilation , postpad , algo = algo , transpose = True )
305412
306413 def wrapDeconvNdBackwardData (grad , W , data , stride , pad , dilation , groups , algo ):
307414 assert groups == 1
@@ -360,5 +467,24 @@ def wrapCrossMapLRNBackward(data, outdata, grad, workspace, N, alpha, beta, K):
360467 crossMapLRN = wrapCrossMapLRN
361468 crossMapLRNBackward = wrapCrossMapLRNBackward
362469
470+ def wrapInstanceNorm2d (data , scale , bias , epsilon ):
471+ result = DNNLInstanceNorm .instanceNorm2d (data , scale , bias , epsilon )
472+
473+ outdata , savemean , savevar , extscale , extbias , desc = result
474+ return outdata , savemean , savevar , (extscale , extbias , desc )
475+
476+ def wrapInstanceNorm2dBackward (grad , data , exts , savemean , savevar , epsilon , affine ):
477+ extscale , extbias , desc = exts
478+ return DNNLInstanceNorm .instanceNorm2dBackward (
479+ grad , data , extscale , extbias , savemean , savevar , epsilon , desc , affine
480+ )
481+
482+ global instanceNorm2d , instanceNorm2dBackward
483+ instanceNorm2d = wrapInstanceNorm2d
484+ instanceNorm2dBackward = wrapInstanceNorm2dBackward
485+
486+ global deviceSupportsBatchHint
487+ deviceSupportsBatchHint = lambda : False
488+
363489
364490autoinit ()
0 commit comments