Skip to content

Commit c59a2d3

Browse files
committed
backward incompatible change:
- nn modules now in `PyTorchAug.nn`, rather than in `PyTorchAug` directly new functionality: - anything in `nn` can now be used, without having to explicitly add to the program - this also paves the way for being able to easily (and perhaps automatically) handle other types, like nnx etc
1 parent a2b7dff commit c59a2d3

File tree

3 files changed

+57
-32
lines changed

3 files changed

+57
-32
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get_file_datetime(filepath):
139139

140140
setup(
141141
name='PyTorch',
142-
version='2.10.0-SNAPSHOT',
142+
version='3.0.0',
143143
author='Hugh Perkins',
144144
author_email='hughperkins@gmail.com',
145145
description=(

src/PyTorchAug.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ def getNextObjectId():
1212
nextObjectId += 1
1313
return res
1414

15-
class Linear(object):
16-
def __init__(self):
17-
print('Linear.__init__')
18-
19-
def __attr__(self):
20-
print('Linear.__attr__')
21-
2215
def pushGlobal(lua, name1, name2=None, name3=None):
2316
lua.getGlobal(name1)
2417
if name2 is None:
@@ -88,6 +81,11 @@ def pushSomething(lua, something):
8881
pushObject(lua, something)
8982
return
9083

84+
# print('type(something)', type(something))
85+
# print('str(type(something))', str(type(something)))
86+
# if 'nn.' in str(type(something)):
87+
# setupNnClass(
88+
9189
raise Exception('pushing type ' + str(type(something)) + ' not implemented, value ', something)
9290

9391
def popSomething(lua, self=None, name=None):
@@ -235,15 +233,42 @@ def __init__(self, *args, _fromLua=False, **kwargs):
235233
renamedClass = type(AnNnClass)(nnClassName, (AnNnClass,), {})
236234
return renamedClass
237235

238-
239-
luaClasses = {}
240-
nnClasses = [
241-
'Linear', 'ClassNLLCriterion', 'MSECriterion', 'Sequential', 'LogSoftMax',
242-
'Reshape', 'SpatialConvolutionMM', 'SpatialMaxPooling', 'ReLU', 'Tanh']
243-
for nnClassName in nnClasses:
236+
def setupNnClass(nnClassName):
244237
nnClass = loadNnClass(nnClassName)
245238
globals()[nnClassName] = nnClass
246239
luaClasses['nn.' + nnClassName] = nnClass
240+
luaClassesReverse[nnClass] = 'nn.' + nnClassName
241+
return nnClass
242+
243+
#def mygetattr():
244+
# print('mygetattr()')
245+
246+
#def __getattr__():
247+
# print('__getattr__')
248+
249+
class Nn(object):
250+
def __init__(self):
251+
self.classes = {}
252+
253+
def __getattr__(self, name):
254+
# print('Nn.__getattr__', name)
255+
if name not in self.classes:
256+
self.classes[name] = setupNnClass(name)
257+
thisClass = self.classes[name]
258+
# print('thisClass', thisClass)
259+
return thisClass
260+
261+
nn = Nn()
262+
263+
luaClasses = {}
264+
#nnClasses = [
265+
# 'Linear', 'ClassNLLCriterion', 'MSECriterion', 'Sequential', 'LogSoftMax',
266+
# 'Reshape', 'SpatialConvolutionMM', 'SpatialMaxPooling', 'ReLU', 'Tanh']
267+
#for nnClassName in nnClasses:
268+
# setupNnClass(nnClassName)
269+
# nnClass = loadNnClass(nnClassName)
270+
# globals()[nnClassName] = nnClass
271+
# luaClasses['nn.' + nnClassName] = nnClass
247272

248273

249274
luaClassesReverse = {}
@@ -252,7 +277,7 @@ def populateLuaClassesReverse():
252277
for name in luaClasses:
253278
classtype = luaClasses[name]
254279
luaClassesReverse[classtype] = name
255-
populateLuaClassesReverse()
280+
#populateLuaClassesReverse()
256281

257282
cythonClasses = {}
258283
cythonClasses['torch.FloatTensor'] = {'popFunction': PyTorch._popFloatTensor}

test/test_pynn.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import print_function, division
22
import PyTorch
33

4-
from PyTorchAug import *
4+
from PyTorchAug import nn
55

66
def test_pynn():
77
PyTorch.manualSeed(123)
8-
linear = Linear(3, 5)
8+
linear = nn.Linear(3, 5)
99
linear
1010
print('linear', linear)
1111
print('linear.weight', linear.weight)
@@ -20,12 +20,12 @@ def test_pynn():
2020
gradInput = linear.updateGradInput(input, output)
2121
print('gradInput', gradInput)
2222

23-
criterion = ClassNLLCriterion()
23+
criterion = nn.ClassNLLCriterion()
2424
print('criterion', criterion)
2525

2626
print('dir(linear)', dir(linear))
2727

28-
mlp = Sequential()
28+
mlp = nn.Sequential()
2929
mlp.add(linear)
3030

3131
output = mlp.forward(input)
@@ -39,23 +39,23 @@ def test_pynn():
3939

4040
numpy.random.seed(123)
4141

42-
mlp = Sequential()
42+
mlp = nn.Sequential()
4343

44-
mlp.add(SpatialConvolutionMM(1, 16, 5, 5, 1, 1, 2, 2))
45-
mlp.add(ReLU())
46-
mlp.add(SpatialMaxPooling(3, 3, 3, 3))
44+
mlp.add(nn.SpatialConvolutionMM(1, 16, 5, 5, 1, 1, 2, 2))
45+
mlp.add(nn.ReLU())
46+
mlp.add(nn.SpatialMaxPooling(3, 3, 3, 3))
4747

48-
mlp.add(SpatialConvolutionMM(16, 32, 3, 3, 1, 1, 1, 1))
49-
mlp.add(ReLU())
50-
mlp.add(SpatialMaxPooling(2, 2, 2, 2))
48+
mlp.add(nn.SpatialConvolutionMM(16, 32, 3, 3, 1, 1, 1, 1))
49+
mlp.add(nn.ReLU())
50+
mlp.add(nn.SpatialMaxPooling(2, 2, 2, 2))
5151

52-
mlp.add(Reshape(32 * 4 * 4))
53-
mlp.add(Linear(32 * 4 * 4, 150))
54-
mlp.add(Tanh())
55-
mlp.add(Linear(150, 10))
56-
mlp.add(LogSoftMax())
52+
mlp.add(nn.Reshape(32 * 4 * 4))
53+
mlp.add(nn.Linear(32 * 4 * 4, 150))
54+
mlp.add(nn.Tanh())
55+
mlp.add(nn.Linear(150, 10))
56+
mlp.add(nn.LogSoftMax())
5757

58-
criterion = ClassNLLCriterion()
58+
criterion = nn.ClassNLLCriterion()
5959
print('got criterion')
6060

6161
learningRate = 0.02

0 commit comments

Comments
 (0)