|
| 1 | +/* |
| 2 | + * Created on Aug 6, 2005 |
| 3 | + * |
| 4 | + */ |
| 5 | +package aima.learning.neural; |
| 6 | + |
| 7 | +import java.util.ArrayList; |
| 8 | +import java.util.List; |
| 9 | +import java.util.Arrays; |
| 10 | + |
| 11 | +import aima.learning.framework.Example; |
| 12 | +import aima.util.Pair; |
| 13 | + |
| 14 | +/** |
| 15 | + * @author Ravi Mohan |
| 16 | + * |
| 17 | + */ |
| 18 | + |
| 19 | +public class IrisDataSetNumerizer implements Numerizer { |
| 20 | + |
| 21 | + public Pair<List<Double>, List<Double>> numerize(Example e) { |
| 22 | + List<Double> input = new ArrayList<Double>(); |
| 23 | + List<Double> desiredOutput = new ArrayList<Double>(); |
| 24 | + |
| 25 | + double sepal_length = e.getAttributeValueAsDouble("sepal_length"); |
| 26 | + double sepal_width = e.getAttributeValueAsDouble("sepal_width"); |
| 27 | + double petal_length = e.getAttributeValueAsDouble("petal_length"); |
| 28 | + double petal_width = e.getAttributeValueAsDouble("petal_width"); |
| 29 | + |
| 30 | + input.add(sepal_length); |
| 31 | + input.add(sepal_width); |
| 32 | + input.add(petal_length); |
| 33 | + input.add(petal_width); |
| 34 | + |
| 35 | + String plant_category_string = e |
| 36 | + .getAttributeValueAsString("plant_category"); |
| 37 | + |
| 38 | + desiredOutput = convertCategoryToListOfDoubles(plant_category_string); |
| 39 | + |
| 40 | + Pair<List<Double>, List<Double>> io = new Pair<List<Double>, List<Double>>( |
| 41 | + input, desiredOutput); |
| 42 | + |
| 43 | + return io; |
| 44 | + } |
| 45 | + |
| 46 | + public String denumerize(List<Double> outputValue) { |
| 47 | + List<Double> rounded = new ArrayList<Double>(); |
| 48 | + for (Double d : outputValue) { |
| 49 | + rounded.add(round(d)); |
| 50 | + } |
| 51 | + if (rounded.equals(Arrays.asList(0.0, 0.0, 1.0))) { |
| 52 | + return "setosa"; |
| 53 | + } else if (rounded.equals(Arrays.asList(0.0, 1.0, 0.0))) { |
| 54 | + return "versicolor"; |
| 55 | + } else if (rounded.equals(Arrays.asList(1.0, 0.0, 0.0))) { |
| 56 | + return "virginica"; |
| 57 | + } else { |
| 58 | + return "unknown"; |
| 59 | + } |
| 60 | + } |
| 61 | + |
| 62 | + private double round(Double d) { |
| 63 | + if (d < 0) { |
| 64 | + return 0.0; |
| 65 | + } |
| 66 | + if (d > 1) { |
| 67 | + return 1.0; |
| 68 | + } else { |
| 69 | + return Math.round(d); |
| 70 | + } |
| 71 | + } |
| 72 | + |
| 73 | + private List<Double> convertCategoryToListOfDoubles( |
| 74 | + String plant_category_string) { |
| 75 | + if (plant_category_string.equals("setosa")) { |
| 76 | + return Arrays.asList(0.0, 0.0, 1.0); |
| 77 | + } else if (plant_category_string.equals("versicolor")) { |
| 78 | + return Arrays.asList(0.0, 1.0, 0.0); |
| 79 | + } else if (plant_category_string.equals("virginica")) { |
| 80 | + return Arrays.asList(1.0, 0.0, 0.0); |
| 81 | + } else { |
| 82 | + throw new RuntimeException("invalid plant category"); |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | +} |
0 commit comments