Skip to content

Commit 1561715

Browse files
committed
Sync
1 parent 7d907cb commit 1561715

File tree

3 files changed

+394
-63
lines changed

3 files changed

+394
-63
lines changed
Lines changed: 31 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
package org.mwg.experiments.kdtree;
22

33

4-
import org.mwg.Callback;
54
import org.mwg.experiments.smartgridprofiling.utility.GaussianProfile;
65
import org.mwg.ml.common.distance.GaussianDistance;
76

87
public class GaussianTreeNode extends GaussianProfile {
98

10-
KDNode root;
9+
KDNodeSync root;
1110

1211
public void setPrecisions(double[] precisions) {
1312
this.precisions = precisions;
@@ -16,92 +15,61 @@ public void setPrecisions(double[] precisions) {
1615
double[] precisions;
1716

1817

19-
2018
public GaussianTreeNode() {
2119
}
2220

2321

24-
25-
26-
public void internalLearn(final double[] values, final double[] features, final Callback<Boolean> callback) {
22+
public void internalLearn(final double[] values, final double[] features) {
2723
super.learn(values);
2824

2925
if (root == null) {
30-
root = new KDNode();
26+
root = new KDNodeSync();
3127
root.setDistance(new GaussianDistance(precisions));
3228
root.setThreshold(1.001);
3329

3430
GaussianProfile profile = new GaussianProfile();
3531
profile.learn(values);
36-
root.insert(features, profile, new Callback<Boolean>() {
37-
@Override
38-
public void on(Boolean result) {
39-
callback.on(true);
40-
}
41-
});
32+
root.insert(features, profile);
4233
} else {
43-
root.nearestWithinDistance(features, new Callback<Object>() {
44-
@Override
45-
public void on(Object result) {
46-
if (result != null) {
47-
GaussianProfile profile = (GaussianProfile) result;
48-
profile.learn(values);
49-
if (callback != null) {
50-
callback.on(true);
51-
}
52-
} else {
53-
GaussianProfile profile = new GaussianProfile();
54-
profile.learn(values);
55-
root.insert(features, profile, new Callback<Boolean>() {
56-
@Override
57-
public void on(Boolean result) {
58-
if (callback != null) {
59-
callback.on(true);
60-
}
61-
}
62-
});
63-
}
64-
65-
}
66-
});
34+
Object result = root.nearestWithinDistance(features);
35+
if (result != null) {
36+
GaussianProfile profile = (GaussianProfile) result;
37+
profile.learn(values);
38+
} else {
39+
GaussianProfile profile = new GaussianProfile();
40+
profile.learn(values);
41+
root.insert(features, profile);
42+
}
43+
6744
}
6845
}
6946

7047

71-
public int getNumNodes(){
48+
public int getNumNodes() {
7249
return root.getNum();
7350
}
7451

7552

76-
public void predictValue(double[] values, Callback<Double> callback){
77-
if(callback==null){
78-
return;
79-
}
53+
public double predictValue(double[] values) {
54+
8055

8156
double[] features = new double[values.length - 1];
8257
System.arraycopy(values, 0, features, 0, values.length - 1);
83-
if(root==null){
84-
callback.on(null);
85-
return;
58+
if (root == null) {
59+
return 0;
60+
} else {
61+
Object result = root.nearestWithinDistance(features);
62+
if (result != null) {
63+
GaussianProfile profile = (GaussianProfile) result;
64+
double[] avg = profile.getAvg();
65+
Double res = avg[avg.length - 1];
66+
return res;
67+
} else {
68+
double[] avg = getAvg();
69+
Double res = avg[avg.length - 1];
70+
return res;
71+
}
8672
}
87-
else {
88-
root.nearestWithinDistance(features, new Callback<Object>() {
89-
@Override
90-
public void on(Object result) {
91-
if (result != null) {
92-
GaussianProfile profile = (GaussianProfile) result;
93-
double[] avg = profile.getAvg();
94-
Double res = avg[avg.length - 1];
95-
callback.on(res);
96-
}
97-
else {
98-
double[] avg=getAvg();
99-
Double res= avg[avg.length-1];
100-
callback.on(res);
101-
}
102-
}
103-
});
10473

105-
}
10674
}
10775
}
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
package org.mwg.experiments.kdtree;
2+
3+
4+
import org.mwg.ml.common.distance.Distance;
5+
6+
/**
7+
* Created by assaad on 29/06/16.
8+
*/
9+
public class KDNodeSync {
10+
11+
12+
KDNodeSync right;
13+
KDNodeSync left;
14+
double[] key;
15+
Object value;
16+
17+
public int getNum() {
18+
return num;
19+
}
20+
21+
private int num;
22+
23+
24+
public void setThreshold(double threshold) {
25+
this.threshold = threshold;
26+
}
27+
28+
double threshold;
29+
Distance distance;
30+
31+
32+
public void setDistance(Distance d) {
33+
this.distance = d;
34+
}
35+
36+
private Distance getDistance() {
37+
return distance;
38+
}
39+
40+
41+
public void insert(final double[] key, final Object value) {
42+
final int dim = key.length;
43+
44+
if (key.length != dim) {
45+
throw new RuntimeException("Key size should always be the same");
46+
}
47+
48+
Distance distance = getDistance();
49+
internalInsert(this, this, distance, key, 0, dim, threshold, value);
50+
}
51+
52+
53+
public Object nearest(final double[] key) {
54+
55+
// initial call is with infinite hyper-rectangle and max distance
56+
HRect hr = HRect.infiniteHRect(key.length);
57+
double max_dist_sqd = Double.MAX_VALUE;
58+
59+
NearestNeighborList nnl = new NearestNeighborList(1);
60+
Distance distance = getDistance();
61+
internalNearest(this, distance, key, hr, max_dist_sqd, 0, key.length, threshold, nnl);
62+
63+
Object res = nnl.getHighest();
64+
return res;
65+
}
66+
67+
public Object nearestWithinDistance(final double[] key) {
68+
69+
70+
// initial call is with infinite hyper-rectangle and max distance
71+
HRect hr = HRect.infiniteHRect(key.length);
72+
double max_dist_sqd = Double.MAX_VALUE;
73+
74+
NearestNeighborList nnl = new NearestNeighborList(1);
75+
Distance distance = getDistance();
76+
internalNearest(this, distance, key, hr, max_dist_sqd, 0, key.length, threshold, nnl);
77+
78+
if (nnl.getBestDistance() <= threshold) {
79+
Object res = nnl.getHighest();
80+
return res;
81+
} else {
82+
return null;
83+
}
84+
}
85+
86+
87+
public Object[] nearestN(final double[] key, final int n) {
88+
89+
HRect hr = HRect.infiniteHRect(key.length);
90+
double max_dist_sqd = Double.MAX_VALUE;
91+
92+
NearestNeighborList nnl = new NearestNeighborList(n);
93+
Distance distance = getDistance();
94+
internalNearest(this, distance, key, hr, max_dist_sqd, 0, key.length, threshold, nnl);
95+
96+
Object[] res = nnl.getAllNodes();
97+
return res;
98+
99+
100+
}
101+
102+
103+
private static void internalNearest(KDNodeSync node, final Distance distance, double[] target, HRect hr, double max_dist_sqd, int lev, int dim, double err, NearestNeighborList nnl) {
104+
// 1. if kd is empty exit.
105+
if (node == null) {
106+
return;
107+
}
108+
109+
double[] pivot = node.key;
110+
if (pivot == null) {
111+
return;
112+
}
113+
114+
115+
// 2. s := split field of kd
116+
int s = lev % dim;
117+
118+
// 3. pivot := dom-elt field of kd
119+
120+
double pivot_to_target = distance.measure(pivot, target);
121+
122+
// 4. Cut hr into to sub-hyperrectangles left-hr and right-hr.
123+
// The cut plane is through pivot and perpendicular to the s
124+
// dimension.
125+
HRect left_hr = hr; // optimize by not cloning
126+
HRect right_hr = (HRect) hr.clone();
127+
left_hr.max[s] = pivot[s];
128+
right_hr.min[s] = pivot[s];
129+
130+
// 5. target-in-left := target_s <= pivot_s
131+
boolean target_in_left = target[s] < pivot[s];
132+
133+
KDNodeSync nearer_kd;
134+
HRect nearer_hr;
135+
KDNodeSync further_kd;
136+
HRect further_hr;
137+
138+
// 6. if target-in-left then
139+
// 6.1. nearer-kd := left field of kd and nearer-hr := left-hr
140+
// 6.2. further-kd := right field of kd and further-hr := right-hr
141+
if (target_in_left) {
142+
nearer_kd = node.left;
143+
nearer_hr = left_hr;
144+
further_kd = node.right;
145+
further_hr = right_hr;
146+
}
147+
//
148+
// 7. if not target-in-left then
149+
// 7.1. nearer-kd := right field of kd and nearer-hr := right-hr
150+
// 7.2. further-kd := left field of kd and further-hr := left-hr
151+
else {
152+
nearer_kd = node.right;
153+
nearer_hr = right_hr;
154+
further_kd = node.left;
155+
further_hr = left_hr;
156+
}
157+
158+
// 8. Recursively call Nearest Neighbor with paramters
159+
// (nearer-kd, target, nearer-hr, max-dist-sqd), storing the
160+
// results in nearest and dist-sqd
161+
//nnbr(nearer_kd, target, nearer_hr, max_dist_sqd, lev + 1, K, nnl);
162+
internalNearest(nearer_kd, distance, target, nearer_hr, max_dist_sqd, lev + 1, dim, err, nnl);
163+
164+
165+
double dist_sqd;
166+
167+
if (!nnl.isCapacityReached()) {
168+
dist_sqd = Double.MAX_VALUE;
169+
} else {
170+
dist_sqd = nnl.getMaxPriority();
171+
}
172+
173+
// 9. max-dist-sqd := minimum of max-dist-sqd and dist-sqd
174+
max_dist_sqd = Math.min(max_dist_sqd, dist_sqd);
175+
176+
// 10. A nearer point could only lie in further-kd if there were some
177+
// part of further-hr within distance sqrt(max-dist-sqd) of
178+
// target. If this is the case then
179+
double[] closest = further_hr.closest(target);
180+
if (distance.measure(closest, target) < max_dist_sqd) {
181+
182+
// 10.1 if (pivot-target)^2 < dist-sqd then
183+
if (pivot_to_target < dist_sqd) {
184+
185+
// 10.1.2 dist-sqd = (pivot-target)^2
186+
dist_sqd = pivot_to_target;
187+
nnl.insert(node.value, dist_sqd);
188+
189+
// 10.1.3 max-dist-sqd = dist-sqd
190+
// max_dist_sqd = dist_sqd;
191+
if (nnl.isCapacityReached()) {
192+
max_dist_sqd = nnl.getMaxPriority();
193+
} else {
194+
max_dist_sqd = Double.MAX_VALUE;
195+
}
196+
}
197+
198+
// 10.2 Recursively call Nearest Neighbor with parameters
199+
// (further-kd, target, further-hr, max-dist_sqd),
200+
// storing results in temp-nearest and temp-dist-sqd
201+
//nnbr(further_kd, target, further_hr, max_dist_sqd, lev + 1, K, nnl);
202+
internalNearest(further_kd, distance, target, further_hr, max_dist_sqd, lev + 1, dim, err, nnl);
203+
}
204+
205+
206+
}
207+
208+
209+
private static boolean internalInsert(final KDNodeSync node, final KDNodeSync root, final Distance distance, final double[] key, final int lev, final int dim, final double err, final Object value) {
210+
211+
double[] tk = node.key;
212+
if (tk == null) {
213+
node.key = key.clone();
214+
node.value = value;
215+
216+
if (node == root) {
217+
node.num = 1;
218+
}
219+
return true;
220+
221+
} else if (distance.measure(key, tk) < err) {
222+
node.value = value;
223+
return true;
224+
} else if (key[lev] > tk[lev]) {
225+
//check right
226+
if (node.right == null) {
227+
KDNodeSync rightNode = new KDNodeSync();
228+
rightNode.key = key.clone();
229+
rightNode.value = value;
230+
node.right = rightNode;
231+
root.num = root.num + 1;
232+
return true;
233+
} else {
234+
internalInsert(node.right, root, distance, key, (lev + 1) % dim, dim, err, value);
235+
return true;
236+
}
237+
238+
} else {
239+
if (node.left == null) {
240+
KDNodeSync leftNode = new KDNodeSync();
241+
leftNode.key = key.clone();
242+
leftNode.value = value;
243+
node.left = leftNode;
244+
root.num = root.num + 1;
245+
return true;
246+
} else {
247+
internalInsert(node.left, root, distance, key, (lev + 1) % dim, dim, err, value);
248+
return true;
249+
}
250+
}
251+
}
252+
253+
254+
}

0 commit comments

Comments
 (0)