Skip to content

Commit 930ca07

Browse files
committed
Merge branch 'CleanUp_WekaMatchingRule'
2 parents b172c07 + 08934cc commit 930ca07

26 files changed

+249759
-668
lines changed

src/main/java/de/uni_mannheim/informatik/dws/winter/matching/rules/WekaMatchingRule.java

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,18 @@
3737
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Attribute;
3838
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.FeatureVectorDataSet;
3939
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Record;
40+
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.comparators.RecordComparator;
4041
import de.uni_mannheim.informatik.dws.winter.processing.Processable;
42+
import weka.attributeSelection.AttributeSelection;
43+
import weka.attributeSelection.GreedyStepwise;
44+
import weka.attributeSelection.WrapperSubsetEval;
4145
import weka.classifiers.Classifier;
4246
import weka.classifiers.evaluation.Evaluation;
4347
import weka.core.DenseInstance;
4448
import weka.core.Instance;
4549
import weka.core.Instances;
4650
import weka.core.Utils;
51+
import weka.core.pmml.PMMLFactory;
4752

4853
/**
4954
* Class that creates and applies a matching Rule based on supervised learning
@@ -62,21 +67,27 @@ public class WekaMatchingRule<RecordType extends Matchable, SchemaElementType ex
6267
private Classifier classifier;
6368
private List<Comparator<RecordType, SchemaElementType>> comparators;
6469

70+
// Handling of feature subset selection
71+
private boolean forwardSelection = false;
72+
private boolean backwardSelection = false;
73+
private AttributeSelection fs;
74+
6575
public final String trainingSet = "trainingSet";
6676
public final String machtSet = "matchSet";
67-
68-
// TODO Discuss finalThreshold --> Can be set via options -C <confidence factor for pruning>
77+
78+
// TODO Discuss finalThreshold --> Can be set via options -C <confidence
79+
// factor for pruning>
6980
/**
7081
* Create a MatchingRule, which can be trained using the Weka library for
7182
* identity resolution.
7283
*
7384
* @param finalThreshold
7485
* determines the confidence level, which needs to be exceeded by
7586
* the classifier, so that it can classify a record as match.
76-
*
87+
*
7788
* @param classifierName
7889
* Has the name of a specific classifier from the Weka library.
79-
*
90+
*
8091
* @param parameters
8192
* Hold the parameters to tune the classifier.
8293
*/
@@ -95,6 +106,7 @@ public WekaMatchingRule(double finalThreshold, String classifierName, String par
95106
// create list for comparators
96107
this.comparators = new LinkedList<>();
97108
}
109+
98110

99111
public String[] getparameters() {
100112
return parameters;
@@ -125,7 +137,10 @@ public void addComparator(Comparator<RecordType, SchemaElementType> comparator)
125137

126138
/**
127139
*
128-
* learns the rule from parsed features in a cross validation
140+
* Learns the rule from parsed features in a cross validation and the set
141+
* parameters. Additionally feature subset selection is conducted, if the
142+
* parameters this.forwardSelection or this.backwardSelection are set
143+
* accordingly.
129144
*
130145
* @param features
131146
* Contains features to learn a classifier
@@ -135,13 +150,39 @@ public void addComparator(Comparator<RecordType, SchemaElementType> comparator)
135150
public Performance learnParameters(FeatureVectorDataSet features) {
136151
// create training
137152
Instances trainingData = transformToWeka(features, this.trainingSet);
153+
138154
try {
139155
Evaluation eval = new Evaluation(trainingData);
156+
// apply feature subset selection
157+
if (this.forwardSelection || this.backwardSelection) {
158+
159+
GreedyStepwise search = new GreedyStepwise();
160+
search.setSearchBackwards(this.backwardSelection);
161+
162+
this.fs = new AttributeSelection();
163+
WrapperSubsetEval wrapper = new WrapperSubsetEval();
140164

165+
// Do feature subset selection, but using a 10-fold cross
166+
// validation
167+
wrapper.buildEvaluator(trainingData);
168+
wrapper.setClassifier(this.classifier);
169+
wrapper.setFolds(10);
170+
wrapper.setThreshold(0.01);
171+
172+
this.fs.setEvaluator(wrapper);
173+
this.fs.setSearch(search);
174+
175+
this.fs.SelectAttributes(trainingData);
176+
177+
trainingData = fs.reduceDimensionality(trainingData);
178+
179+
}
180+
// perform 10-fold Cross Validation to evaluate classifier
141181
eval.crossValidateModel(this.classifier, trainingData, 10, new Random(1));
142-
this.classifier.buildClassifier(trainingData);
143182
System.out.println(eval.toSummaryString("\nResults\n\n", false));
144-
183+
184+
this.classifier.buildClassifier(trainingData);
185+
145186
int truePositive = (int) eval.numTruePositives(trainingData.classIndex());
146187
int falsePositive = (int) eval.numFalsePositives(trainingData.classIndex());
147188
int falseNegative = (int) eval.numFalseNegatives(trainingData.classIndex());
@@ -273,7 +314,18 @@ public Record generateFeatures(RecordType record1, RecordType record2,
273314

274315
double similarity = comp.compare(record1, record2, null);
275316

276-
String name = String.format("[%d] %s", i, comp.getClass().getSimpleName());
317+
String attribute1 = "";
318+
String attribute2 = "";
319+
try{
320+
attribute1 = ((RecordComparator)comp).getAttributeRecord1().toString();
321+
attribute2 = ((RecordComparator)comp).getAttributeRecord2().toString();
322+
323+
} catch (ClassCastException e) {
324+
// Not possible to add attribute names
325+
//e.printStackTrace();
326+
}
327+
328+
String name = String.format("[%d] %s %s %s", i, comp.getClass().getSimpleName(), attribute1, attribute2);
277329
Attribute att = null;
278330
for (Attribute elem : features.getSchema().get()) {
279331
if (elem.toString().equals(name)) {
@@ -315,9 +367,18 @@ public Correspondence<RecordType, SchemaElementType> apply(RecordType record1, R
315367
FeatureVectorDataSet matchSet = this.initialiseFeatures();
316368
Record matchRecord = generateFeatures(record1, record2, schemaCorrespondences, matchSet);
317369

370+
// transform entry for classification.
318371
matchSet.add(matchRecord);
319372
Instances matchInstances = this.transformToWeka(matchSet, this.machtSet);
320-
373+
374+
// reduce dimensions if feature subset selection was applied before.
375+
if((this.backwardSelection|| this.forwardSelection) && this.fs != null)
376+
try {
377+
matchInstances = this.fs.reduceDimensionality(matchInstances);
378+
} catch (Exception e1) {
379+
e1.printStackTrace();
380+
}
381+
// Apply matching rule
321382
try {
322383
double result = this.classifier.classifyInstance(matchInstances.firstInstance());
323384
return new Correspondence<RecordType, SchemaElementType>(record1, record2, result, schemaCorrespondences);
@@ -365,14 +426,20 @@ public void storeModel(File location) {
365426
@Override
366427
public void readModel(File location) {
367428
// deserialize model
429+
368430
try {
369431
ObjectInputStream ois = new ObjectInputStream(new FileInputStream(location));
370432
this.setClassifier((Classifier) ois.readObject());
371433
ois.close();
372434
} catch (FileNotFoundException e) {
373435
e.printStackTrace();
374436
} catch (IOException e) {
375-
e.printStackTrace();
437+
try {
438+
this.setClassifier((Classifier) PMMLFactory.getPMMLModel(location, null));
439+
440+
} catch (Exception e1) {
441+
e1.printStackTrace();
442+
}
376443
} catch (ClassNotFoundException e) {
377444
e.printStackTrace();
378445
}
@@ -393,7 +460,7 @@ public double compare(RecordType record1, RecordType record2,
393460

394461
/**
395462
* Create a new FeatureVectorDataSet with the corresponding features, which
396-
* result form the added comparators.
463+
* result from the added comparators.
397464
*
398465
* @see de.uni_mannheim.informatik.dws.winter.matching.rules.LearnableMatchingRule#initialiseFeatures()
399466
*/
@@ -405,8 +472,19 @@ public FeatureVectorDataSet initialiseFeatures() {
405472
for (int i = 0; i < comparators.size(); i++) {
406473

407474
Comparator<RecordType, SchemaElementType> comp = comparators.get(i);
408-
409-
String name = String.format("[%d] %s", i, comp.getClass().getSimpleName());
475+
476+
String attribute1 = "";
477+
String attribute2 = "";
478+
try{
479+
attribute1 = ((RecordComparator)comp).getAttributeRecord1().toString();
480+
attribute2 = ((RecordComparator)comp).getAttributeRecord2().toString();
481+
482+
} catch (ClassCastException e) {
483+
// Not possible to add attribute names
484+
//e.printStackTrace();
485+
}
486+
487+
String name = String.format("[%d] %s %s %s", i, comp.getClass().getSimpleName(), attribute1, attribute2);
410488

411489
Attribute att = new Attribute(name);
412490
result.addAttribute(att);
@@ -417,4 +495,20 @@ public FeatureVectorDataSet initialiseFeatures() {
417495
return result;
418496
}
419497

498+
public boolean isForwardSelection() {
499+
return forwardSelection;
500+
}
501+
502+
public void setForwardSelection(boolean forwardSelection) {
503+
this.forwardSelection = forwardSelection;
504+
}
505+
506+
public boolean isBackwardSelection() {
507+
return backwardSelection;
508+
}
509+
510+
public void setBackwardSelection(boolean backwardSelection) {
511+
this.backwardSelection = backwardSelection;
512+
}
513+
420514
}

src/main/java/de/uni_mannheim/informatik/dws/winter/model/defaultmodel/CSVRecordReader.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import java.io.File;
1515
import java.util.HashSet;
16+
import java.util.Map;
1617
import java.util.Set;
1718

1819
import de.uni_mannheim.informatik.dws.winter.model.DataSet;
@@ -28,16 +29,24 @@
2829
public class CSVRecordReader extends CSVMatchableReader<Record, Attribute> {
2930

3031
private int idIndex = -1;
32+
private Map<String, Attribute> attributeMapping;
3133

3234
/**
3335
*
3436
* @param idColumnIndex
3537
* The index of the column that contains the ID attribute. Specify -1 if the file does not contain a unique ID attribute.
38+
* @param attributeMapping
39+
* The position of a column and the corresponding attribute
3640
*/
3741
public CSVRecordReader(int idColumnIndex) {
3842
this.idIndex = idColumnIndex;
3943
}
4044

45+
public CSVRecordReader(int idColumnIndex, Map<String, Attribute> attributeMapping) {
46+
this.idIndex = idColumnIndex;
47+
this.attributeMapping = attributeMapping;
48+
}
49+
4150
/* (non-Javadoc)
4251
* @see de.uni_mannheim.informatik.wdi.model.io.CSVMatchableReader#readLine(java.lang.String[], de.uni_mannheim.informatik.wdi.model.DataSet)
4352
*/
@@ -75,8 +84,15 @@ protected void readLine(File file, int rowNumber, String[] values, DataSet<Recor
7584
Record r = new Record(id, file.getAbsolutePath());
7685

7786
for(int i = 0; i < values.length; i++) {
78-
String attributeId = String.format("%s_Col%d", file.getName(), i);
79-
Attribute a = dataset.getAttribute(attributeId);
87+
Attribute a;
88+
if(this.attributeMapping == null){
89+
String attributeId = String.format("%s_Col%d", file.getName(), i);
90+
a = dataset.getAttribute(attributeId);
91+
}
92+
else{
93+
a = this.attributeMapping.get(Integer.toString(i));
94+
}
95+
8096
String v = values[i];
8197

8298
if(v.isEmpty()) {
Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,55 @@
1-
/*
2-
* Copyright (c) 2017 Data and Web Science Group, University of Mannheim, Germany (http://dws.informatik.uni-mannheim.de/)
3-
*
4-
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
5-
* You may obtain a copy of the License at
6-
*
7-
* http://www.apache.org/licenses/LICENSE-2.0
8-
*
9-
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10-
* See the License for the specific language governing permissions and limitations under the License.
11-
*/
12-
package de.uni_mannheim.informatik.dws.winter.model.defaultmodel.comparators;
13-
14-
15-
16-
import de.uni_mannheim.informatik.dws.winter.matching.rules.Comparator;
17-
import de.uni_mannheim.informatik.dws.winter.model.Correspondence;
18-
import de.uni_mannheim.informatik.dws.winter.model.Matchable;
19-
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Attribute;
20-
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Record;
21-
import de.uni_mannheim.informatik.dws.winter.similarity.EqualsSimilarity;
22-
import de.uni_mannheim.informatik.dws.winter.similarity.string.LevenshteinSimilarity;
23-
24-
/**
25-
* {@link Comparator} for {@link Record}s
26-
* values and their {@link LevenshteinSimilarity} value.
27-
*
28-
* @author Alexander Brinkmann (albrinkm@mail.uni-mannheim.de)
29-
*
30-
*/
31-
public class RecordComparatorEqual extends RecordComparator {
32-
33-
public RecordComparatorEqual(Attribute attributeRecord1, Attribute attributeRecord2) {
34-
super(attributeRecord1, attributeRecord2);
35-
}
36-
37-
38-
private static final long serialVersionUID = 1L;
39-
private EqualsSimilarity<String> sim = new EqualsSimilarity<String>();
40-
41-
42-
@Override
43-
public double compare(Record record1, Record record2, Correspondence<Attribute, Matchable> schemaCorrespondence) {
44-
// preprocessing
45-
String s1 = record1.getValue(this.getAttributeRecord1());
46-
String s2 = record2.getValue(this.getAttributeRecord2());
47-
48-
double similarity = sim.calculate(s1, s2);
49-
50-
return similarity;
51-
}
52-
53-
}
1+
/*
2+
* Copyright (c) 2017 Data and Web Science Group, University of Mannheim, Germany (http://dws.informatik.uni-mannheim.de/)
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
* See the License for the specific language governing permissions and limitations under the License.
11+
*/
12+
package de.uni_mannheim.informatik.dws.winter.model.defaultmodel.comparators;
13+
14+
15+
16+
import de.uni_mannheim.informatik.dws.winter.matching.rules.Comparator;
17+
import de.uni_mannheim.informatik.dws.winter.model.Correspondence;
18+
import de.uni_mannheim.informatik.dws.winter.model.Matchable;
19+
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Attribute;
20+
import de.uni_mannheim.informatik.dws.winter.model.defaultmodel.Record;
21+
import de.uni_mannheim.informatik.dws.winter.similarity.EqualsSimilarity;
22+
23+
/**
24+
* {@link Comparator} for {@link Record}s
25+
* exactly matching.
26+
*
27+
* @author Alexander Brinkmann (albrinkm@mail.uni-mannheim.de)
28+
*
29+
*/
30+
public class RecordComparatorEqual extends StringComparator {
31+
32+
public RecordComparatorEqual(Attribute attributeRecord1, Attribute attributeRecord2) {
33+
super(attributeRecord1, attributeRecord2);
34+
}
35+
36+
37+
private static final long serialVersionUID = 1L;
38+
private EqualsSimilarity<String> sim = new EqualsSimilarity<String>();
39+
40+
41+
@Override
42+
public double compare(Record record1, Record record2, Correspondence<Attribute, Matchable> schemaCorrespondence) {
43+
// preprocessing
44+
String s1 = record1.getValue(this.getAttributeRecord1());
45+
String s2 = record2.getValue(this.getAttributeRecord2());
46+
47+
s1 = preprocess(s1);
48+
s2 = preprocess(s2);
49+
50+
double similarity = sim.calculate(s1, s2);
51+
52+
return similarity;
53+
}
54+
55+
}

0 commit comments

Comments
 (0)