Skip to content

Commit fa7a2df

Browse files
committed
Implement the removeTestExemplar feature
Add a method to `Label` that indicates whether it consists of all 0's (matches). Allow configuring `SubcontextList` so that it skips the adding of such labels. Then pass this configuration from `AnalogicalModeling`. Change the default to `true` to match the behavior of Algorithm::AM. Fixes #26.
1 parent 572dbad commit fa7a2df

File tree

11 files changed

+69
-16
lines changed

11 files changed

+69
-16
lines changed

src/main/java/weka/classifiers/lazy/AM/data/SubcontextList.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public class SubcontextList implements Iterable<Subcontext> {
4040
private final HashMap<Label, Subcontext> labelToSubcontext = new HashMap<>();
4141

4242
private final Labeler labeler;
43+
private final boolean ignoreFullMatches;
4344

4445
/**
4546
* @return the number of attributes used to predict an outcome
@@ -54,9 +55,11 @@ public int getCardinality() {
5455
*
5556
* @param labeler Labeler for assigning labels to items in data
5657
* @param data Training data (exemplars)
58+
* @param ignoreFullMatches if true, will not add entirely matching contexts
5759
*/
58-
public SubcontextList(Labeler labeler, List<Instance> data) {
60+
public SubcontextList(Labeler labeler, List<Instance> data, boolean ignoreFullMatches) {
5961
this.labeler = labeler;
62+
this.ignoreFullMatches = ignoreFullMatches;
6063
for (Instance se : data)
6164
add(se);
6265
}
@@ -66,6 +69,9 @@ public SubcontextList(Labeler labeler, List<Instance> data) {
6669
*/
6770
void add(Instance exemplar) {
6871
Label label = labeler.label(exemplar);
72+
if(ignoreFullMatches && label.allMatching()) {
73+
return;
74+
}
6975
if (!labelToSubcontext.containsKey(label)) {
7076
labelToSubcontext.put(label, new Subcontext(label, labeler.getContextString(label)));
7177
}

src/main/java/weka/classifiers/lazy/AM/label/BitSetLabel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ public Label union(Label other) {
6666
return new BitSetLabel(bitSet, getCardinality());
6767
}
6868

69+
@Override
70+
public boolean allMatching() {
71+
return labelBits.isEmpty();
72+
}
73+
6974
@Override
7075
public String toString() {
7176
if (getCardinality() == 0) {

src/main/java/weka/classifiers/lazy/AM/label/IntLabel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ public Label union(Label other) {
104104
return new IntLabel(labelBits & otherLabel.labelBits, getCardinality());
105105
}
106106

107+
@Override
108+
public boolean allMatching() {
109+
return labelBits == 0;
110+
}
111+
107112
@Override
108113
public String toString() {
109114
if(getCardinality() == 0) {

src/main/java/weka/classifiers/lazy/AM/label/Label.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,10 @@ public abstract class Label {
7777
* words, keep all matches from both labels.
7878
*/
7979
public abstract Label union(Label other);
80+
81+
/**
82+
* @return true if every feature of this label is a match (i.e. this is the
83+
* {@link Labeler#getLatticeTop() top of the lattice}; false otherwise
84+
*/
85+
public abstract boolean allMatching();
8086
}

src/main/java/weka/classifiers/lazy/AM/label/LongLabel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ public Label union(Label other) {
9797
return new LongLabel(labelBits & otherLabel.labelBits, getCardinality());
9898
}
9999

100+
@Override
101+
public boolean allMatching() {
102+
return labelBits == 0L;
103+
}
104+
100105
@Override
101106
public String toString() {
102107
if(getCardinality() == 0) {

src/main/java/weka/classifiers/lazy/AnalogicalModeling.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ private AMResults classify(Instance testItem) throws InterruptedException, Execu
167167
Labeler labeler = new LabelerFactory.CardinalityBasedLabelerFactory().createLabeler(testItem, m_ignoreUnknowns, mdc);
168168
// 3 steps to assigning outcome probabilities:
169169
// 1. Place each data item in a subcontext
170-
SubcontextList subList = new SubcontextList(labeler, trainingExemplars);
170+
SubcontextList subList = new SubcontextList(labeler, trainingExemplars, getRemoveTestExemplar());
171171
// 2. Create a supracontextual lattice and fill it with subcontexts
172172
LatticeFactory latticeFactory;
173173
if (randomProvider == null) {
@@ -234,7 +234,7 @@ public String ignoreUnknownsTipText() {
234234
/**
235235
* By default, we remove any exemplar with the same features as the test exemplar
236236
*/
237-
private boolean m_removeTestExemplar = false;
237+
private boolean m_removeTestExemplar = true;
238238

239239
/**
240240
* @return true if we remove a test instance from training before predicting its outcome
@@ -255,8 +255,8 @@ public void setRemoveTestExemplar(boolean removeTestExemplar) {
255255
*/
256256
@SuppressWarnings("unused") // used by Weka UI
257257
public String removeTestExemplarTipText() {
258-
return "Set to true if you wish to remove a test instance from the training set before "
259-
+ "attempting to predict its outcome.";
258+
return "Set to true if you wish to remove the test instance from the training set before "
259+
+ "attempting to classify it.";
260260
}
261261

262262
/**

src/test/java/weka/classifiers/lazy/AM/data/SubcontextListTest.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import java.util.ArrayList;
1414
import java.util.List;
1515

16-
import static org.junit.Assert.assertEquals;
17-
import static org.junit.Assert.assertTrue;
16+
import static org.junit.Assert.*;
1817

1918
public class SubcontextListTest {
2019

@@ -26,7 +25,7 @@ public void testChapter3Data() throws Exception {
2625

2726
Labeler labeler = new IntLabeler(test, false, MissingDataCompare.MATCH);
2827

29-
SubcontextList subs = new SubcontextList(labeler, train);
28+
SubcontextList subs = new SubcontextList(labeler, train, false);
3029
assertEquals(subs.getCardinality(), 3);
3130

3231
List<Subcontext> subList = getSubList(subs);
@@ -50,6 +49,22 @@ public void testChapter3Data() throws Exception {
5049
assertTrue(subList.contains(expected));
5150
}
5251

52+
@Test
53+
public void testIgnoreFullMatches() throws Exception {
54+
Instances train = TestUtils.getDataSet(TestUtils.CHAPTER_3_DATA);
55+
Instance test = train.get(0);
56+
57+
Labeler labeler = new IntLabeler(test, false, MissingDataCompare.MATCH);
58+
Subcontext allMatchingSub = new Subcontext(new IntLabel(0b000, 3), "foo");
59+
allMatchingSub.add(train.get(0));// 310e
60+
61+
SubcontextList subs = new SubcontextList(labeler, train, false);
62+
assertTrue("Should contain 000 sub when not ignoring full matches", getSubList(subs).contains(allMatchingSub));
63+
64+
subs = new SubcontextList(labeler, train, true);
65+
assertFalse("Should not contain 000 sub when ignoring full matches", getSubList(subs).contains(allMatchingSub));
66+
}
67+
5368
private List<Subcontext> getSubList(final SubcontextList subcontextList) {
5469
return new ArrayList<>() {
5570
{
@@ -67,7 +82,7 @@ public void testAccessors() throws Exception {
6782

6883
Labeler labeler = new IntLabeler(test, false, MissingDataCompare.MATCH);
6984

70-
SubcontextList subs = new SubcontextList(labeler, train);
85+
SubcontextList subs = new SubcontextList(labeler, train, false);
7186
assertEquals("getLabeler returns the labeler used in the constructor", subs.getLabeler(), labeler);
7287
assertEquals("getCardinality returns the cardinality of the test item", subs.getCardinality(), 3);
7388
}

src/test/java/weka/classifiers/lazy/AM/label/LabelTest.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,15 @@ public void testUnion() {
127127
assertTrue(intersected.matches(i));
128128
}
129129

130+
@Test
131+
public void testAllMatching() {
132+
Labeler labeler = labelerFactory.createLabeler(mockInstance(3), false, MATCH);
133+
assertTrue("Label composed of all 0's", labeler.fromBits(0b000).allMatching());
134+
for (int bits : List.of(0b100, 0b001, 0b010, 0b111)) {
135+
assertFalse("Label with a 1 in it", labeler.fromBits(bits).allMatching());
136+
}
137+
}
138+
130139
@Test
131140
public void testMatchesThrowsExceptionForIndexTooLow() {
132141
Labeler labeler = labelerFactory.createLabeler(mockInstance(3), false, MATCH);

src/test/java/weka/classifiers/lazy/AM/lattice/HeterogeneousLatticeTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void testChapter3Data() throws Exception {
3535
// cardinality of one
3636
Labeler noPartitionLabeler = Mockito.spy(new LabelerFactory.IntLabelerFactory().createLabeler(test, false, MissingDataCompare.MATCH));
3737
when(noPartitionLabeler.numPartitions()).thenReturn(1);
38-
SubcontextList subList = new SubcontextList(noPartitionLabeler, train);
38+
SubcontextList subList = new SubcontextList(noPartitionLabeler, train, false);
3939
HeterogeneousLattice heteroLattice = new HeterogeneousLattice(0);
4040
heteroLattice.fill(subList);
4141

src/test/java/weka/classifiers/lazy/AM/lattice/LatticeTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ private void skipForLatticeClass(String reason, Class<? extends Lattice> clazz)
7272
@Test
7373
public void testFillingWithEmptySubcontextList() throws Exception {
7474
Lattice lattice = latticeSupplier.get();
75-
lattice.fill(new SubcontextList(mock(IntLabeler.class), Collections.emptyList()));
75+
lattice.fill(new SubcontextList(mock(IntLabeler.class), Collections.emptyList(), false));
7676
}
7777

7878
@Test
7979
public void testLatticeCannotBeFilledTwice() throws Exception {
8080
Lattice lattice = latticeSupplier.get();
81-
lattice.fill(new SubcontextList(mock(IntLabeler.class), Collections.emptyList()));
81+
lattice.fill(new SubcontextList(mock(IntLabeler.class), Collections.emptyList(), false));
8282
exception.expect(IllegalStateException.class);
8383
exception.expectMessage(new StringContains("already filled"));
84-
lattice.fill(new SubcontextList(mock(IntLabeler.class), Collections.emptyList()));
84+
lattice.fill(new SubcontextList(mock(IntLabeler.class), Collections.emptyList(), false));
8585
}
8686

8787
@Test
@@ -150,7 +150,7 @@ public void testCleanSupraTiming() throws Exception {
150150
private void testSupras(Instances train, int testIndex, String[] expectedSupras) throws ExecutionException, InterruptedException {
151151
final Instance test = train.get(testIndex);
152152
train.remove(testIndex);
153-
SubcontextList subList = new SubcontextList(getFullSplitLabeler(test), train);
153+
SubcontextList subList = new SubcontextList(getFullSplitLabeler(test), train, false);
154154
Lattice testLattice = latticeSupplier.get();
155155
testLattice.fill(subList);
156156
Set<Supracontext> actualSupras = testLattice.getSupracontexts();

0 commit comments

Comments
 (0)