@@ -96,13 +96,61 @@ def pushSomething(lua, something):
9696
9797 raise Exception ('pushing type ' + str (type (something )) + ' not implemented, value ' , something )
9898
99+ def popSomething (lua ):
100+ lua .pushValue (- 1 )
101+ pushGlobal (lua , 'torch' , 'type' )
102+ lua .insert (- 2 )
103+ lua .call (1 , 1 )
104+ typestring = popString (lua )
105+
106+ if typestring in cythonClasses :
107+ popFunction = cythonClasses [typestring ]['popFunction' ]
108+ res = popFunction ()
109+ return res
110+
111+ if typestring == 'number' :
112+ res = lua .toNumber (- 1 )
113+ lua .remove (- 1 )
114+ return res
115+
116+ if typestring == 'string' :
117+ res = popString (lua )
118+ return res
119+
120+ if typestring == 'table' :
121+ return popTable (lua )
122+
123+ if typestring in luaClasses :
124+ returnobject = luaClasses [typestring ](_fromLua = True )
125+ registerObject (lua , returnobject )
126+ return returnobject
127+
128+ if typestring == 'nil' :
129+ lua .remove (- 1 )
130+ return None
131+
132+ # raise Exception('pop type ' + str(typestring) + ' not implemented')
133+ print ('pop type ' + str (typestring ) + ' not implemented' )
134+
99135def pushTable (lua , table ):
100136 lua .newTable ()
101137 for k , v in table .items ():
102138 pushSomething (lua , k )
103139 pushSomething (lua , v )
104140 lua .setTable (- 3 )
105141
142+ def popTable (lua ):
143+ res = {}
144+ lua .pushNil ()
145+ while lua .next (- 2 ) != 0 :
146+ value = popSomething (lua )
147+ lua .pushValue (- 1 )
148+ key = popSomething (lua )
149+ res [key ] = value
150+ lua .remove (- 1 )
151+ return res
152+
153+
106154class LuaClass (object ):
107155 def __init__ (self , * args , nameList ):
108156 lua = PyTorch .getGlobalState ().getLua ()
@@ -115,22 +163,14 @@ def __init__(self, *args, nameList):
115163 lua .call (len (args ), 1 )
116164 registerObject (lua , self )
117165
118- # nameList = nameList[:]
119- # nameList.append('float')
120- # pushGlobalFromList(lua, nameList)
121- # pushObject(lua, self)
122- # lua.call(1, 0)
123-
124166 topEnd = lua .getTop ()
125167 assert topStart == topEnd
126168
127169 def __del__ (self ):
128170 name = self .__class__ .__name__
129- # print(name + '.__del__')
130171
131172 def __repr__ (self ):
132173 topStart = lua .getTop ()
133- # name = self.__class__.__name__
134174 luaClass = self .luaclass
135175 if luaClass == 'table' :
136176 return 'table'
@@ -141,7 +181,6 @@ def __repr__(self):
141181 pushGlobal (lua , splitLuaClass [0 ], splitLuaClass [1 ], '__tostring' )
142182 else :
143183 raise Exception ('not implemented: luaclass with more than 2 parts ' + luaClass )
144- # pushGlobal(lua, 'nn', name, '__tostring')
145184 pushObject (lua , self )
146185 lua .call (1 , 1 )
147186 res = popString (lua )
@@ -192,43 +231,12 @@ def mymethod(*args):
192231 for arg in args :
193232 pushSomething (lua , arg )
194233 lua .call (len (args ) + 1 , 1 ) # +1 for self
195- lua .pushValue (- 1 )
196- pushGlobal (lua , 'torch' , 'type' )
197- lua .insert (- 2 )
198- lua .call (1 , 1 )
199- returntype = popString (lua )
200234 # this is getting a bit recursive :-P
201235# print('cythonClasses', cythonClasses)
202- if returntype in cythonClasses :
203- popFunction = cythonClasses [returntype ]['popFunction' ]
204- res = popFunction ()
205- topEnd = lua .getTop ()
206- assert topStart == topEnd
207- return res
208- elif returntype == 'number' :
209- res = lua .toNumber (- 1 )
210- lua .remove (- 1 )
211- topEnd = lua .getTop ()
212- assert topStart == topEnd
213- return res
214- elif returntype == 'string' :
215- res = popString (lua )
216- topEnd = lua .getTop ()
217- assert topStart == topEnd
218- return res
219- elif returntype in luaClasses :
220- returnobject = luaClasses [returntype ](_fromLua = True )
221- registerObject (lua , returnobject )
222- topEnd = lua .getTop ()
223- assert topStart == topEnd
224- return returnobject
225- elif returntype == 'nil' :
226- lua .remove (- 1 )
227- topEnd = lua .getTop ()
228- assert topStart == topEnd
229- return None
230- else :
231- raise Exception ('return type ' + str (returntype ) + ' not implemented' )
236+ res = popSomething (lua )
237+ topEnd = lua .getTop ()
238+ assert topStart == topEnd
239+ return res
232240 lua .remove (- 1 )
233241 topEnd = lua .getTop ()
234242 assert topStart == topEnd
@@ -241,16 +249,6 @@ def mymethod(*args):
241249 else :
242250 raise Exception ('handling type ' + typename + ' not implemented' )
243251
244- class Table (LuaClass ):
245- def __init__ (self , _fromLua = False ):
246- # print('Table.__init__')
247- if not _fromLua :
248- name = self .__class__ .__name__
249- super (self .__class__ , self ).__init__ (nameList = ['nn' , name ])
250- else :
251- self .__dict__ ['__objectId' ] = getNextObjectId ()
252- self .luaclass = 'table'
253-
254252class Linear (LuaClass ):
255253 def __init__ (self , numIn = 1 , numOut = 1 , _fromLua = False ):
256254 # print('Linear.__init__')
@@ -360,7 +358,7 @@ def __init__(self, _fromLua=False):
360358luaClasses ['nn.SpatialMaxPooling' ] = SpatialMaxPooling
361359luaClasses ['nn.ReLU' ] = ReLU
362360luaClasses ['nn.Tanh' ] = Tanh
363- luaClasses ['table' ] = Table
361+ # luaClasses['table'] = Table
364362
365363luaClassesReverse = {}
366364def populateLuaClassesReverse ():
0 commit comments