Skip to content

Commit 0959a5c

Browse files
author
Maksym Zavershynskyi
committed
Prepare TorchFold for PyPI
1 parent 72ae4ea commit 0959a5c

File tree

9 files changed

+35
-206
lines changed

9 files changed

+35
-206
lines changed

README.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
# pytorch-tools
2-
Tools for PyTorch
1+
# TorchFold
32

4-
## Torch Fold
3+
Blog post: http://near.ai/articles/2017-09-06-PyTorch-Dynamic-Batching/
54

65
Analogous to [TensorFlow Fold](https://github.com/tensorflow/fold), implements dynamic batching with super simple interface.
76
Replace every direct call in your computation to nn module with `f.add('function name', arguments)`.
8-
It will construct an optimized version of computation and on `f.apply` will dynmically batch and execute the computation on given nn module.
7+
It will construct an optimized version of computation and on `f.apply` will dynamically batch and execute the computation on given nn module.
98

10-
For example:
9+
## Installation
10+
We recommend using pip package manager:
11+
```
12+
pip install torchfold
13+
```
14+
15+
## Example
1116

1217
```
1318
f = torchfold.Fold()
@@ -35,9 +40,3 @@ For example:
3540
model = Model(...)
3641
f.apply(model, [[res]])
3742
```
38-
39-
## Embeddings
40-
41-
Many times you find yourself with new words in vocabulary as you are working on your model.
42-
Instead of re-training you can use `embeddings.expand_embeddings` to expand them on the fly with given vocabulary.
43-

examples/snli/spinn-example.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
from torch.autograd import Variable
99
import torch.nn as nn
1010
from torch import optim
11-
import torch.nn.functional as F
1211

1312
from torchtext import data
1413
from torchtext import datasets
1514

16-
from pytorch_tools import torchfold
15+
import torchfold
1716

1817

1918
parser = argparse.ArgumentParser(description='SPINN')
@@ -97,7 +96,6 @@ def is_leaf(self):
9796
def __repr__(self):
9897
return str(self.id) if self.is_leaf() else "(%s, %s)" % (self.left, self.right)
9998

100-
10199
def __init__(self, example, inputs_vocab, answer_vocab):
102100
self.label = answer_vocab.stoi[example.label] - 1
103101
queue = []

pytorch_tools/__init__.py

Whitespace-only changes.

pytorch_tools/embeddings.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

pytorch_tools/trainer.py

Lines changed: 0 additions & 91 deletions
This file was deleted.

setup.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
1-
from setuptools import setup
1+
from setuptools import setup, find_packages
2+
from os import path
3+
4+
VERSION = "0.1.0"
5+
6+
here = path.abspath(path.dirname(__file__))
7+
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
8+
long_description = f.read()
29

3-
VERSION = "0.0.1"
410

511
setup(
6-
name='pytorch-tools',
12+
name='torchfold',
713
version=VERSION,
8-
description='Tools for PyTorch',
9-
packages=['pytorch_tools'],
14+
description='Dynamic Batching with PyTorch',
15+
long_description=long_description,
16+
long_description_content_type='text/markdown',
17+
packages=find_packages(exclude=["*_test.py"]),
1018
license='Apache License, Version 2.0',
11-
author='Illia Polosukhin',
19+
author='Illia Polosukhin, NEAR Inc',
20+
author_email="illia@near.ai",
21+
project_urls={
22+
'Blog Post': "http://near.ai/articles/2017-09-06-PyTorch-Dynamic-Batching/",
23+
'Source': "https://github.com/nearai/torchfold",
24+
},
1225
)
1326

torchfold/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .torchfold import Fold, Unfold
Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
import torch
44
from torch.autograd import Variable
5-
import torch.nn as nn
6-
from torch import optim
7-
import torch.nn.functional as F
85

96

107
class Fold(object):
@@ -38,7 +35,6 @@ def __repr__(self):
3835
return "[%d:%d]%s" % (
3936
self.step, self.index, self.op)
4037

41-
4238
class ComputedResult(object):
4339
def __init__(self, batch_size, batched_result):
4440
self.batch_size = batch_size
@@ -79,7 +75,6 @@ def get(self, index, split_idx=-1):
7975
self.result[split_idx] = torch.chunk(self.result[split_idx], self.batch_size)
8076
return self.result[split_idx][index]
8177

82-
8378
def __init__(self, volatile=False, cuda=False):
8479
self.steps = collections.defaultdict(
8580
lambda: collections.defaultdict(list))
@@ -96,7 +91,7 @@ def add(self, op, *args):
9691
"""Add op to the fold."""
9792
self.total_nodes += 1
9893
if not all([isinstance(arg, (
99-
Fold.Node, int, torch.tensor._TensorBase, Variable)) for arg in args]):
94+
Fold.Node, int, torch.tensor._TensorBase, Variable)) for arg in args]):
10095
raise ValueError(
10196
"All args should be Tensor, Variable, int or Node, got: %s" % str(args))
10297
if args not in self.cached_nodes[op]:
@@ -116,8 +111,7 @@ def _batch_args(self, arg_lists, values):
116111
for arg_item in arg[1:])
117112

118113
if arg[0].batch:
119-
batched_arg = values[arg[0].step][arg[0]
120-
.op].try_get_batched(arg)
114+
batched_arg = values[arg[0].step][arg[0].op].try_get_batched(arg)
121115
if batched_arg is not None:
122116
res.append(batched_arg)
123117
else:
@@ -142,8 +136,7 @@ def _batch_args(self, arg_lists, values):
142136
if isinstance(arg_item, Fold.Node):
143137
assert arg_item.batch
144138
r.append(arg_item.get(values))
145-
elif isinstance(arg_item, (torch.tensor._TensorBase,
146-
Variable)):
139+
elif isinstance(arg_item, (torch.tensor._TensorBase, Variable)):
147140
r.append(arg_item)
148141
else:
149142
raise ValueError(
Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
1-
import collections
2-
31
import torch
42
from torch.autograd import Variable
53
import torch.nn as nn
6-
from torch import optim
7-
import torch.nn.functional as F
84

95
import torchfold
106

@@ -111,4 +107,4 @@ def test_rnn_optimized_chunking(self):
111107

112108

113109
if __name__ == "__main__":
114-
unittest.main()
110+
unittest.main()

0 commit comments

Comments
 (0)