Skip to content

Commit c1a8b0c

Browse files
committed
Can pass tables to lua functions now ,from python
1 parent 7522c32 commit c1a8b0c

File tree

12 files changed

+284
-224
lines changed

12 files changed

+284
-224
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ from __future__ import print_function, division
125125

126126
# Recent news
127127

128+
5 March:
129+
* added `PyTorchHelpers.load_lua_class(lua_filename, lua_classname)` to easily import a lua class from a lua file
130+
* can pass parameters to lua class constructors, from python
131+
* can pass tables to lua functions, from python (pass in as python dictionaries, become lua tables)
132+
128133
2 March:
129134
* removed requirements on Cython, Jinja2 for installation
130135

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_file_datetime(filepath):
139139

140140
setup(
141141
name='PyTorch',
142-
version='2.7.0',
142+
version='2.8.0',
143143
author='Hugh Perkins',
144144
author_email='hughperkins@gmail.com',
145145
description=(

simpleexample/luabit.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,9 @@ function Luabit:getOut(inTensor, outSize, kernelSize)
2121
return out
2222
end
2323

24+
function Luabit:printTable(sometable)
25+
for k, v in pairs(sometable) do
26+
print('Luabit:printTable ', k, v)
27+
end
28+
end
29+

simpleexample/pybit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,5 @@
2323
outTensor = luaout.asNumpyTensor()
2424
print('outTensor', outTensor)
2525

26+
luabit.printTable({'color': 'red', 'weather': 'sunny', 'anumber': 10, 'afloat': 1.234})
27+

src/PyTorch.cpp

Lines changed: 21 additions & 52 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/PyTorchAug.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,50 @@ def torchType(lua, pos):
6868
lua.call(1, 1)
6969
return popString(lua)
7070

71+
def pushSomething(lua, something):
72+
if isinstance(something, int):
73+
lua.pushNumber(something)
74+
return
75+
76+
if isinstance(something, float):
77+
lua.pushNumber(something)
78+
return
79+
80+
if isinstance(something, str):
81+
lua.pushString(something)
82+
return
83+
84+
if isinstance(something, dict):
85+
pushTable(lua, something)
86+
return
87+
88+
for pythonClass in pushFunctionByPythonClass:
89+
if isinstance(something, pythonClass):
90+
pushFunctionByPythonClass[pythonClass](something)
91+
return
92+
93+
if type(something) in luaClassesReverse:
94+
pushObject(lua, something)
95+
return
96+
97+
raise Exception('pushing type ' + str(type(something)) + ' not implemented, value ', something)
98+
99+
def pushTable(lua, table):
100+
lua.newTable()
101+
for k, v in table.items():
102+
pushSomething(lua, k)
103+
pushSomething(lua, v)
104+
lua.setTable(3)
105+
71106
class LuaClass(object):
72107
def __init__(self, *args, nameList):
73-
# print('LuaClass.__init__()')
74108
lua = PyTorch.getGlobalState().getLua()
75109
# self.luaclass = luaclass
76110
self.__dict__['__objectId'] = getNextObjectId()
77111
topStart = lua.getTop()
78112
pushGlobalFromList(lua, nameList)
79113
for arg in args:
80-
if isinstance(arg, int):
81-
lua.pushNumber(arg)
82-
elif isinstance(arg, str):
83-
lua.pushString(arg)
84-
else:
85-
raise Exception('arg type ' + str(type(arg)) + ' not implemented')
114+
pushSomething(lua, arg)
86115
lua.call(len(args), 1)
87116
registerObject(lua, self)
88117

@@ -160,29 +189,8 @@ def mymethod(*args):
160189
pushObject(lua, self)
161190
lua.getField(-1, name)
162191
lua.insert(-2)
163-
# pushObject(lua, self)
164192
for arg in args:
165-
# print('arg', arg, type(arg))
166-
pushedArg = False
167-
for pythonClass in pushFunctionByPythonClass:
168-
if isinstance(arg, pythonClass):
169-
pushFunctionByPythonClass[pythonClass](arg)
170-
pushedArg = True
171-
break
172-
if not pushedArg and type(arg) in luaClassesReverse:
173-
pushObject(lua, arg)
174-
pushedArg = True
175-
if not pushedArg and isinstance(arg, float):
176-
lua.pushNumber(arg)
177-
pushedArg = True
178-
if not pushedArg and isinstance(arg, int):
179-
lua.pushNumber(arg)
180-
pushedArg = True
181-
if not pushedArg and isinstance(arg, str):
182-
lua.pushString(arg)
183-
pushedArg = True
184-
if not pushedArg:
185-
raise Exception('arg type ' + str(type(arg)) + ' not implemented')
193+
pushSomething(lua, arg)
186194
lua.call(len(args) + 1, 1) # +1 for self
187195
lua.pushValue(-1)
188196
pushGlobal(lua, 'torch', 'type')

src/Storage.cpp

Lines changed: 2 additions & 32 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)