Skip to content

Commit 56cd1ac

Browse files
author
magicindian
committed
committing moved file
1 parent d1fe8d2 commit 56cd1ac

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Created on Aug 6, 2005
3+
*
4+
*/
5+
package aima.learning.neural;
6+
7+
import java.util.ArrayList;
8+
import java.util.List;
9+
import java.util.Arrays;
10+
11+
import aima.learning.framework.Example;
12+
import aima.util.Pair;
13+
14+
/**
15+
* @author Ravi Mohan
16+
*
17+
*/
18+
19+
public class IrisDataSetNumerizer implements Numerizer {
20+
21+
public Pair<List<Double>, List<Double>> numerize(Example e) {
22+
List<Double> input = new ArrayList<Double>();
23+
List<Double> desiredOutput = new ArrayList<Double>();
24+
25+
double sepal_length = e.getAttributeValueAsDouble("sepal_length");
26+
double sepal_width = e.getAttributeValueAsDouble("sepal_width");
27+
double petal_length = e.getAttributeValueAsDouble("petal_length");
28+
double petal_width = e.getAttributeValueAsDouble("petal_width");
29+
30+
input.add(sepal_length);
31+
input.add(sepal_width);
32+
input.add(petal_length);
33+
input.add(petal_width);
34+
35+
String plant_category_string = e
36+
.getAttributeValueAsString("plant_category");
37+
38+
desiredOutput = convertCategoryToListOfDoubles(plant_category_string);
39+
40+
Pair<List<Double>, List<Double>> io = new Pair<List<Double>, List<Double>>(
41+
input, desiredOutput);
42+
43+
return io;
44+
}
45+
46+
public String denumerize(List<Double> outputValue) {
47+
List<Double> rounded = new ArrayList<Double>();
48+
for (Double d : outputValue) {
49+
rounded.add(round(d));
50+
}
51+
if (rounded.equals(Arrays.asList(0.0, 0.0, 1.0))) {
52+
return "setosa";
53+
} else if (rounded.equals(Arrays.asList(0.0, 1.0, 0.0))) {
54+
return "versicolor";
55+
} else if (rounded.equals(Arrays.asList(1.0, 0.0, 0.0))) {
56+
return "virginica";
57+
} else {
58+
return "unknown";
59+
}
60+
}
61+
62+
private double round(Double d) {
63+
if (d < 0) {
64+
return 0.0;
65+
}
66+
if (d > 1) {
67+
return 1.0;
68+
} else {
69+
return Math.round(d);
70+
}
71+
}
72+
73+
private List<Double> convertCategoryToListOfDoubles(
74+
String plant_category_string) {
75+
if (plant_category_string.equals("setosa")) {
76+
return Arrays.asList(0.0, 0.0, 1.0);
77+
} else if (plant_category_string.equals("versicolor")) {
78+
return Arrays.asList(0.0, 1.0, 0.0);
79+
} else if (plant_category_string.equals("virginica")) {
80+
return Arrays.asList(1.0, 0.0, 0.0);
81+
} else {
82+
throw new RuntimeException("invalid plant category");
83+
}
84+
}
85+
86+
}

0 commit comments

Comments
 (0)