Skip to content

Commit 96b1bae

Browse files
committed
Simplify loading lua classes: add PyTorchHelpers.load_lua_class(lua_filename, lua_classname)
1 parent dd5dd61 commit 96b1bae

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_file_datetime(filepath):
139139

140140
setup(
141141
name='PyTorch',
142-
version='2.5.0',
142+
version='2.6.0',
143143
author='Hugh Perkins',
144144
author_email='hughperkins@gmail.com',
145145
description=(
@@ -152,7 +152,7 @@ def get_file_datetime(filepath):
152152
install_requires=['numpy'],
153153
scripts=[],
154154
ext_modules=ext_modules,
155-
py_modules=['floattensor', 'PyTorchAug'],
155+
py_modules=['floattensor', 'PyTorchAug', 'PyTorchHelpers'],
156156
package_dir={'': 'src'}
157157
)
158158

simpleexample/pybit.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
import sys
22
import os
3-
import PyTorchAug
43
import PyTorch
4+
import PyTorchHelpers
55
import numpy as np
66

7-
PyTorch.require('luabit')
8-
class Luabit(PyTorchAug.LuaClass):
9-
def __init__(self, _fromLua=False):
10-
self.luaclass = 'Luabit'
11-
if not _fromLua:
12-
name = self.__class__.__name__
13-
super(self.__class__, self).__init__([name])
14-
else:
15-
self.__dict__['__objectId'] = getNextObjectId()
7+
Luabit = PyTorchHelpers.load_lua_class('luabit.lua', 'Luabit')
168

179
batchSize = 2
1810
numFrames = 4

src/PyTorchHelpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import PyTorch
2+
import PyTorchAug
3+
4+
def load_lua_class(lua_filename, lua_classname):
5+
module = lua_filename.replace('.lua', '')
6+
PyTorch.require(module)
7+
class LuaWrapper(PyTorchAug.LuaClass):
8+
def __init__(self, _fromLua=False):
9+
self.luaclass = lua_classname
10+
if not _fromLua:
11+
name = lua_classname
12+
super(self.__class__, self).__init__([name])
13+
else:
14+
self.__dict__['__objectId'] = getNextObjectId()
15+
return LuaWrapper
16+

0 commit comments

Comments
 (0)