Skip to content

Commit 7522c32

Browse files
committed
Can pass parameters to lua class constructors now
1 parent 96b1bae commit 7522c32

File tree

5 files changed

+32
-20
lines changed

5 files changed

+32
-20
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_file_datetime(filepath):
139139

140140
setup(
141141
name='PyTorch',
142-
version='2.6.0',
142+
version='2.7.0',
143143
author='Hugh Perkins',
144144
author_email='hughperkins@gmail.com',
145145
description=(

simpleexample/luabit.lua

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ require 'nn'
33

44
local Luabit = torch.class('Luabit')
55

6-
function Luabit:__init()
6+
function Luabit:__init(someName)
7+
print('Luabit:__init(', someName, ')')
8+
self.someName = someName
9+
end
10+
11+
function Luabit:getName()
12+
return self.someName
713
end
814

915
function Luabit:getOut(inTensor, outSize, kernelSize)

simpleexample/pybit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
outSize = 3
1313
kernelSize = 3
1414

15-
luabit = Luabit()
15+
luabit = Luabit('green')
16+
print(luabit.getName())
1617

1718
inTensor = np.random.randn(batchSize, numFrames, inSize).astype('float32')
1819
luain = PyTorch.asFloatTensor(inTensor)

src/PyTorchAug.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def torchType(lua, pos):
6969
return popString(lua)
7070

7171
class LuaClass(object):
72-
def __init__(self, nameList, *args):
72+
def __init__(self, *args, nameList):
7373
# print('LuaClass.__init__()')
7474
lua = PyTorch.getGlobalState().getLua()
7575
# self.luaclass = luaclass
@@ -79,6 +79,8 @@ def __init__(self, nameList, *args):
7979
for arg in args:
8080
if isinstance(arg, int):
8181
lua.pushNumber(arg)
82+
elif isinstance(arg, str):
83+
lua.pushString(arg)
8284
else:
8385
raise Exception('arg type ' + str(type(arg)) + ' not implemented')
8486
lua.call(len(args), 1)
@@ -236,7 +238,7 @@ def __init__(self, _fromLua=False):
236238
# print('Table.__init__')
237239
if not _fromLua:
238240
name = self.__class__.__name__
239-
super(self.__class__, self).__init__(['nn', name])
241+
super(self.__class__, self).__init__(nameList=['nn', name])
240242
else:
241243
self.__dict__['__objectId'] = getNextObjectId()
242244
self.luaclass = 'table'
@@ -247,7 +249,7 @@ def __init__(self, numIn=1, numOut=1, _fromLua=False):
247249
self.luaclass = 'nn.Linear'
248250
if not _fromLua:
249251
name = self.__class__.__name__
250-
super(self.__class__, self).__init__(['nn', name], numIn, numOut)
252+
super(self.__class__, self).__init__(numIn, numOut, nameList=['nn', name])
251253
else:
252254
self.__dict__['__objectId'] = getNextObjectId()
253255

@@ -256,7 +258,7 @@ def __init__(self, _fromLua=False):
256258
self.luaclass = 'nn.ClassNLLCriterion'
257259
if not _fromLua:
258260
name = self.__class__.__name__
259-
super(self.__class__, self).__init__(['nn', name])
261+
super(self.__class__, self).__init__(nameList=['nn', name])
260262
else:
261263
self.__dict__['__objectId'] = getNextObjectId()
262264

@@ -265,7 +267,7 @@ def __init__(self, _fromLua=False):
265267
self.luaclass = 'nn.MSECriterion'
266268
if not _fromLua:
267269
name = self.__class__.__name__
268-
super(self.__class__, self).__init__(['nn', name])
270+
super(self.__class__, self).__init__(nameList=['nn', name])
269271
else:
270272
self.__dict__['__objectId'] = getNextObjectId()
271273

@@ -274,7 +276,7 @@ def __init__(self, _fromLua=False):
274276
self.luaclass = 'nn.Sequential'
275277
if not _fromLua:
276278
name = self.__class__.__name__
277-
super(self.__class__, self).__init__(['nn', name])
279+
super(self.__class__, self).__init__(nameList=['nn', name])
278280
else:
279281
self.__dict__['__objectId'] = getNextObjectId()
280282

@@ -283,7 +285,7 @@ def __init__(self, _fromLua=False):
283285
self.luaclass = 'nn.LogSoftMax'
284286
if not _fromLua:
285287
name = self.__class__.__name__
286-
super(self.__class__, self).__init__(['nn', name])
288+
super(self.__class__, self).__init__(nameList=['nn', name])
287289
else:
288290
self.__dict__['__objectId'] = getNextObjectId()
289291

@@ -293,13 +295,13 @@ def __init__(self, s1, s2=None, s3=None, s4=None, _fromLua=False):
293295
if not _fromLua:
294296
name = self.__class__.__name__
295297
if s4 is not None: # this is a bit hacky, but gets it working for now...
296-
super(self.__class__, self).__init__(['nn', name], s1, s2, s3, s4)
298+
super(self.__class__, self).__init__(s1, s2, s3, s4, nameList=['nn', name])
297299
elif s3 is not None:
298-
super(self.__class__, self).__init__(['nn', name], s1, s2, s3)
300+
super(self.__class__, self).__init__(s1, s2, s3, nameList=['nn', name])
299301
elif s2 is not None:
300-
super(self.__class__, self).__init__(['nn', name], s1, s2)
302+
super(self.__class__, self).__init__(s1, s2, nameList=['nn', name])
301303
else:
302-
super(self.__class__, self).__init__(['nn', name], s1)
304+
super(self.__class__, self).__init__(s1, nameList=['nn', name])
303305
else:
304306
self.__dict__['__objectId'] = getNextObjectId()
305307

@@ -308,7 +310,7 @@ def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW=1, dH=1, padW=0, padH=0
308310
self.luaclass = 'nn.SpatialConvolutionMM'
309311
if not _fromLua:
310312
name = self.__class__.__name__
311-
super(self.__class__, self).__init__(['nn', name], nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
313+
super(self.__class__, self).__init__(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH, nameList=['nn', name])
312314
else:
313315
self.__dict__['__objectId'] = getNextObjectId()
314316

@@ -317,7 +319,7 @@ def __init__(self, kW, kH, dW, dH, padW=0, padH=0, _fromLua=False):
317319
self.luaclass = 'nn.SpatialMaxPooling'
318320
if not _fromLua:
319321
name = self.__class__.__name__
320-
super(self.__class__, self).__init__(['nn', name], kW, kH, dW, dH, padW, padH)
322+
super(self.__class__, self).__init__(kW, kH, dW, dH, padW, padH, nameList=['nn', name])
321323
else:
322324
self.__dict__['__objectId'] = getNextObjectId()
323325

@@ -326,7 +328,7 @@ def __init__(self, _fromLua=False):
326328
self.luaclass = 'nn.ReLU'
327329
if not _fromLua:
328330
name = self.__class__.__name__
329-
super(self.__class__, self).__init__(['nn', name])
331+
super(self.__class__, self).__init__(nameList=['nn', name])
330332
else:
331333
self.__dict__['__objectId'] = getNextObjectId()
332334

@@ -335,7 +337,7 @@ def __init__(self, _fromLua=False):
335337
self.luaclass = 'nn.Tanh'
336338
if not _fromLua:
337339
name = self.__class__.__name__
338-
super(self.__class__, self).__init__(['nn', name])
340+
super(self.__class__, self).__init__(nameList=['nn', name])
339341
else:
340342
self.__dict__['__objectId'] = getNextObjectId()
341343

src/PyTorchHelpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@ def load_lua_class(lua_filename, lua_classname):
55
module = lua_filename.replace('.lua', '')
66
PyTorch.require(module)
77
class LuaWrapper(PyTorchAug.LuaClass):
8-
def __init__(self, _fromLua=False):
8+
def __init__(self, *args, _fromLua=False):
9+
#print('calling super constructor with', args)
10+
#super(LuaWrapper, self).__init__(*args)
911
self.luaclass = lua_classname
1012
if not _fromLua:
1113
name = lua_classname
12-
super(self.__class__, self).__init__([name])
14+
super(self.__class__, self).__init__(*args, nameList=[name])
1315
else:
1416
self.__dict__['__objectId'] = getNextObjectId()
17+
# self.__getattr__('__init')(*args)
1518
return LuaWrapper
1619

0 commit comments

Comments
 (0)