Skip to content

Commit 6651ed3

Browse files
committed
- lua functions can return tables to python now
1 parent 4380874 commit 6651ed3

File tree

10 files changed

+514
-297
lines changed

10 files changed

+514
-297
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_file_datetime(filepath):
139139

140140
setup(
141141
name='PyTorch',
142-
version='2.8.2-SNAPSHOT',
142+
version='2.9.0',
143143
author='Hugh Perkins',
144144
author_email='hughperkins@gmail.com',
145145
description=(

simpleexample/luabit.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ function Luabit:printTable(sometable, somestring, table2)
2929
for k, v in pairs(table2) do
3030
print('Luabit table2 ', k, v)
3131
end
32+
return {bear='happy', result=12.345, foo='bar'}
3233
end
3334

simpleexample/pybit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
outTensor = luaout.asNumpyTensor()
2424
print('outTensor', outTensor)
2525

26-
luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234}, 'mistletoe', {
26+
res = luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234}, 'mistletoe', {
2727
'row1': 'col1', 'meta': 'data'})
28+
print('res', res)
2829

src/PyTorch.cpp

Lines changed: 12 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/PyTorchAug.py

Lines changed: 53 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,61 @@ def pushSomething(lua, something):
9696

9797
raise Exception('pushing type ' + str(type(something)) + ' not implemented, value ', something)
9898

99+
def popSomething(lua):
100+
lua.pushValue(-1)
101+
pushGlobal(lua, 'torch', 'type')
102+
lua.insert(-2)
103+
lua.call(1, 1)
104+
typestring = popString(lua)
105+
106+
if typestring in cythonClasses:
107+
popFunction = cythonClasses[typestring]['popFunction']
108+
res = popFunction()
109+
return res
110+
111+
if typestring == 'number':
112+
res = lua.toNumber(-1)
113+
lua.remove(-1)
114+
return res
115+
116+
if typestring == 'string':
117+
res = popString(lua)
118+
return res
119+
120+
if typestring == 'table':
121+
return popTable(lua)
122+
123+
if typestring in luaClasses:
124+
returnobject = luaClasses[typestring](_fromLua=True)
125+
registerObject(lua, returnobject)
126+
return returnobject
127+
128+
if typestring == 'nil':
129+
lua.remove(-1)
130+
return None
131+
132+
# raise Exception('pop type ' + str(typestring) + ' not implemented')
133+
print('pop type ' + str(typestring) + ' not implemented')
134+
99135
def pushTable(lua, table):
100136
lua.newTable()
101137
for k, v in table.items():
102138
pushSomething(lua, k)
103139
pushSomething(lua, v)
104140
lua.setTable(-3)
105141

142+
def popTable(lua):
143+
res = {}
144+
lua.pushNil()
145+
while lua.next(-2) != 0:
146+
value = popSomething(lua)
147+
lua.pushValue(-1)
148+
key = popSomething(lua)
149+
res[key] = value
150+
lua.remove(-1)
151+
return res
152+
153+
106154
class LuaClass(object):
107155
def __init__(self, *args, nameList):
108156
lua = PyTorch.getGlobalState().getLua()
@@ -115,22 +163,14 @@ def __init__(self, *args, nameList):
115163
lua.call(len(args), 1)
116164
registerObject(lua, self)
117165

118-
# nameList = nameList[:]
119-
# nameList.append('float')
120-
# pushGlobalFromList(lua, nameList)
121-
# pushObject(lua, self)
122-
# lua.call(1, 0)
123-
124166
topEnd = lua.getTop()
125167
assert topStart == topEnd
126168

127169
def __del__(self):
128170
name = self.__class__.__name__
129-
# print(name + '.__del__')
130171

131172
def __repr__(self):
132173
topStart = lua.getTop()
133-
# name = self.__class__.__name__
134174
luaClass = self.luaclass
135175
if luaClass == 'table':
136176
return 'table'
@@ -141,7 +181,6 @@ def __repr__(self):
141181
pushGlobal(lua, splitLuaClass[0], splitLuaClass[1], '__tostring')
142182
else:
143183
raise Exception('not implemented: luaclass with more than 2 parts ' + luaClass)
144-
# pushGlobal(lua, 'nn', name, '__tostring')
145184
pushObject(lua, self)
146185
lua.call(1, 1)
147186
res = popString(lua)
@@ -192,43 +231,12 @@ def mymethod(*args):
192231
for arg in args:
193232
pushSomething(lua, arg)
194233
lua.call(len(args) + 1, 1) # +1 for self
195-
lua.pushValue(-1)
196-
pushGlobal(lua, 'torch', 'type')
197-
lua.insert(-2)
198-
lua.call(1, 1)
199-
returntype = popString(lua)
200234
# this is getting a bit recursive :-P
201235
# print('cythonClasses', cythonClasses)
202-
if returntype in cythonClasses:
203-
popFunction = cythonClasses[returntype]['popFunction']
204-
res = popFunction()
205-
topEnd = lua.getTop()
206-
assert topStart == topEnd
207-
return res
208-
elif returntype == 'number':
209-
res = lua.toNumber(-1)
210-
lua.remove(-1)
211-
topEnd = lua.getTop()
212-
assert topStart == topEnd
213-
return res
214-
elif returntype == 'string':
215-
res = popString(lua)
216-
topEnd = lua.getTop()
217-
assert topStart == topEnd
218-
return res
219-
elif returntype in luaClasses:
220-
returnobject = luaClasses[returntype](_fromLua=True)
221-
registerObject(lua, returnobject)
222-
topEnd = lua.getTop()
223-
assert topStart == topEnd
224-
return returnobject
225-
elif returntype == 'nil':
226-
lua.remove(-1)
227-
topEnd = lua.getTop()
228-
assert topStart == topEnd
229-
return None
230-
else:
231-
raise Exception('return type ' + str(returntype) + ' not implemented')
236+
res = popSomething(lua)
237+
topEnd = lua.getTop()
238+
assert topStart == topEnd
239+
return res
232240
lua.remove(-1)
233241
topEnd = lua.getTop()
234242
assert topStart == topEnd
@@ -241,16 +249,6 @@ def mymethod(*args):
241249
else:
242250
raise Exception('handling type ' + typename + ' not implemented')
243251

244-
class Table(LuaClass):
245-
def __init__(self, _fromLua=False):
246-
# print('Table.__init__')
247-
if not _fromLua:
248-
name = self.__class__.__name__
249-
super(self.__class__, self).__init__(nameList=['nn', name])
250-
else:
251-
self.__dict__['__objectId'] = getNextObjectId()
252-
self.luaclass = 'table'
253-
254252
class Linear(LuaClass):
255253
def __init__(self, numIn=1, numOut=1, _fromLua=False):
256254
# print('Linear.__init__')
@@ -360,7 +358,7 @@ def __init__(self, _fromLua=False):
360358
luaClasses['nn.SpatialMaxPooling'] = SpatialMaxPooling
361359
luaClasses['nn.ReLU'] = ReLU
362360
luaClasses['nn.Tanh'] = Tanh
363-
luaClasses['table'] = Table
361+
# luaClasses['table'] = Table
364362

365363
luaClassesReverse = {}
366364
def populateLuaClassesReverse():

0 commit comments

Comments
 (0)