-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpointer-nets.py
More file actions
134 lines (96 loc) · 2.63 KB
/
pointer-nets.py
File metadata and controls
134 lines (96 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:light
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.8.1
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---
# ## Config stuff
# %autosave 0
# %matplotlib inline
# %config InlineBackend.figure_format = 'svg'
# ## Common code
# +
from pprint import pprint
from matplotlib.text import Text
import matplotlib.pyplot as plt
import numpy as np
def plot_points(points):
plt.xlim(-0.5, 1.5)
plt.ylim(-0.5, 1.5)
plt.scatter(points[:, 0], points[:, 1])
ax = plt.gca()
# label points
label_points(points, "p")
def label_points(points, label, y_off=0):
ax = plt.gca()
for i, (x, y) in enumerate(points):
ax.add_artist(Text(x, y + y_off, f"${label}_{{{i}}}$"))
# -
# ## Convex-Hull demo
# +
from matplotlib.patches import Polygon
from convex_hull_dataset import get_points, get_verts
def get_hull(points):
verts = get_verts(points)
return points[verts]
def plot_hull(verts):
plt.scatter(verts[:, 0], verts[:, 1], s=4)
# surrounding polygon
ax = plt.gca()
poly = Polygon(verts, color=(0.8, 0.2, 0.2, 0.5))
ax.add_artist(poly)
# label vertices
label_points(verts, "v", 0.075)
def hull_demo():
points = get_points()
verts = get_hull(points)
plot_points(points)
plot_hull(verts)
hull_demo()
# +
from convex_hull_dataset import ConvexHullSample
import torch
import ptr_network
model = torch.load('trained_model.pt')
def hull_model_demo(size=8):
(sample,) = ConvexHullSample.create_samples(1,size)
model_result = model(sample.points)
points = np.stack([p.numpy()[:2] for p in sample.points])
print(model_result.decoded_seq)
# verts = points[model_result.decoded_seq[:-1]]
verts = points[model_result.decoded_seq]
plot_points(points)
plot_hull(verts)
hull_model_demo(15)
# -
# ## Delaunay Demo
# +
from scipy.spatial import Delaunay
from matplotlib.collections import PolyCollection
import matplotlib.cm as cm
def get_delaunay(points):
return sorted(list(map(list, Delaunay(points).simplices)))
def plot_delaunay(points):
delaunay = get_delaunay(points)
# pprint(delaunay)
collection = PolyCollection(
[points[simplex] for simplex in delaunay],
edgecolors="black",
cmap=cm.autumn,
)
collection.set_array(np.arange(len(delaunay)))
plt.gca().add_collection(collection)
plt.colorbar(collection)
def delaunay_demo():
points = get_points(10)
plot_points(points)
plot_delaunay(points)
delaunay_demo()