@@ -69,7 +69,7 @@ def torchType(lua, pos):
6969 return popString (lua )
7070
7171class LuaClass (object ):
72- def __init__ (self , nameList , * args ):
72+ def __init__ (self , * args , nameList ):
7373 # print('LuaClass.__init__()')
7474 lua = PyTorch .getGlobalState ().getLua ()
7575# self.luaclass = luaclass
@@ -79,6 +79,8 @@ def __init__(self, nameList, *args):
7979 for arg in args :
8080 if isinstance (arg , int ):
8181 lua .pushNumber (arg )
82+ elif isinstance (arg , str ):
83+ lua .pushString (arg )
8284 else :
8385 raise Exception ('arg type ' + str (type (arg )) + ' not implemented' )
8486 lua .call (len (args ), 1 )
@@ -236,7 +238,7 @@ def __init__(self, _fromLua=False):
236238 # print('Table.__init__')
237239 if not _fromLua :
238240 name = self .__class__ .__name__
239- super (self .__class__ , self ).__init__ (['nn' , name ])
241+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
240242 else :
241243 self .__dict__ ['__objectId' ] = getNextObjectId ()
242244 self .luaclass = 'table'
@@ -247,7 +249,7 @@ def __init__(self, numIn=1, numOut=1, _fromLua=False):
247249 self .luaclass = 'nn.Linear'
248250 if not _fromLua :
249251 name = self .__class__ .__name__
250- super (self .__class__ , self ).__init__ (['nn' , name ], numIn , numOut )
252+ super (self .__class__ , self ).__init__ (numIn , numOut , nameList = ['nn' , name ])
251253 else :
252254 self .__dict__ ['__objectId' ] = getNextObjectId ()
253255
@@ -256,7 +258,7 @@ def __init__(self, _fromLua=False):
256258 self .luaclass = 'nn.ClassNLLCriterion'
257259 if not _fromLua :
258260 name = self .__class__ .__name__
259- super (self .__class__ , self ).__init__ (['nn' , name ])
261+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
260262 else :
261263 self .__dict__ ['__objectId' ] = getNextObjectId ()
262264
@@ -265,7 +267,7 @@ def __init__(self, _fromLua=False):
265267 self .luaclass = 'nn.MSECriterion'
266268 if not _fromLua :
267269 name = self .__class__ .__name__
268- super (self .__class__ , self ).__init__ (['nn' , name ])
270+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
269271 else :
270272 self .__dict__ ['__objectId' ] = getNextObjectId ()
271273
@@ -274,7 +276,7 @@ def __init__(self, _fromLua=False):
274276 self .luaclass = 'nn.Sequential'
275277 if not _fromLua :
276278 name = self .__class__ .__name__
277- super (self .__class__ , self ).__init__ (['nn' , name ])
279+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
278280 else :
279281 self .__dict__ ['__objectId' ] = getNextObjectId ()
280282
@@ -283,7 +285,7 @@ def __init__(self, _fromLua=False):
283285 self .luaclass = 'nn.LogSoftMax'
284286 if not _fromLua :
285287 name = self .__class__ .__name__
286- super (self .__class__ , self ).__init__ (['nn' , name ])
288+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
287289 else :
288290 self .__dict__ ['__objectId' ] = getNextObjectId ()
289291
@@ -293,13 +295,13 @@ def __init__(self, s1, s2=None, s3=None, s4=None, _fromLua=False):
293295 if not _fromLua :
294296 name = self .__class__ .__name__
295297 if s4 is not None : # this is a bit hacky, but gets it working for now...
296- super (self .__class__ , self ).__init__ ([ 'nn' , name ], s1 , s2 , s3 , s4 )
298+ super (self .__class__ , self ).__init__ (s1 , s2 , s3 , s4 , nameList = [ 'nn' , name ] )
297299 elif s3 is not None :
298- super (self .__class__ , self ).__init__ ([ 'nn' , name ], s1 , s2 , s3 )
300+ super (self .__class__ , self ).__init__ (s1 , s2 , s3 , nameList = [ 'nn' , name ] )
299301 elif s2 is not None :
300- super (self .__class__ , self ).__init__ (['nn' , name ], s1 , s2 )
302+ super (self .__class__ , self ).__init__ (s1 , s2 , nameList = ['nn' , name ])
301303 else :
302- super (self .__class__ , self ).__init__ (['nn' , name ], s1 )
304+ super (self .__class__ , self ).__init__ (s1 , nameList = ['nn' , name ])
303305 else :
304306 self .__dict__ ['__objectId' ] = getNextObjectId ()
305307
@@ -308,7 +310,7 @@ def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW=1, dH=1, padW=0, padH=0
308310 self .luaclass = 'nn.SpatialConvolutionMM'
309311 if not _fromLua :
310312 name = self .__class__ .__name__
311- super (self .__class__ , self ).__init__ ([ 'nn' , name ], nInputPlane , nOutputPlane , kW , kH , dW , dH , padW , padH )
313+ super (self .__class__ , self ).__init__ (nInputPlane , nOutputPlane , kW , kH , dW , dH , padW , padH , nameList = [ 'nn' , name ] )
312314 else :
313315 self .__dict__ ['__objectId' ] = getNextObjectId ()
314316
@@ -317,7 +319,7 @@ def __init__(self, kW, kH, dW, dH, padW=0, padH=0, _fromLua=False):
317319 self .luaclass = 'nn.SpatialMaxPooling'
318320 if not _fromLua :
319321 name = self .__class__ .__name__
320- super (self .__class__ , self ).__init__ ([ 'nn' , name ], kW , kH , dW , dH , padW , padH )
322+ super (self .__class__ , self ).__init__ (kW , kH , dW , dH , padW , padH , nameList = [ 'nn' , name ] )
321323 else :
322324 self .__dict__ ['__objectId' ] = getNextObjectId ()
323325
@@ -326,7 +328,7 @@ def __init__(self, _fromLua=False):
326328 self .luaclass = 'nn.ReLU'
327329 if not _fromLua :
328330 name = self .__class__ .__name__
329- super (self .__class__ , self ).__init__ (['nn' , name ])
331+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
330332 else :
331333 self .__dict__ ['__objectId' ] = getNextObjectId ()
332334
@@ -335,7 +337,7 @@ def __init__(self, _fromLua=False):
335337 self .luaclass = 'nn.Tanh'
336338 if not _fromLua :
337339 name = self .__class__ .__name__
338- super (self .__class__ , self ).__init__ (['nn' , name ])
340+ super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
339341 else :
340342 self .__dict__ ['__objectId' ] = getNextObjectId ()
341343
0 commit comments