3737import de .uni_mannheim .informatik .dws .winter .model .defaultmodel .Attribute ;
3838import de .uni_mannheim .informatik .dws .winter .model .defaultmodel .FeatureVectorDataSet ;
3939import de .uni_mannheim .informatik .dws .winter .model .defaultmodel .Record ;
40+ import de .uni_mannheim .informatik .dws .winter .model .defaultmodel .comparators .RecordComparator ;
4041import de .uni_mannheim .informatik .dws .winter .processing .Processable ;
42+ import weka .attributeSelection .AttributeSelection ;
43+ import weka .attributeSelection .GreedyStepwise ;
44+ import weka .attributeSelection .WrapperSubsetEval ;
4145import weka .classifiers .Classifier ;
4246import weka .classifiers .evaluation .Evaluation ;
4347import weka .core .DenseInstance ;
4448import weka .core .Instance ;
4549import weka .core .Instances ;
4650import weka .core .Utils ;
51+ import weka .core .pmml .PMMLFactory ;
4752
4853/**
4954 * Class that creates and applies a matching Rule based on supervised learning
@@ -62,21 +67,27 @@ public class WekaMatchingRule<RecordType extends Matchable, SchemaElementType ex
6267 private Classifier classifier ;
6368 private List <Comparator <RecordType , SchemaElementType >> comparators ;
6469
70+ // Handling of feature subset selection
71+ private boolean forwardSelection = false ;
72+ private boolean backwardSelection = false ;
73+ private AttributeSelection fs ;
74+
6575 public final String trainingSet = "trainingSet" ;
6676 public final String machtSet = "matchSet" ;
67-
68- // TODO Discuss finalThreshold --> Can be set via options -C <confidence factor for pruning>
77+
78+ // TODO Discuss finalThreshold --> Can be set via options -C <confidence
79+ // factor for pruning>
6980 /**
7081 * Create a MatchingRule, which can be trained using the Weka library for
7182 * identity resolution.
7283 *
7384 * @param finalThreshold
7485 * determines the confidence level, which needs to be exceeded by
7586 * the classifier, so that it can classify a record as match.
76- *
87+ *
7788 * @param classifierName
7889 * Has the name of a specific classifier from the Weka library.
79- *
90+ *
8091 * @param parameters
8192 * Hold the parameters to tune the classifier.
8293 */
@@ -95,6 +106,7 @@ public WekaMatchingRule(double finalThreshold, String classifierName, String par
95106 // create list for comparators
96107 this .comparators = new LinkedList <>();
97108 }
109+
98110
99111 public String [] getparameters () {
100112 return parameters ;
@@ -125,7 +137,10 @@ public void addComparator(Comparator<RecordType, SchemaElementType> comparator)
125137
126138 /**
127139 *
128- * learns the rule from parsed features in a cross validation
140+ * Learns the rule from parsed features in a cross validation and the set
141+ * parameters. Additionally feature subset selection is conducted, if the
142+ * parameters this.forwardSelection or this.backwardSelection are set
143+ * accordingly.
129144 *
130145 * @param features
131146 * Contains features to learn a classifier
@@ -135,13 +150,39 @@ public void addComparator(Comparator<RecordType, SchemaElementType> comparator)
135150 public Performance learnParameters (FeatureVectorDataSet features ) {
136151 // create training
137152 Instances trainingData = transformToWeka (features , this .trainingSet );
153+
138154 try {
139155 Evaluation eval = new Evaluation (trainingData );
156+ // apply feature subset selection
157+ if (this .forwardSelection || this .backwardSelection ) {
158+
159+ GreedyStepwise search = new GreedyStepwise ();
160+ search .setSearchBackwards (this .backwardSelection );
161+
162+ this .fs = new AttributeSelection ();
163+ WrapperSubsetEval wrapper = new WrapperSubsetEval ();
140164
165+ // Do feature subset selection, but using a 10-fold cross
166+ // validation
167+ wrapper .buildEvaluator (trainingData );
168+ wrapper .setClassifier (this .classifier );
169+ wrapper .setFolds (10 );
170+ wrapper .setThreshold (0.01 );
171+
172+ this .fs .setEvaluator (wrapper );
173+ this .fs .setSearch (search );
174+
175+ this .fs .SelectAttributes (trainingData );
176+
177+ trainingData = fs .reduceDimensionality (trainingData );
178+
179+ }
180+ // perform 10-fold Cross Validation to evaluate classifier
141181 eval .crossValidateModel (this .classifier , trainingData , 10 , new Random (1 ));
142- this .classifier .buildClassifier (trainingData );
143182 System .out .println (eval .toSummaryString ("\n Results\n \n " , false ));
144-
183+
184+ this .classifier .buildClassifier (trainingData );
185+
145186 int truePositive = (int ) eval .numTruePositives (trainingData .classIndex ());
146187 int falsePositive = (int ) eval .numFalsePositives (trainingData .classIndex ());
147188 int falseNegative = (int ) eval .numFalseNegatives (trainingData .classIndex ());
@@ -273,7 +314,18 @@ public Record generateFeatures(RecordType record1, RecordType record2,
273314
274315 double similarity = comp .compare (record1 , record2 , null );
275316
276- String name = String .format ("[%d] %s" , i , comp .getClass ().getSimpleName ());
317+ String attribute1 = "" ;
318+ String attribute2 = "" ;
319+ try {
320+ attribute1 = ((RecordComparator )comp ).getAttributeRecord1 ().toString ();
321+ attribute2 = ((RecordComparator )comp ).getAttributeRecord2 ().toString ();
322+
323+ } catch (ClassCastException e ) {
324+ // Not possible to add attribute names
325+ //e.printStackTrace();
326+ }
327+
328+ String name = String .format ("[%d] %s %s %s" , i , comp .getClass ().getSimpleName (), attribute1 , attribute2 );
277329 Attribute att = null ;
278330 for (Attribute elem : features .getSchema ().get ()) {
279331 if (elem .toString ().equals (name )) {
@@ -315,9 +367,18 @@ public Correspondence<RecordType, SchemaElementType> apply(RecordType record1, R
315367 FeatureVectorDataSet matchSet = this .initialiseFeatures ();
316368 Record matchRecord = generateFeatures (record1 , record2 , schemaCorrespondences , matchSet );
317369
370+ // transform entry for classification.
318371 matchSet .add (matchRecord );
319372 Instances matchInstances = this .transformToWeka (matchSet , this .machtSet );
320-
373+
374+ // reduce dimensions if feature subset selection was applied before.
375+ if ((this .backwardSelection || this .forwardSelection ) && this .fs != null )
376+ try {
377+ matchInstances = this .fs .reduceDimensionality (matchInstances );
378+ } catch (Exception e1 ) {
379+ e1 .printStackTrace ();
380+ }
381+ // Apply matching rule
321382 try {
322383 double result = this .classifier .classifyInstance (matchInstances .firstInstance ());
323384 return new Correspondence <RecordType , SchemaElementType >(record1 , record2 , result , schemaCorrespondences );
@@ -365,14 +426,20 @@ public void storeModel(File location) {
365426 @ Override
366427 public void readModel (File location ) {
367428 // deserialize model
429+
368430 try {
369431 ObjectInputStream ois = new ObjectInputStream (new FileInputStream (location ));
370432 this .setClassifier ((Classifier ) ois .readObject ());
371433 ois .close ();
372434 } catch (FileNotFoundException e ) {
373435 e .printStackTrace ();
374436 } catch (IOException e ) {
375- e .printStackTrace ();
437+ try {
438+ this .setClassifier ((Classifier ) PMMLFactory .getPMMLModel (location , null ));
439+
440+ } catch (Exception e1 ) {
441+ e1 .printStackTrace ();
442+ }
376443 } catch (ClassNotFoundException e ) {
377444 e .printStackTrace ();
378445 }
@@ -393,7 +460,7 @@ public double compare(RecordType record1, RecordType record2,
393460
394461 /**
395462 * Create a new FeatureVectorDataSet with the corresponding features, which
396- * result form the added comparators.
463+ * result from the added comparators.
397464 *
398465 * @see de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule#initialiseFeatures()
399466 */
@@ -405,8 +472,19 @@ public FeatureVectorDataSet initialiseFeatures() {
405472 for (int i = 0 ; i < comparators .size (); i ++) {
406473
407474 Comparator <RecordType , SchemaElementType > comp = comparators .get (i );
408-
409- String name = String .format ("[%d] %s" , i , comp .getClass ().getSimpleName ());
475+
476+ String attribute1 = "" ;
477+ String attribute2 = "" ;
478+ try {
479+ attribute1 = ((RecordComparator )comp ).getAttributeRecord1 ().toString ();
480+ attribute2 = ((RecordComparator )comp ).getAttributeRecord2 ().toString ();
481+
482+ } catch (ClassCastException e ) {
483+ // Not possible to add attribute names
484+ //e.printStackTrace();
485+ }
486+
487+ String name = String .format ("[%d] %s %s %s" , i , comp .getClass ().getSimpleName (), attribute1 , attribute2 );
410488
411489 Attribute att = new Attribute (name );
412490 result .addAttribute (att );
@@ -417,4 +495,20 @@ public FeatureVectorDataSet initialiseFeatures() {
417495 return result ;
418496 }
419497
498+ public boolean isForwardSelection () {
499+ return forwardSelection ;
500+ }
501+
502+ public void setForwardSelection (boolean forwardSelection ) {
503+ this .forwardSelection = forwardSelection ;
504+ }
505+
506+ public boolean isBackwardSelection () {
507+ return backwardSelection ;
508+ }
509+
510+ public void setBackwardSelection (boolean backwardSelection ) {
511+ this .backwardSelection = backwardSelection ;
512+ }
513+
420514}
0 commit comments